# Training an Object Segmentation Workshop
### By: Aaron Gabrielle C. Dichoso
### From: DLSU - Center of Imaging and Visual Innovations (CIVI)
May 27, 2025

## 1. Importing Libraries

In [None]:
import torch
import torchvision
from torchvision.datasets import CocoDetection
from torchvision.models.segmentation import fcn_resnet50
from torchvision.transforms import functional as F
from torch.utils.data import DataLoader
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pycocotools.coco import COCO
from pycocotools import mask as coco_mask

## 2. Notebook configurations
Batch Size: The Number of Images passed to the model during training in one forward pass

Classes: Subset of 12 objects + 1 background class from the standard MSCOCO classes

["dog", "cat", "person", "chair", "mouse", "remote", "keyboard", "cell phone", "cup", "fork", "knife", "spoon"]

Epochs: Number of Iterations that the training images will be passed to the model.

Learning Rate: Affects the strength of adjustments applied to the model during training.

Device: Use CUDA if available, else use the CPU

In [None]:
BATCH_SIZE = 4
NUM_CLASSES = 12 + 1
NUM_EPOCHS = 45
LEARNING_RATE = 1e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

COCO_PATH = './dataset/'
TRAIN_IMG_DIR = os.path.join(COCO_PATH, 'coco_sample/train')
TEST_IMG_DIR = os.path.join(COCO_PATH, 'coco_sample/test')
VAL_IMG_DIR = os.path.join(COCO_PATH, 'coco_sample/val')
TRAIN_ANN_FILE = os.path.join(COCO_PATH, 'coco_sample/train.json')
TEST_ANN_FILE = os.path.join(COCO_PATH, 'coco_sample/test.json')
VAL_ANN_FILE = os.path.join(COCO_PATH, 'coco_sample/val.json')

## 3. Dataset Loader

A dataset loader is utilized in this notebook to allow modifications to the original MSCOCO dataset:

1. Instead of the 91 classes in MSCOCO, only 13 classes are used
2. Allows you to apply transformations to images before loading the image

In [None]:
class CocoSegmentation(torch.utils.data.Dataset):
    def __init__(self, root, annFile):
        pass


    def __getitem__(self, idx):
        pass


    def __len__(self):
        pass

## 4. Load Datasets

In [None]:
train_dataset = None
test_dataset = None
val_dataset = None

train_loader = None
val_loader = None
test_loader = None

## 5. FCN Model Configurations

Get the FCN model from pytorch, and modify it to use the custom number of classes

In [None]:
model = fcn_resnet50(pretrained=False, num_classes=NUM_CLASSES)
model.to(DEVICE)

Criterion: Measure of Model Performance used during training.

Optimizer: Method used to determine the optimal weights during training.

Scheduler: Decays the Learning Rate over epochs

In [None]:
criterion = None
optimizer = None
scheduler = None

## 6. Training Loop

General Flow is as follows:
1. Set the model into training mode
2. For each epoch, do the following:
    a. Load the Images and Masks to the Device
    b. Zero out existing gradients
    c. Perform a Forward Pass
    d. Compute the Prediction Performance / Loss
    e. Perform a backward pass and update weights (Backpropagation)
    f. save the model checkpoint

In [None]:
from tqdm import tqdm
import os
import torch

def train(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):
    if not os.path.exists("checkpoints"):
        os.makedirs("checkpoints")  # Create a directory to save checkpoints

    best_val_loss = float('inf')

    for epoch in range(num_epochs):  # For each epoch

        for images, masks in progress_bar:
            images, masks = images.to(device), masks.to(device)  # Load Images and Masks to Device

            continue

        continue


In [None]:
train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=NUM_EPOCHS,
    device=DEVICE,
)

## 7. Evaluation Metrics

In [None]:
def pixel_accuracy(preds, labels):
    pass

def mean_iou(preds, labels, num_classes):
    pass

In [None]:
def evaluate(model, dataloader):
    model.eval()
    accs, ious = [], []
    with torch.no_grad():
        for images, masks in dataloader:
            continue

    print(f"Pixel Accuracy: {np.mean(accs):.4f}")
    print(f"Mean IoU: {np.mean(ious):.4f}")

In [None]:
evaluate(model, test_loader)

## 8. Visualization

In [None]:
import random
import torch.nn.functional as FNN

def visualize_random_samples(model, dataset, class_names, num_samples=5):
    model.eval()
    indices = random.sample(range(len(dataset)), num_samples)

    for idx in indices:
        image, mask = dataset[idx]
        with torch.no_grad():
            output = model(image.unsqueeze(0).to(DEVICE))['out']  # shape: (1, C, H, W)
            probs = FNN.softmax(output, dim=1).squeeze(0)          # shape: (C, H, W)
            pred = torch.argmax(probs, dim=0).cpu().numpy()      # shape: (H, W)
            # Calculate average confidence for each predicted class in the mask
            unique_labels = np.unique(pred)
            confs = {}
            for label in unique_labels:
                mask_pixels = (pred == label)
                # Average confidence for that class at predicted pixels
                conf = probs[label][mask_pixels].mean().item()
                confs[label] = conf

        plt.figure(figsize=(12, 4))

        plt.subplot(1, 3, 1)
        plt.imshow(image.permute(1, 2, 0))
        plt.title("Image")
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(mask)
        plt.title("Ground Truth")
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(pred)
        plt.title("Prediction")
        plt.axis('off')

        plt.show()

        print(f"Sample index: {idx}")
        print("Class labels in prediction and average confidence:")
        for label in unique_labels:
            name = class_names.get(label, "N/A")
            confidence = confs[label]
            print(f"Label {label}: {name} - Confidence: {confidence:.3f}")
        print("-" * 30)


In [None]:
visualize_random_samples(model, test_dataset, test_dataset.class_names, 5)