In [1]:
import os
import numpy as np
from skimage.io import imread
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from data_augmentation import SegmentationDataset
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms.functional as TF

In [2]:
test_path = "../Dataset/Test/color"
test_label_path = "../Dataset/Test/label"
train_path = "../Dataset/TrainVal/color"
train_label_path = "../Dataset/TrainVal/label"

In [3]:
# ------------------------------------------
# 1. Dataset Class for Image Segmentation
# ------------------------------------------

transform = A.Compose([
    ######### TODO: Maybe take out this fist padding ##########
    # A.PadIfNeeded(min_height=300, min_width=300, border_mode=0, value=(0, 0, 0)),  # Pad small images to 300x300
    ######################################################
    A.LongestMaxSize(max_size=300, interpolation=0),  # Resize longest side to 300 (if necessary)
    A.PadIfNeeded(min_height=300, min_width=300, border_mode=0, value=(0, 0, 0)),  # Pad remaining images to 300x300
    A.RandomCrop(256, 256),  # Crop to fixed size
    A.HorizontalFlip(p=0.5),  # Flip images & masks with 50% probability
    A.Rotate(limit=20, p=0.5),  # Random rotation (-20° to 20°)
    # A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.3),  # Elastic distortion
    # A.GridDistortion(p=0.3),  # Slight grid warping
    # A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),  # Color jitter
    # A.GaussianBlur(blur_limit=(3, 7), p=0.2),  # Random blur
    # A.GaussNoise(var_limit=(10, 50), p=0.2),  # Random noise
    # A.CoarseDropout(max_holes=2, max_height=50, max_width=50, p=0.3),  # Cutout occlusion
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # Standard normalization
    ToTensorV2()  # Convert to PyTorch tensor
])

train_dataset = SegmentationDataset(image_dir=train_path, mask_dir=train_label_path, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

  A.PadIfNeeded(min_height=300, min_width=300, border_mode=0, value=(0, 0, 0)),  # Pad remaining images to 300x300


In [4]:
# ------------------------------------------
# 2. Define the UNet Model
# ------------------------------------------
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNet, self).__init__()
        
        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            )
        
        self.encoder1 = conv_block(in_channels, 64)
        self.encoder2 = conv_block(64, 128)
        self.encoder3 = conv_block(128, 256)
        self.encoder4 = conv_block(256, 512)

        self.bottleneck = conv_block(512, 1024)

        self.decoder4 = conv_block(1024, 512)  # Should match concat(e4, upconv4)
        self.decoder3 = conv_block(512, 256)   # Should match concat(e3, upconv3)
        self.decoder2 = conv_block(256, 128)   # Should match concat(e2, upconv2)
        self.decoder1 = conv_block(128, 64)    # Should match concat(e1, upconv1)


        self.final_layer = nn.Conv2d(64, out_channels, kernel_size=1)

        self.pool = nn.MaxPool2d(2, 2)
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)  # 1024 → 512
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)   # 512 → 256
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)   # 256 → 128
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)    # 128 → 64

    def forward(self, x):
        e1 = self.encoder1(x)
        e2 = self.encoder2(self.pool(e1))
        e3 = self.encoder3(self.pool(e2))
        e4 = self.encoder4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))
    
        d4 = self.upconv4(b)    # [1024 → 512]
        d4 = torch.cat((e4, d4), dim=1)  # [512 + 512 = 1024]
        d4 = self.decoder4(d4)  # Output: [1024 → 512]

        d3 = self.upconv3(d4)   # [512 → 256]
        d3 = torch.cat((e3, d3), dim=1)  # [256 + 256 = 512]
        d3 = self.decoder3(d3)  # Output: [512 → 256]

        d2 = self.upconv2(d3)   # [256 → 128]
        d2 = torch.cat((e2, d2), dim=1)  # [128 + 128 = 256]
        d2 = self.decoder2(d2)  # Output: [256 → 128]

        d1 = self.upconv1(d2)   # [128 → 64]
        d1 = torch.cat((e1, d1), dim=1)  # [64 + 64 = 128]
        d1 = self.decoder1(d1)  # Output: [128 → 64]

        return self.final_layer(d1)  # Output: [64 → num_classes]


In [5]:
from tqdm import tqdm

# ------------------------------------------
# 3. Train the UNet Model
# ------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = UNet(in_channels=3, out_channels=3).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=3)

def train(model, dataloader, optimizer, criterion, epochs=50):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks.long())
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            progress_bar.set_postfix({'Loss': loss.item()})
        print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(dataloader):.4f}")

Using device: cpu


In [6]:
train(model, train_dataloader, optimizer, criterion)

# ------------------------------------------
# 4. Save and Evaluate Model
# ------------------------------------------
torch.save(model.state_dict(), "unet_baseline.pth")

Epoch 1/50:   0%|          | 0/460 [09:05<?, ?it/s, Loss=1.01] 

KeyboardInterrupt: 