### Importing Libraries

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

import torch.optim as optim

import os

from PIL import Image, ImageOps

import random

#import any other library you need below this line

### Loading data

Upload the data in zip format to Colab. Then run the cell below.

In [None]:
!unzip data.zip

### Defining the Dataset Class

In [4]:
import cv2
import torch
import torchvision.transforms.transforms
from torch.utils.data import Dataset, DataLoader

import numpy as np

import os
import random

from PIL import Image, ImageOps

# import any other libraries you need below this line
import torchvision.transforms.functional as v1
import torchvision.transforms.v2 as v2


class Cell_data(Dataset):
    def __init__(self, data_dir, size, train='True', train_test_split=0.8, augment_data=True):
        ##########################inputs##################################
        # data_dir(string) - directory of the data#########################
        # size(int) - size of the images you want to use###################
        # train(boolean) - train data or test data#########################
        # train_test_split(float) - the portion of the data for training###
        # augment_data(boolean) - use data augmentation or not#############
        super(Cell_data, self).__init__()
        # todo
        # initialize the data class
        self.data_dir = data_dir
        self.size = size
        self.train = train
        self.train_test_split = train_test_split
        self.augment_data = augment_data

        # read images and masks
        image_path = os.path.join(data_dir, 'scans')
        mask_path = os.path.join(data_dir, 'labels')
        
        images = sorted([os.path.join(image_path, file) for file in os.listdir(image_path) if file.endswith('.bmp')])
        masks = sorted([os.path.join(mask_path, file) for file in os.listdir(mask_path) if file.endswith('.bmp')])

        # split train set & test set
        idx = int(train_test_split * len(images))
        if train:
            self.images = images[:idx]
            self.masks = masks[:idx]
        else:
            self.images = images[idx:]
            self.masks = masks[idx:]

    def __getitem__(self, idx):

        # load image and mask from index idx of your data
        image_path = self.images[idx]
        mask_path = self.masks[idx]
        image = self.load_image(image_path)
        mask = self.load_mask(mask_path)

        if not self.train:
            return image, mask

        # data augmentation part
        # reference: https://pytorch.org/vision/main/auto_examples/transforms
        # /plot_transforms_illustrations.html#sphx-glr-auto-examples-transforms-plot-transforms-illustrations-py
        if not self.augment_data:
            augment_mode = np.random.randint(0, 5)
            if augment_mode == 0:
                print("flip vertically")
                # flip image vertically
                image = v1.vflip(image)
                mask = v1.vflip(mask)
            elif augment_mode == 1:
                print("flip horizontally")
                # flip image horizontally
                image = v1.hflip(image)
                mask = v1.hflip(mask)
            elif augment_mode == 2:
                print("zoom image")
                # zoom image
                image = v2.RandomResizedCrop(size=(self.size, self.size))(image)
                mask = v2.RandomResizedCrop(size=(self.size, self.size))(image)
            elif augment_mode == 3:
                print("rotate image")
                # rotate image
                image = v2.RandomRotation((0, 360))(image)
                mask = v2.RandomRotation((0, 360))(mask)
            elif augment_mode == 4:
                print("non-rigid transformation")
                # Convert image and mask tensors to PIL images
                image_pil = v1.to_pil_image(image)
                mask_pil = v1.to_pil_image(mask)

                # Apply ElasticTransform
                elastic_transform = v2.ElasticTransform(alpha=50.0)
                image_pil = elastic_transform(image_pil)
                mask_pil = elastic_transform(mask_pil)

                # Convert back to tensors
                image = v1.to_tensor(image_pil)
                mask = v1.to_tensor(mask_pil)
            else:
                print("gamma correction")

        # todo
        # return image and mask in tensors
        return image, mask

    def __len__(self):
        return len(self.images)

    # Helper function to load images, given file path, return a tensor
    def load_image(self, path):
        image = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), (self.size, self.size))
        image = cv2.normalize(image, image, alpha=0.0, beta=1.0, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
        image_tensor = torch.from_numpy(image).unsqueeze(0)  # convert to tensor
        return image_tensor

    # Helper function to load masks, given file path, return a tensor
    def load_mask(self, path):
        image = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), (self.size, self.size))
        image_tensor = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)  # convert to tensor
        return image_tensor


### Define the Model
1. Define the Convolution blocks
2. Define the down path
3. Define the up path
4. combine the down and up path to get the final model

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# import any other libraries you need below this line
import torchvision.transforms.functional as v1

class twoConvBlock(nn.Module):
    """Part 1  The Convolutional blocks"""

    # initialize the block
    def __init__(self, input_channel, output_channel):
        super(twoConvBlock, self).__init__()
        self.doubleConvBlock = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, kernel_size=3, padding=1),  # 3 × 3 un-padded convolution layer
            nn.ReLU(inplace=True),  # ReLU
            nn.Conv2d(output_channel, output_channel, kernel_size=3, padding=1),  # 3 × 3 un-padded convolution layer
            nn.BatchNorm2d(output_channel),  # Batch normalization layer
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # implement the forward path
        return self.doubleConvBlock(x)


class downStep(nn.Module):
    """Part 2  The Contracting path"""

    # initialize the down path
    def __init__(self, input_channel, output_channel):
        super(downStep, self).__init__()
        self.convBlock = twoConvBlock(input_channel, output_channel)  # 2 conv blocks
        self.maxPool = nn.MaxPool2d(kernel_size=2)  # 2 x 2 max pool

    def forward(self, x):
        # implement the forward path
        x = self.convBlock(x)
        x_maxPool = self.maxPool(x)
        return x, x_maxPool


class upStep(nn.Module):
    """Part 3  The Expansive path"""

    def __init__(self, input_channel, output_channel):
        super(upStep, self).__init__()
        # initialize the up path
        self.upConv = nn.ConvTranspose2d(input_channel, output_channel, kernel_size=2,
                                         stride=2)  # transpose convolutions
        self.convBlock = twoConvBlock(output_channel * 2, output_channel)  #

    def forward(self, x, skip_connection):
        # implement the forward path
        x = self.upConv(x)

        # process crop and copy
        # c,h,w
        diffY = skip_connection.size()[2] - x.size()[2]
        diffX = skip_connection.size()[3] - x.size()[3]
        # reference : https://github.com/milesial/Pytorch-UNet/blob/67bf11b4db4c5f2891bd7e8e7f58bcde8ee2d2db/unet/unet_parts.py
        skip_connection = F.pad(skip_connection, [-diffX // 2, -(diffX - diffX // 2),
                                                  -diffY // 2, -(diffY - diffY // 2)])
        target_size = (x.size(2), x.size(3))
        skip_connection = v1.center_crop(skip_connection, output_size=target_size)
        x = torch.cat([x, skip_connection], dim=1)
        new_x = self.convBlock(x)
        return new_x


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        # initialize the complete model
        self.n_channels = n_channels
        self.n_classes = n_classes

        # Contracting
        self.inc = downStep(n_channels, 64)
        self.down1 = downStep(64, 128)
        self.down2 = downStep(128, 256)
        self.down3 = downStep(256, 512)

        # Bottom, no max pooling
        self.bot = twoConvBlock(512, 1024)

        # Expansive
        self.up1 = upStep(1024, 512)
        self.up2 = upStep(512, 256)
        self.up3 = upStep(256, 128)
        self.up4 = upStep(128, 64)

        # Output
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # implement the forward path
        x1, x1_maxpool = self.inc(x)
        x2, x2_maxpool = self.down1(x1_maxpool)
        x3, x3_maxpool = self.down2(x2_maxpool)
        x4, x4_maxpool = self.down3(x3_maxpool)

        x_bot = self.bot(x4_maxpool)

        x = self.up1(x_bot, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        x_out = self.outc(x)

        return x_out


In [6]:
print(torch.cuda.is_available())  # Should return True if CUDA is installed correctly
print(torch.cuda.device_count())  # Number of GPUs available

True
1


### Training

In [None]:
#Paramteres

#learning rate
lr = 1e-2

#number of training epochs
epoch_n = 20

#input image-mask size
image_size = 572
#root directory of project
root_dir = os.getcwd()

#training batch size
batch_size = 32

#use checkpoint model for training
load = False

#use GPU for training
gpu = True

data_dir = os.path.join(root_dir, 'data/cells')


trainset = Cell_data(data_dir = data_dir, size = image_size)
trainloader = DataLoader(trainset, batch_size = batch_size, shuffle=True, num_workers=4)

testset = Cell_data(data_dir = data_dir, size = image_size, train = False, num_workers=4)
testloader = DataLoader(testset, batch_size = batch_size)

device = torch.device('cuda:0' if gpu else 'cpu')

model = UNet(n_channels=1, n_classes=batch_size).to('cuda:0').to(device)

if load:
  print('loading model')
  model.load_state_dict(torch.load('checkpoint.pt'))

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.0005)

model.train()
for e in range(epoch_n):
  epoch_loss = 0
  model.train()
  for i, data in enumerate(trainloader):
    image, label = data

    image = image.to(device)
    label = label.squeeze(1).long().to(device)

    pred = model(image)

    crop_x = (label.shape[1] - pred.shape[2]) // 2
    crop_y = (label.shape[2] - pred.shape[3]) // 2

    label = label[:, crop_x: label.shape[1] - crop_x, crop_y: label.shape[2] - crop_y]
    
    loss = criterion(pred, label)

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

    epoch_loss += loss.item()

    print('batch %d --- Loss: %.4f' % (i, loss.item() / batch_size), flush=True)
  print('Epoch %d / %d --- Loss: %.4f' % (e + 1, epoch_n, epoch_loss / trainset.__len__()), flush=True)

  torch.save(model.state_dict(), 'checkpoint.pt')


In [None]:
  model.eval()

  total = 0
  correct = 0
  total_loss = 0

  with torch.no_grad():
    for i, data in enumerate(testloader):
      image, label = data

      image = image.to(device)
      label = label.long().to(device)

      pred = model(image)
      crop_x = (label.shape[1] - pred.shape[2]) // 2
      crop_y = (label.shape[2] - pred.shape[3]) // 2

      label = label[:, crop_x: label.shape[1] - crop_x, crop_y: label.shape[2] - crop_y]

      loss = criterion(pred, label)
      total_loss += loss.item()

      _, pred_labels = torch.max(pred, dim = 1)

      total += label.shape[0] * label.shape[1] * label.shape[2]
      correct += (pred_labels == label).sum().item()

    print('Accuracy: %.4f ---- Loss: %.4f' % (correct / total, total_loss / testset.__len__()))

### Testing and Visualization

In [None]:
model.eval()


output_masks = []
output_labels = []

with torch.no_grad():
  for i in range(testset.__len__()):
    image, labels = testset.__getitem__(i)
    
    input_image = image.unsqueeze(0).unsqueeze(0).to(device)
    pred = model(input_image)

    output_mask = torch.max(pred, dim = 1)[1].cpu().squeeze(0).numpy()

    crop_x = (labels.shape[0] - output_mask.shape[0]) // 2
    crop_y = (labels.shape[1] - output_mask.shape[1]) // 2
    labels = labels[crop_x: labels.shape[0] - crop_x, crop_y: labels.shape[1] - crop_y].numpy()
    
    output_masks.append(output_mask)
    output_labels.append(labels)


In [None]:
fig, axes = plt.subplots(testset.__len__(), 2, figsize = (20, 20))

for i in range(testset.__len__()):
  axes[i, 0].imshow(output_labels[i])
  axes[i, 0].axis('off')
  axes[i, 1].imshow(output_masks[i])
  axes[i, 1].axis('off')