In [None]:
! pip install -r '../requirements.txt'

In [188]:
import os
import torch
from PIL import Image
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torch import nn
import numpy as np
from src.model import UNet
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms.functional as TF

In [189]:
device=torch.device('mps')

In [190]:
# load image and mask files
image_dir = '../data/images'
mask_dir = '../data/masks'

image_files = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.jpg')])
mask_files = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if f.endswith('.png')])
assert len(image_files) == len(mask_files), "Number of images and masks must match!"

In [191]:
# define image and mask transforms
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# resize targets to 68x68 (match model output)
def mask_transform(mask):
    mask = TF.resize(mask, size=(68, 68), interpolation=TF.InterpolationMode.NEAREST)
    mask = torch.tensor(np.array(mask), dtype=torch.long, device=device)
    mask[mask==255]=1
    return mask
    

In [192]:
# train test split
train_images, val_images, train_masks, val_masks = train_test_split(image_files, mask_files, test_size=0.2, random_state=42)

In [193]:
# create torch dataset to easily load and preprocess images
class SegmentationDataset(Dataset):
    
    def __init__(self, image_paths, mask_paths, image_transform=None, mask_transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.image_transform = image_transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # load image and mask
        image = Image.open(self.image_paths[idx]).convert("RGB")
        mask = Image.open(self.mask_paths[idx]).convert("L")  # Grayscale

        # apply transformations
        if self.image_transform:
            image = self.image_transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return image, mask

In [194]:
# define train and val datasets and dataloaders
train_dataset = SegmentationDataset(train_images, train_masks, image_transform=image_transform, mask_transform=mask_transform)
val_dataset = SegmentationDataset(val_images, val_masks, image_transform=image_transform, mask_transform=mask_transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

print('size of training data', len(train_dataset))
print('size of validation data', len(val_dataset))

# check data
for images, masks in train_loader:
    print(f"Image batch shape: {images.shape}")
    print(f"Mask batch shape: {masks.shape}")
    break

size of training data 2075
size of validation data 519
Image batch shape: torch.Size([8, 3, 256, 256])
Mask batch shape: torch.Size([8, 68, 68])


In [195]:
# finally start training the model

num_classes = 2
print("train data size:", len(train_dataset))
print("val data size:", len(val_dataset))

train data size: 2075
val data size: 519


In [196]:
# get checkpoint to save model
#checkpoint = torch.load('../checkpoints/checkpoint.pth')

In [197]:
# init model
unet = UNet(num_classes)
#unet.load_state_dict(checkpoint['state_dict'])
unet = unet.to(device)

In [198]:
# init loss with class weights
def calculate_class_weights(data_loader, num_classes):
    pixel_counts = torch.zeros(num_classes)  # To store counts for each class

    # Loop through the dataset
    for _, masks in data_loader:
        # Count pixels per class
        for c in range(num_classes):
            pixel_counts[c] += masks[:,c,:,:].sum()
        
        print(pixel_counts)

    # Total pixels in the dataset
    total_pixels = pixel_counts.sum()

    # Inverse frequency
    class_weights = total_pixels / (num_classes * pixel_counts)
    return class_weights

#class_weights = calculate_class_weights(train_loader, num_classes)
class_weights = torch.tensor([0.6341, 2.3639])
criterion = nn.CrossEntropyLoss(weight=class_weights)

In [199]:
# init optimizer
optimizer = torch.optim.Adam(unet.parameters(), lr=1e-4)
#optimizer.load_state_dict(checkpoint['optimizer'])

In [200]:
# use to calculate mean intersection over union for evaluating performance (with one hot encodings)
def get_miou(predictions, targets):
    # confusion matrix between predicted and ground truth class labels
    conf_matrix = torch.zeros(num_classes, num_classes, device=device)

    predictions = predictions.flatten()
    targets = targets.flatten()

    #print('pred',predictions)
    #print('targets',targets)

    # rows are predictions, cols are targets
    for c in range(num_classes):
        for t in range(num_classes):
            conf_matrix[c,t] += torch.sum((predictions == c) & (targets == t))

    #print(conf_matrix)

    # get iou for each, then calculate mean later
    class_ious = []
    for c in range(num_classes):
        tp = conf_matrix[c,c]
        fp = conf_matrix[c,:].sum() - conf_matrix[c,c]
        fn = conf_matrix[:,c].sum() - conf_matrix[c,c]

        U = tp + fp + fn
        #print(U)

        if U == 0: # prevent division by 0
            class_ious.append(float('nan'))
        else:
            class_ious.append(tp.float()/U)

    #print(class_ious)

    return torch.nanmean(torch.tensor(class_ious, device=device))

In [201]:
# use to evaluate with val data
def eval_model(model, val_loader, criterion):
    loss_sum = 0
    miou_sum = 0
    batches = len(val_loader)
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)

            out = model(images)
            loss = criterion(out.to('cpu'), masks.to('cpu'))
            loss_sum += loss.item()

            predictions = torch.argmax(out, dim=1)
            miou = get_miou(predictions, masks.to(device)).item()
            miou_sum += miou

    avg_val_loss = loss_sum / batches
    avg_val_miou = miou_sum / batches
    return avg_val_loss, miou_sum

In [202]:
# init summary writer
writer = SummaryWriter('../runs/unet')

In [None]:
# training loop

cur_epoch = 0
#cur_epoch = checkpoint['epoch']

epochs = 20
for epoch in range(cur_epoch, cur_epoch + epochs):
    mious = 0
    losses = 0
    n = len(train_loader)
    for batch_idx, (images, masks)in enumerate(train_loader):
        #print(images.shape, masks.shape)
        #print(torch.unique(masks))
        images = images.to(device)
        masks = masks.to(device)
        out = unet(images)
        out = out.to(device)
        #print(out.shape)
        #print(masks.shape)
        #print(f"Output device: {out.dtype}")
        #print(f"Mask device: {masks.dtype}")
        # have to put temporarily on cpu as torch has some sort of issue with mps here :(
        loss = criterion(out.to('cpu'), masks.to('cpu'))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # turn probabilities into hard predictions
        predictions = torch.argmax(out, dim=1)

        print('loss',loss.item())
        #print('pred',torch.unique(predictions))
        #print('mask',torch.unique(masks))
        #print('out',torch.unique(out))
        miou = get_miou(predictions, masks).item()
        print('miou',miou)

        losses += loss.item()
        mious += miou

        # add train and mIoU for individual
        writer.add_scalar('Loss/train', loss.item(), epoch * n + batch_idx)
        writer.add_scalar('mIoU/train', miou, epoch * n + batch_idx)

    # add train loss and mIoU for batch
    print("epoch mIoU:",mious/n)
    print("epoch loss:",losses/n)
    writer.add_scalar(' train loss/epoch', losses/n, epoch)
    writer.add_scalar('train mIoU/epoch', mious/n, epoch)

    # add val loss and mIoU
    val_loss, val_miou = eval_model(unet, val_loader, criterion)
    writer.add_scalar('val loss/epoch', val_loss, epoch)
    writer.add_scalar('val mIoU/epoch', val_miou, epoch)
    if epoch % 3 == 0:
        sample_images = images[:5]
        sample_preds = predictions[:5]
        sample_masks = masks[:5]
        print(sample_images.shape, sample_masks.shape, sample_preds.shape)

        writer.add_images("Images", sample_images, epoch)
        writer.add_images("Predictions", sample_preds.unsqueeze(1), epoch)
        writer.add_images("Masks", sample_masks.unsqueeze(1), epoch)

    # save progress
    checkpoint = {
        'epoch': epoch + 1,
        'state_dict': unet.state_dict(),
        'optimizer': optimizer.state_dict(),
    }

    torch.save(checkpoint, '../checkpoints/checkpoint.pth')

loss 0.6949171423912048
miou 0.40026551485061646
loss 0.6989352107048035
miou 0.3800819516181946
loss 0.6814723610877991
miou 0.47411423921585083
loss 0.6883294582366943
miou 0.4058724045753479
loss 0.6892361640930176
miou 0.4031760096549988


In [None]:
# close writer
writer.close()