In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd drive/My \Drive/Acad/ADS/Project/

In [None]:
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Contracting Path
        self.enc1 = self.conv_block(3, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)

        self.pool = nn.MaxPool2d(2)

        # Expansive Path
        self.up3 = self.upconv_block(512, 256)
        self.dec3 = self.conv_block(512, 256)
        self.up2 = self.upconv_block(256, 128)
        self.dec2 = self.conv_block(256, 128)
        self.up1 = self.upconv_block(128, 64)
        self.dec1 = self.conv_block(128, 64)

        self.out_conv = nn.Conv2d(64, 1, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        return block

    def upconv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        )
        return block

    def forward(self, x):
        # Contracting Path
        e1 = self.enc1(x)
        p1 = self.pool(e1)
        e2 = self.enc2(p1)
        p2 = self.pool(e2)
        e3 = self.enc3(p2)
        p3 = self.pool(e3)
        e4 = self.enc4(p3)

        # Expansive Path
        up3 = self.up3(e4)
        merge3 = torch.cat([up3, e3], dim=1)
        d3 = self.dec3(merge3)

        up2 = self.up2(d3)
        merge2 = torch.cat([up2, e2], dim=1)
        d2 = self.dec2(merge2)

        up1 = self.up1(d2)
        merge1 = torch.cat([up1, e1], dim=1)
        d1 = self.dec1(merge1)

        out = self.out_conv(d1)
        return torch.sigmoid(out)

In [None]:
'''
Notebook to train UNet for segmentation
'''

# Import required libraries
import os
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import optim
import torchvision
from PIL import Image
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

def dice_loss(pred, target, smooth = 1e-6):
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()

    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - dice

class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha

    def forward(self, pred, target):
        bce = F.binary_cross_entropy(pred, target, reduction='mean')
        dice = dice_loss(pred, target)
        return self.alpha * bce + (1 - self.alpha) * dice

# Load training data
# Labels (Outputs)
x_trainHR = np.load('./Data/unet/train_labels.npy').astype(np.float32)
# Images (Conditions)
x_trainLR = np.load('./Data/unet/train_images.npy').astype(np.float32)
x_trainHR = torch.Tensor(x_trainHR)
x_trainLR = torch.Tensor(x_trainLR)
# Print data dimensions
print(x_trainHR.shape)
print(x_trainLR.shape)

# Create dataset and dataloader for efficient data loading and batching
dataset = TensorDataset(x_trainHR,x_trainLR)
dataloader = DataLoader(dataset, batch_size=5)

l = len(dataloader)
device = "cuda"
model = UNet().to(device)
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
loss_function = CombinedLoss(alpha=0.5)
epochs = 500

for epoch in range(epochs):
    print(f"Starting epoch {epoch + 1}:")

    epoch_loss = 0
    pbar = tqdm(dataloader)

    for labels, images in pbar:
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = loss_function(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        pbar.set_postfix(Dice_Loss=loss.item())

    avg_loss = epoch_loss / len(dataloader)
    print(f'Average Loss for Epoch {epoch + 1}: {avg_loss:.5f}\n')

    # Save model weights
    torch.save(model, os.path.join("Weights", f"UNet_ckpt_2.pt"))