# U-Net Model

## Import Libraries, Modules, & Scripts

In [1]:
import torch
import numpy as np
from torch import optim, nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from unet import UNet
from dataset import dataset

## Set Hyperparameters and Device

In [2]:
LEARNING_RATE = 3e-4
BATCH_SIZE = 16
EPOCHS = 1
DATA_PATH = '25-ds-casia-un-tvt'
MODEL_SAVE_PATH = 'un-model/ds4-un-model.pth'
RESUME_TRAINING = False

### Select Device

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


## Create Datasets and DataLoaders

In [4]:
train_dataset = dataset(root_path=DATA_PATH, split='train')
val_dataset = dataset(root_path=DATA_PATH, split='validation')
test_dataset = dataset(root_path=DATA_PATH, split='test')

In [5]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

## Initialize Model, Optimizer, and Loss Function

### Class Weighting

In [6]:
total_px = 0
total_tp_px = 0

In [7]:
for img, mask in train_dataloader:
    total_px += mask.nelement()
    total_tp_px += mask.sum().item()

# Calculate the ratio
au_px = total_px - total_tp_px
pos_weight = au_px / total_tp_px

In [8]:
print(f"Total pixels: {total_px}")
print(f"Total tampered pixels: {total_tp_px}")
print(f"Calculated pos_weight: {pos_weight}")

Total pixels: 58851328
Total tampered pixels: 6278448.774902344
Calculated pos_weight: 8.373545936259612


In [9]:
pos_weight_tensor = torch.tensor([pos_weight]).to(device)

In [10]:
model = UNet(in_channels=3, num_classes=1).to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)

## If Resume Training, then Load Weights

In [11]:
if RESUME_TRAINING:
    try:
        model.load_state_dict(torch.load(MODEL_SAVE_PATH))
        print("Loaded model weights from checkpoint.")
    except FileNotFoundError:
        print("No checkpoint found. Starting training from scratch.")

## Training Loop

### Training Loop

In [12]:
for epoch in range(EPOCHS):
    model.train()
    train_running_loss = 0
    train_total = 0
    train_correct = 0
    
    for idx, img_mask in enumerate(tqdm(train_dataloader)):
        img = img_mask[0].float().to(device)
        mask = img_mask[1].float().to(device)

        y_pred = model(img)
        optimizer.zero_grad()

        loss = criterion(y_pred, mask)
        train_running_loss += loss.item()

        loss.backward()
        optimizer.step()

        # Calculate accuracy for binary segmentation
        predicted = (torch.sigmoid(y_pred) > 0.5).float()
        train_total += mask.nelement()
        train_correct += (predicted == mask).sum().item()

    train_loss = train_running_loss / (idx + 1)
    train_accuracy = 100 * train_correct / train_total 

    model.eval()
    val_running_loss = 0
    val_total = 0
    val_correct = 0
    
    with torch.no_grad():
        for idx, img_mask in enumerate(tqdm(val_dataloader)):
            img = img_mask[0].float().to(device)
            mask = img_mask[1].float().to(device)

            y_pred = model(img)
            loss = criterion(y_pred, mask)

            val_running_loss += loss.item()

            # Calculate validation accuracy for binary segmentation
            predicted = (torch.sigmoid(y_pred) > 0.5).float()
            val_total += mask.nelement()
            val_correct += (predicted == mask).sum().item()

        val_loss = val_running_loss / (idx + 1)
        val_accuracy = 100 * val_correct / val_total

    print("-" * 70)
    print(f"EPOCH {epoch + 1}")
    print(f"Training Accuracy: {train_accuracy:.4f}% | Training Loss: {train_loss:.4f}")
    print(f"Validation Accuracy: {val_accuracy:.4f}% | Validation Loss: {val_loss:.4f}")
    print("-" * 70)

100%|██████████████████████████████████████████████████████████████████████████████████| 57/57 [09:30<00:00, 10.02s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 17/17 [00:55<00:00,  3.27s/it]

----------------------------------------------------------------------
EPOCH 1
Training Accuracy: 53.3512% | Training Loss: 1.2183
Validation Accuracy: 63.8073% | Validation Loss: 1.0550
----------------------------------------------------------------------





### Training Loop Curves

### Save the Model

In [13]:
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"Model saved to {MODEL_SAVE_PATH}")

Model saved to un-model/ds4-un-model.pth


## Test Loop

In [10]:
model.eval()
test_running_loss = 0
with torch.no_grad():
    for idx, img_mask in enumerate(tqdm(test_dataloader)):
        img = img_mask[0].float().to(device)
        mask = img_mask[1].float().to(device)

        y_pred = model(img)
        loss = criterion(y_pred, mask)

        test_running_loss += loss.item()

    test_loss = test_running_loss / (idx + 1)
    print(f"Test Loss: {test_loss:.4f}")

  0%|                                                                                            | 0/8 [00:05<?, ?it/s]


KeyboardInterrupt: 

### Test Loop Curves