# Model training - Semantic Segmentation

This script performs the semantic segmentation model training of the Coco dataset.

It loads the pre-processed data from the directory, build the model network, train it.

- It uses masks that were built assigning each pixel in the image to its belonging classes (one mask for each image).
- It uses Sparse Cross Entropy (because of the point above, and the fact that each pixel only belongs to one class).

In [1]:
import os
import cv2
import torch
import json
import time
import logging
import numpy as np
from tqdm import tqdm  # Import tqdm for progress bar

from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

from models.unet_parts import *

## Create Dataset and Dataloader

In [11]:
class COCOSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        """
        Args:
            image_dir (str): Path to the directory containing images.
            mask_dir (str): Path to the directory containing instance masks.
            transform (callable, optional): Optional transform to be applied to images.
        """
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(".png")])[:100] # only consider first 100
        self.mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith(".png")])[:100]

    def __len__(self):
        """
        Return the number of files in the image dataset (each image correspond to one mask)
        """
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load the images and masks
        image_name = self.image_files[idx]
        image_path = os.path.join(self.image_dir, image_name)
        mask_name = image_name  # Image and mask have the same filename
        mask_path = os.path.join(self.mask_dir, mask_name)

        if not mask_name:
            return None  # No mask found, handle accordingly
        
        # Set final image sizes (650x700), which includes all sizes
        image_size_h = 650
        image_size_w = 700

        # Load image
        image_original = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # reading as is

        # Pad image to match desired size
        original_h, original_w, _ = image_original.shape
        pad_h = max(0, (image_size_h - original_h) // 2)
        pad_w = max(0, (image_size_w - original_w) // 2)

        image_padded = np.pad(image_original, ((pad_h, image_size_h - original_h - pad_h), (pad_w, image_size_w - original_w - pad_w), (0, 0)), mode='constant', constant_values=0)
        
        # Load mask (grayscale) and expand values
        mask_original = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        
        # Pad mask to match desired size instead of interpolation
        original_h, original_w = mask_original.shape

        pad_h = max(0, (image_size_h - original_h) // 2)
        pad_w = max(0, (image_size_w - original_w) // 2)

        mask_padded = np.pad(mask_original, ((pad_h, image_size_h - original_h - pad_h), (pad_w, image_size_w - original_w - pad_w)), mode='constant', constant_values=0)
        print('check')
        # Ensure mask values remain categorical (0 to 255 after expansion)
        mask_tensor = torch.tensor(mask_padded, dtype=torch.float32)
        print('check')
        # Convert image to tensor
        image_tensor = torch.tensor(image_padded, dtype=torch.float32).permute(2, 0, 1) # in tensors, channels must be first dimension
        print('check')
        return image_tensor, mask_tensor

In [12]:
image_val_dir = "/home/maver02/Development/Datasets/COCO/preprocess_coco_2_v1/val/images"
image_train_dir = "/home/maver02/Development/Datasets/COCO/preprocess_coco_2_v1/train/images"

mask_val_dir = "/home/maver02/Development/Datasets/COCO/preprocess_coco_2_v1/val/masks"
mask_train_dir = "/home/maver02/Development/Datasets/COCO/preprocess_coco_2_v1/train/masks"

instances_val_dir = "/home/maver02/Development/Datasets/COCO/annotations/instances_val2017.json"
instances_train_dir = "/home/maver02/Development/Datasets/COCO/annotations/instances_val2017.json"

In [13]:
test_data = COCOSegmentationDataset(image_val_dir, mask_val_dir)
train_data = test_data # use test data for now as it is smaller

batch_size = 2  # Reduce to avoid OOM
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

## Model creation

In [16]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = "cpu"
logging.info(f'Using device {device}')

# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel

# Create model instance and move to device
model = UNet(n_channels=3, n_classes=91).to(device)

logging.info(f'Network:\n'
                f'\t{model.n_channels} input channels\n'
                f'\t{model.n_classes} output channels (classes)\n'
                f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')

print(model)

INFO: Using device cpu
INFO: Network:
	3 input channels
	91 output channels (classes)
	Transposed conv upscaling


UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [17]:
# Define loss, optimizer, epochs
loss_fn = nn.CrossEntropyLoss() # As we have multiclass represented as pixel integers in masks
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
epochs = 3

## Training

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    """
    In a single training loop, the model makes predictions on the training dataset (fed to it in batches), 
    and backpropagates the prediction error to adjust the model’s parameters.
    """
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    """
    We also check the model’s performance against the test dataset to ensure it is learning.
    """
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [19]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
check
check
check
check
check
check


RuntimeError: expected scalar type Long but found Float