In [1]:
# CS4287-Assign1-22352228-22340343.ipynb
# Names and IDs: Cormac Greaney - 22352228, Jan Lawinski - 22340343
# Date: October 2025
# Description: Image Segmentation on Oxford-IIIT Pet Dataset using VGG16-based U-Net
# Code runs to completion: Yes
# Reused source: https://pytorch.org/vision/stable/models.html (for pretrained VGG16)


# ==============================================================
# Here we have all of the imports used for our project
# ==============================================================
import torch, torchvision
from torchvision.transforms import functional as TF
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.datasets.utils import download_and_extract_archive
from PIL import Image
import numpy as np, os, matplotlib.pyplot as plt
from sklearn.model_selection import KFold


# ======================================================================
# Here we check for my GPU so that we can drastically reduce the runtime 
# but will revert to the cpu if its run on a machine without a GPU
# ======================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# ==============================================================
# Here we download the Oxford-IIIT Pet datasets images and masks
# ==============================================================
root = "./data/oxford_pets/"
download_and_extract_archive(
    url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
    download_root=root, extract_root=root)
download_and_extract_archive(
    url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz",
    download_root=root, extract_root=root)


# =======================================================================
# Here we create our custom dataset class for the Oxford-IIIT Pet Dataset
# =======================================================================
class OxfordPetSegmentation(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = [f for f in os.listdir(img_dir) if f.endswith(".jpg")]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.img_dir, img_name)
        mask_path = os.path.join(self.mask_dir, "trimaps", img_name.replace(".jpg", ".png"))

        # === Here we load the image and mask ===
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)

        # === Here we resize the image and mask === 
        image = TF.resize(image, (256, 256), interpolation=Image.BILINEAR)
        mask = TF.resize(mask, (256, 256), interpolation=Image.NEAREST)

        # === Here we convert image to tensor and normalize the channels the match our pretrained VGG16 ===
        image = TF.to_tensor(image)
        image = TF.normalize(image, [0.485, 0.456, 0.406],
                                   [0.229, 0.224, 0.225])
        
        # === Here we convert mask to a NumPy array ===
        mask = np.array(mask)

        # === Here we adjust the mask labels to be 0,1,2 ===
        mask = np.clip(mask, 1, 3) - 1

        # === Here we convert to a long tensor ===
        mask = torch.from_numpy(mask).long()

        # === Here we return a tuple for training ===
        return image, mask


# === Transforms - not using atm ===
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])


# ==============================================================
# Here we instantiate our dataset
# ==============================================================
dataset = OxfordPetSegmentation(
    img_dir=os.path.join(root, "images"),
    mask_dir=os.path.join(root, "annotations"),
    transform=transform
)


# ===================================================================
# Here we create our dataloader and wrap our dataset
# we will use a batch size of 4 for training with shuffling turned on
# ===================================================================
loader = DataLoader(dataset, batch_size=4, shuffle=True)


# ==============================================================
# Here we pull a sample batch to verify everything is working
# we print the shape of the batch as a sanity check
# ==============================================================
print("Dataset loaded, sample batch shape:", next(iter(loader))[0].shape)


# ==============================================================
# Here we define our VGG16-based U-Net model
# ==============================================================
class UNetVGG16(nn.Module):

    # === Here we define the cunstructor with 3 classes, background, pet, border ===
    def __init__(self, n_classes=3):
        super().__init__()

        # === Here we load the pretrained VGG16 model ===
        vgg = models.vgg16_bn(weights=models.VGG16_BN_Weights.IMAGENET1K_V1)

        # === Here we extract the features layers into a python list so it can be sliced ===
        features = list(vgg.features.children())


        # =============================================================================================
        # Here we create slices of VGG16 for encoder path, following the standard split we found online
        # https://medium.com/@mygreatlearning/everything-you-need-to-know-about-vgg16-7315defb5918
        # ============================================================================================= 
        self.enc1 = nn.Sequential(*features[:6])    # 64 filters
        self.enc2 = nn.Sequential(*features[6:13])  # 128 filters
        self.enc3 = nn.Sequential(*features[13:23]) # 256 filters
        self.enc4 = nn.Sequential(*features[23:33]) # 512 filters
        self.center = nn.Sequential(*features[33:43]) # 512 bottleneck


        # ===============================================================
        # Here we define the decoder path with upsampling and conv layers
        # ===============================================================
        self.up4 = nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
        self.dec4 = nn.Sequential(
            nn.Conv2d(512 + 512, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = nn.Sequential(
            nn.Conv2d(256 + 256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(128 + 128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(64 + 64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.final = nn.Conv2d(64, n_classes, kernel_size=1)


    # ===============================================================
    # Here we define our forward pass
    # ===============================================================
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        c = self.center(e4)

        d4 = self.up4(c)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))

        d3 = self.up3(d4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))

        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))

        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        return self.final(d1)


# ==========================================================================
# Here we instantiate our model and move it to the device we checked earlier
# Hopefully a GPU!
# ==========================================================================
model = UNetVGG16().to(device)


# ===============================================================================
# Here we print the number of trainable parameters in the model as a sanity check
# ===============================================================================
print(f"Model params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


# ================================================================================
# Here we define our loss function - CrossEntropyLoss for multi-class segmentation
# ================================================================================
criterion = nn.CrossEntropyLoss()


# ==========================================================================================
# Here we use the more modern AdamW optimizer with decoupled weight decay for regularization
# We set a learning rate of 3e-4 and weight decay of 1e-4
# https://www.datacamp.com/tutorial/adamw-optimizer-in-pytorch
# ==========================================================================================
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)


# ==============================================================================================
# Here we define a learning rate scheduler to reduce our learning rate
# it will be halved when a monitored metric stops showing improvement within our patience period
# ==============================================================================================
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)


# ==============================================================
# Here we define a simple training loop for one epoch
# ==============================================================
def train_one_epoch(model, loader, criterion, optimizer):

    # === Here we set the model to training mode ===
    model.train()

    # === Here we initialize total loss for the epoch ===
    total_loss = 0

    # === Here we loop over the data loader ===
    for imgs, masks in loader:

        # === Here we move the images and masks to the device for ===
        imgs, masks = imgs.to(device), masks.to(device)

        # === Here we zero the gradients, perform a forward pass, compute the loss, perform a backward pass, and update weights ===
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        # === Here we add the loss to the total loss ===
        total_loss += loss.item()

    # === Here we return the average loss for the epoch ===    
    return total_loss / len(loader)


# ===========================================================================
# Here we define our helper functions for evaluating the segmentation quality
# ===========================================================================
def compute_iou(preds, labels, num_classes=3):
    """Compute mean Intersection over Union (mIoU) across classes."""
    ious = []
    preds = torch.argmax(preds, dim=1)  # convert from logits to class IDs
    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (labels == cls)
        intersection = (pred_inds & target_inds).float().sum()
        union = (pred_inds | target_inds).float().sum()
        if union == 0:
            ious.append(torch.tensor(float('nan')))  # ignore empty classes
        else:
            ious.append(intersection / union)
    return torch.tensor(ious).nanmean().item()


def compute_dice(preds, labels, num_classes=3):
    """Compute mean Dice coefficient across classes."""
    dice_scores = []
    preds = torch.argmax(preds, dim=1)
    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (labels == cls)
        intersection = (pred_inds & target_inds).float().sum()
        dice = (2. * intersection) / (pred_inds.float().sum() + target_inds.float().sum() + 1e-8)
        dice_scores.append(dice)
    return torch.tensor(dice_scores).nanmean().item()


# ==============================================================
# Here we define our validation loop to measure model performance
# ==============================================================
def validate(model, loader, criterion):
    model.eval()  # set to evaluation mode
    total_loss = 0
    total_iou = 0
    total_dice = 0
    count = 0

    with torch.no_grad():
        for imgs, masks in loader:
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            total_loss += loss.item()
            total_iou += compute_iou(outputs, masks)
            total_dice += compute_dice(outputs, masks)
            count += 1

    avg_loss = total_loss / count
    avg_iou = total_iou / count
    avg_dice = total_dice / count

    return avg_loss, avg_iou, avg_dice


# ===============================================================
# Here we run a demo of our one epoch function and store the loss
# ===============================================================
demo_loss = train_one_epoch(model, loader, criterion, optimizer)


# ================================================================
# Here we run a validation demo to check IoU and Dice calculations
# ================================================================
val_loss, val_iou, val_dice = validate(model, loader, criterion)
print(f"Validation → Loss: {val_loss:.4f}, mIoU: {val_iou:.4f}, Dice: {val_dice:.4f}")


# =============================================================
# Here we print the demos loss to 4 decimal places
# =============================================================
print(f"One-epoch demo loss: {demo_loss:.4f}")


# ==============================================================
# Next Steps and TODOs
# ==============================================================
# TODO:
# - Implement validation loop computing IoU/Dice
# - Add K-Fold CV using sklearn.model_selection.KFold
# - Add ablation study cells: change LR, dropout, augmentation
# - Save sample predictions and plot alongside ground-truth




Using device: cuda
Using downloaded and verified file: ./data/oxford_pets/images.tar.gz
Extracting ./data/oxford_pets/images.tar.gz to ./data/oxford_pets/
Using downloaded and verified file: ./data/oxford_pets/annotations.tar.gz
Extracting ./data/oxford_pets/annotations.tar.gz to ./data/oxford_pets/
Dataset loaded, sample batch shape: torch.Size([4, 3, 256, 256])
Model params: 25,867,075
Validation → Loss: 0.2248, mIoU: 0.7727, Dice: 0.8605
One-epoch demo loss: 0.3032
