In [20]:
! pip install -r '../requirements.txt'


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m25.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [21]:
import os
import torch
from PIL import Image
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torch import nn
import numpy as np
from src.model import UNet
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms.functional as TF
import albumentations as A
from torch.optim.lr_scheduler import ReduceLROnPlateau
from matplotlib import pyplot as plt

In [22]:
device=torch.device('mps')
image_dim = 256

In [23]:
# load image and mask files
image_dir = '../data/images'
mask_dir = '../data/masks'

image_files = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.jpg')])
mask_files = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if f.endswith('.png')])
assert len(image_files) == len(mask_files), "Number of images and masks must match!"

In [24]:


# define image and mask transforms
image_transform = transforms.Compose([
    transforms.Resize((image_dim, image_dim)),
    transforms.ToTensor()
])

# resize targets to 68x68 (match model output)
def mask_transform(mask):
    mask = TF.resize(mask, size=(image_dim, image_dim), interpolation=TF.InterpolationMode.NEAREST)
    mask = torch.tensor(np.array(mask), dtype=torch.long, device=device)
    mask[mask==255]=1
    return mask
    
# define augmentation pipeline for a more flexible model
augmentation = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
], additional_targets={'mask': 'mask'})

In [25]:
# train test split
train_images, val_images, train_masks, val_masks = train_test_split(image_files, mask_files, test_size=0.2, random_state=42)

In [26]:
# create torch dataset to easily load and preprocess images
class SegmentationDataset(Dataset):
    
    def __init__(self, image_paths, mask_paths, image_transform, mask_transform, augmentation=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.image_transform = image_transform
        self.mask_transform = mask_transform
        self.augmentation = augmentation

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # load image and mask
        image = Image.open(self.image_paths[idx]).convert("RGB")
        mask = Image.open(self.mask_paths[idx]).convert("L")  # Grayscale
        
        if self.augmentation:
            # image and mask must be numpy arrays for augmentation
            image = np.array(image)
            mask = np.array(mask)
            augmented = self.augmentation(image=image, mask=mask)
            image, mask = augmented['image'], augmented['mask']
            # convert image and mask back to PIL for additional transformations
            image = Image.fromarray(image)
            mask = Image.fromarray(mask)
            
        image = self.image_transform(image)
        mask = self.mask_transform(mask)
        

        return image, mask

In [27]:
# define train and val datasets and dataloaders
train_dataset = SegmentationDataset(train_images, train_masks, image_transform=image_transform, mask_transform=mask_transform, augmentation=augmentation)
val_dataset = SegmentationDataset(val_images, val_masks, image_transform=image_transform, mask_transform=mask_transform, augmentation=None) # don't touch validation images

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

print('size of training data', len(train_dataset))
print('size of validation data', len(val_dataset))

# check data
for images, masks in train_loader:
    print(f"Image batch shape: {images.shape}")
    print(f"Mask batch shape: {masks.shape}")
    break

size of training data 2075
size of validation data 519
Image batch shape: torch.Size([8, 3, 256, 256])
Mask batch shape: torch.Size([8, 256, 256])


In [28]:
# finally start training the model
num_classes = 2
lr = 1e-4
epochs = 50

In [29]:
# get checkpoint to save model
#checkpoint = torch.load('../checkpoints/checkpoint.pth')

In [30]:
# init loss with class weights
def calculate_class_weights(data_loader, num_classes):
    pixel_counts = torch.zeros(num_classes)  # To store counts for each class

    # Loop through the dataset
    for _, masks in data_loader:
        # Count pixels per class
        for c in range(num_classes):
            pixel_counts[c] += masks[:,c,:,:].sum()
        
        print(pixel_counts)

    # Total pixels in the dataset
    total_pixels = pixel_counts.sum()

    # Inverse frequency
    class_weights = total_pixels / (num_classes * pixel_counts)
    return class_weights

#class_weights = calculate_class_weights(train_loader, num_classes)
class_weights = torch.tensor([0.6341, 2.3639]) # use already calculated weights to save time

In [31]:
# init model
unet = UNet(num_classes)
#unet.load_state_dict(checkpoint['state_dict'])
unet = unet.to(device)

In [32]:
# create a custom dice loss class
class DiceLoss(nn.Module):
    
    def __init__(self):
        super(DiceLoss, self).__init__()
        self.smooth = 1.
        
    def forward(self, preds, targets):
        preds = preds.softmax(dim=1)[:, 1, :, :]# [batch_size, height, width]
        preds_flat = preds.view(preds.size(0),-1)
        target_flat = targets.view(targets.size(0),-1)
        intersection = (preds_flat * target_flat).sum(dim=1)
        dice = (2. * intersection + self.smooth) / (preds_flat.sum(dim=1) + target_flat.sum(dim=1) + self.smooth)
        return 1 - dice.mean()

# custom focal loss class
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        # Apply sigmoid activation for binary segmentation
        probs = torch.sigmoid(logits)
        targets = targets.float()

        # Calculate the focal loss
        pt = probs * targets + (1 - probs) * (1 - targets)
        focal_weight = self.alpha * (1 - pt).pow(self.gamma)
        loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none') * focal_weight

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

# create hybrid dice and bce loss to optimize for both
class HybridDiceBCELoss(nn.Module):
    def __init__(self, weights=None):
        super(HybridDiceBCELoss, self).__init__()
        self.dice_loss = DiceLoss()
        self.bce_loss = nn.CrossEntropyLoss(weight=weights)
        
    def forward(self, preds, targets):
        dice = self.dice_loss(preds, targets)
        bce_loss = self.bce_loss(preds, targets)
        return dice + bce_loss
    
criterion = HybridDiceBCELoss(class_weights)

In [33]:
# init optimizer
optimizer = torch.optim.Adam(unet.parameters(), lr=lr)
#optimizer.load_state_dict(checkpoint['optimizer'])

In [34]:
# add scheduler to prevent overfitting
scheduler = ReduceLROnPlateau(optimizer,
                              mode='min',
                              factor=0.1,
                              patience=3,
                              threshold=1e-3,
                              verbose=True)

In [35]:
# use to calculate mean intersection over union for evaluating performance (with one hot encodings)
def get_miou(predictions, targets):
    # confusion matrix between predicted and ground truth class labels
    conf_matrix = torch.zeros(num_classes, num_classes, device=device)

    predictions = predictions.flatten()
    targets = targets.flatten()

    #print('pred',predictions)
    #print('targets',targets)

    # rows are predictions, cols are targets
    for c in range(num_classes):
        for t in range(num_classes):
            conf_matrix[c,t] += torch.sum((predictions == c) & (targets == t))

    #print(conf_matrix)

    # get iou for each, then calculate mean later
    class_ious = []
    for c in range(num_classes):
        tp = conf_matrix[c,c]
        fp = conf_matrix[c,:].sum() - conf_matrix[c,c]
        fn = conf_matrix[:,c].sum() - conf_matrix[c,c]

        U = tp + fp + fn
        #print(U)

        if U == 0: # prevent division by 0
            class_ious.append(float('nan'))
        else:
            class_ious.append(tp.float()/U)

    #print(class_ious)

    return torch.nanmean(torch.tensor(class_ious, device=device))

In [36]:
# use to evaluate with val data
def eval_model(model, val_loader, criterion):
    model.eval()
    loss_sum = 0
    miou_sum = 0
    batches = len(val_loader)
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)

            out = model(images)
            loss = criterion(out.to('cpu'), masks.to('cpu'))
            loss_sum += loss.item()

            predictions = torch.argmax(out, dim=1)
            miou = get_miou(predictions, masks.to(device)).item()
            miou_sum += miou

    avg_val_loss = loss_sum / batches
    avg_val_miou = miou_sum / batches
    return avg_val_loss, avg_val_miou

In [37]:
# init summary writer
writer = SummaryWriter('../runs/unet')

In [38]:
# training loop

cur_epoch = 0
#cur_epoch = checkpoint['epoch']

early_stop_tolerance = 3
best_val_loss = float('inf')
epochs_wo_improve = 0

unet.train()
for epoch in range(cur_epoch, cur_epoch + epochs):
    mious = 0
    losses = 0
    n = len(train_loader)
    for batch_idx, (images, masks)in enumerate(train_loader):
        #print(images.shape, masks.shape)
        #print(torch.unique(masks))
        images = images.to(device)
        masks = masks.to(device)
        out = unet(images)
        #out = out.to(device)
        #print(out.shape)
        #print(masks.shape)
        #print(f"Output device: {out.dtype}")
        #print(f"Mask device: {masks.dtype}")
        # have to put temporarily on cpu as torch has some sort of issue with mps here :(
        loss = criterion(out.to('cpu'), masks.to('cpu'))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # turn probabilities into hard predictions
        predictions = torch.argmax(out, dim=1)

        print('loss',loss.item())
        #print('pred',torch.unique(predictions))
        #print('mask',torch.unique(masks))
        #print('out',torch.unique(out))
        miou = get_miou(predictions, masks).item()
        print('miou',miou)

        losses += loss.item()
        mious += miou

        # add train and mIoU for individual
        writer.add_scalar('Loss/train', loss.item(), epoch * n + batch_idx)
        writer.add_scalar('mIoU/train', miou, epoch * n + batch_idx)

    # add train loss and mIoU for batch
    print("epoch mIoU:",mious/n)
    print("epoch loss:",losses/n)
    writer.add_scalar(' train loss/epoch', losses/n, epoch)
    writer.add_scalar('train mIoU/epoch', mious/n, epoch)

    # add val loss and mIoU
    val_loss, val_miou = eval_model(unet, val_loader, criterion)
    unet.train()
    writer.add_scalar('val loss/epoch', val_loss, epoch)
    writer.add_scalar('val mIoU/epoch', val_miou, epoch)
    #add images at end of each epoch to visualize predictions
    sample_images = images[:5]
    sample_preds = predictions[:5]
    sample_masks = masks[:5]
    #print(sample_images.shape, sample_masks.shape, sample_preds.shape)

    writer.add_images("Images", sample_images, epoch)
    writer.add_images("Predictions", sample_preds.unsqueeze(1), epoch)
    writer.add_images("Masks", sample_masks.unsqueeze(1), epoch)

    # save progress
    checkpoint = {
        'epoch': epoch + 1,
        'state_dict': unet.state_dict(),
        'optimizer': optimizer.state_dict(),
    }

    torch.save(checkpoint, '../checkpoints/checkpoint.pth')
    scheduler.step(val_loss)

    if val_loss < best_val_loss - 1e-4:
        best_val_loss = val_loss
        epochs_wo_improve = 0
    else:
        epochs_wo_improve += 1

    if epochs_wo_improve >= early_stop_tolerance:
        print(f"Early stopping at epoch {epoch+1}")
        break

loss 1.5511934757232666
miou 0.17394506931304932
loss 1.5200855731964111
miou 0.19974520802497864
loss 1.4140112400054932
miou 0.26971787214279175
loss 1.2769780158996582
miou 0.37916165590286255
loss 1.23909592628479
miou 0.4684666395187378
loss 1.3251407146453857
miou 0.4030883312225342
loss 0.9536775946617126
miou 0.730985164642334
loss 1.4475440979003906
miou 0.38022080063819885
loss 1.2327176332473755
miou 0.5557988882064819
loss 1.1454436779022217
miou 0.5791677832603455
loss 1.246260166168213
miou 0.4347606897354126
loss 1.0435811281204224
miou 0.4976620078086853
loss 1.070786476135254
miou 0.5088837146759033
loss 1.3743486404418945
miou 0.44786781072616577
loss 1.2000210285186768
miou 0.4489957094192505
loss 1.2041504383087158
miou 0.5120629072189331
loss 1.164251446723938
miou 0.5160503387451172
loss 1.1034302711486816
miou 0.4778629541397095
loss 1.181567668914795
miou 0.4185677170753479
loss 1.0245792865753174
miou 0.4678594172000885
loss 0.8446817398071289
miou 0.6296760439

In [39]:
# close writer
writer.close()

In [40]:
def visualize_predictions(inputs, predictions, ground_truths, n_samples=5):
    """
    Visualize input images, predicted masks, and ground truth masks.

    Parameters:
    - inputs: Batch of input images (Tensor or NumPy array).
    - predictions: Batch of predicted masks (Tensor or NumPy array).
    - ground_truths: Batch of ground truth masks (Tensor or NumPy array).
    - n_samples: Number of samples to visualize.
    """
    n_samples = min(n_samples, len(inputs))
    plt.figure(figsize=(15, 5 * n_samples))

    for i in range(n_samples):
        plt.subplot(n_samples, 3, i * 3 + 1)
        plt.imshow(inputs[i][0], cmap='gray')
        plt.title('Input Image')
        plt.axis('off')

        plt.subplot(n_samples, 3, i * 3 + 2)
        plt.imshow(predictions[i][0], cmap='gray')
        plt.title('Predicted Mask')
        plt.axis('off')

        plt.subplot(n_samples, 3, i * 3 + 3)
        plt.imshow(ground_truths[i][0], cmap='gray')
        plt.title('Ground Truth Mask')
        plt.axis('off')

    plt.tight_layout()
    plt.show()


In [41]:
unet.eval()
for i in range(3):
    image, mask = val_dataset.__getitem__(i)
    
    pred = unet(image)

    plt.imshow(image[0], cmap='gray')
    plt.title('Input Image')
    plt.axis('off')

    plt.imshow(pred[0], cmap='gray')
    plt.title('Predicted Mask')
    plt.axis('off')

    plt.imshow(mask[0], cmap='gray')
    plt.title('Ground Truth Mask')
    plt.axis('off')

RuntimeError: slow_conv2d_forward_mps: input(device='cpu') and weight(device=mps:0')  must be on the same device