### Importing Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

import torch.optim as optim

import os

from PIL import Image, ImageOps

import random

#import any other library you need below this line

### Loading data

Upload the data in zip format to Colab. Then run the cell below.

In [2]:
#!unzip data.zip

### Defining the Dataset Class

In [5]:
class Cell_data(Dataset):
    def __init__(self, data_dir, size, train = 'True', train_test_split = 0.8, augment_data = True):
        ##########################inputs##################################
        #data_dir(string) - directory of the data#########################
        #size(int) - size of the images you want to use###################
        #train(boolean) - train data or test data#########################
        #train_test_split(float) - the portion of the data for training###
        #augment_data(boolean) - use data augmentation or not#############
        super(Cell_data, self).__init__()
        # todo
        #initialize the data class
        self.scan_dir = data_dir+"/scans/"
        self.label_dir = data_dir+"/labels/"
        self.image_files = os.listdir(self.scan_dir)
        self.size = size
        self.train = train
        self.train_test_split = train_test_split
        self.augment_data = augment_data

        for image in self.image_files:
            scan_image = Image.open(self.scan_dir+image)
            label_image = Image.open(self.label_dir+image)
            #data augmentation part
            if augment_data:
                augment_mode = np.random.randint(0, 4)
                if augment_mode == 0:
                    #todo 
                    #flip image vertically
                    augmented_scan_image = TF.vflip(scan_image)
                    augmented_label_image = TF.vflip(label_image)
                elif augment_mode == 1:
                    #todo
                    #flip image horizontally
                    augmented_scan_image = TF.hflip(scan_image)
                    augmented_label_image = TF.hflip(label_image)
                elif augment_mode == 2:
                    #todo
                    #zoom image
                    zoom_size = random.randint(self.size*0.75, self.size)
                    augmented_scan_image = TF.center_crop(scan_image, zoom_size)
                    augmented_label_image = TF.center_crop(label_image, zoom_size)
                else:
                    #todo
                    #rotate image
                    angle = random.randint(-30, 30)
                    augmented_scan_image = TF.rotate(scan_image, angle)
                    augmented_label_image = TF.rotate(label_image, angle)

    def __getitem__(self, idx):
        # todo
        #load image and mask from index idx of your data
        filename = self.image_files[idx]

        #todo
        #return image and mask in tensors

    def __len__(self):
        return len(self.image_files)

### Define the Model
1. Define the Convolution blocks
2. Define the down path
3. Define the up path
4. combine the down and up path to get the final model

In [3]:
class twoConvBlock(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(twoConvBlock, self).__init__()
        #todo
        #initialize the block
        self.conv_layer1 = nn.Conv2d(input_channels, output_channels, kernel_size=5, stride=1)
        self.conv_layer2 = nn.Conv2d(output_channels, output_channels, kernel_size=5, stride=1)
        self.batch_norm_layer = nn.BatchNorm2d(output_channels)

    def forward(self, image):
        #todo
        #implement the forward path
        image = self.conv_layer1(image)
        image = F.relu(image)
        image = self.conv_layer2(image)
        image = self.batch_norm_layer(image)
        image = F.relu(image)
        return image

class downStep(nn.Module):
    def __init__(self):
        super(downStep, self).__init__()
        #todo
        #initialize the down path
        self.max_pool_layer = nn.MaxPool2d(kernel_size=2,stride=2)

    def forward(self, image):
        #todo
        #implement the forward path
        image = self.max_pool_layer(image)
        return image

class upStep(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(upStep, self).__init__()
        #todo
        #initialize the up path
        self.up_sampling_layer = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=2, stride=2)

    def forward(self, image):
        #todo
        #implement the forward path
        image = self.up_sampling_layer(image)
        return image

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        #todo
        #initialize the complete model
        self.conv1 = twoConvBlock(1, 64)
        self.conv2 = twoConvBlock(63, 128)
        self.conv3 = twoConvBlock(128, 256)
        self.conv4 = twoConvBlock(256, 512)
        self.conv5 = twoConvBlock(512, 1024)
        self.conv6 = twoConvBlock(1024, 512)
        self.conv7 = twoConvBlock(512, 256)
        self.conv8 = twoConvBlock(256, 128)
        self.conv9 = twoConvBlock(128, 64)
        self.conv10 = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1, stride=1)
        self.down_step = downStep()
        self.up_step = upStep()

    def forward(self, image):
        #todo
        #implement the forward path
        image = self.down_step(self.conv1(image))
        image = self.down_step(self.conv2(image))
        image = self.down_step(self.conv3(image))
        image = self.down_step(self.conv4(image))
        image = self.up_step(self.conv5(image))
        image = self.up_step(self.conv6(iamge))
        image = self.up_step(self.conv7(image))
        image = self.up_step(self.conv8(image))
        image = self.up_step(self.conv9(image))
        image = self.conv10(image)
        return image


### Training

In [6]:
#Paramteres

#learning rate
lr = 1e-2

#number of training epochs
epoch_n = 20

#input image-mask size
image_size = 572
#root directory of project
root_dir = os.getcwd()

#training batch size
batch_size = 4

#use checkpoint model for training
load = False

#use GPU for training
gpu = True

data_dir = os.path.join(root_dir, 'data/cells')


trainset = Cell_data(data_dir = data_dir, size = image_size)
trainloader = DataLoader(trainset, batch_size = 4, shuffle=True)

testset = Cell_data(data_dir = data_dir, size = image_size, train = False)
testloader = DataLoader(testset, batch_size = 4)

AttributeError: 'Cell_data' object has no attribute 'images'

In [None]:
device = torch.device('cuda:0' if gpu else 'cpu')

model = UNet().to('cuda:0').to(device)

if load:
    print('loading model')
    model.load_state_dict(torch.load('checkpoint.pt'))

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=lr, momentum=0.99, weight_decay=0.0005)

model.train()
for e in range(epoch_n):
    epoch_loss = 0
    model.train()
    for i, data in enumerate(trainloader):
        image, label = data

        image = image.unsqueeze(1).to(device)
        label = label.long().to(device)

        pred = model(image)

        crop_x = (label.shape[1] - pred.shape[2]) // 2
        crop_y = (label.shape[2] - pred.shape[3]) // 2

        label = label[:, crop_x: label.shape[1] - crop_x, crop_y: label.shape[2] - crop_y]
    
        loss = criterion(pred, label)

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        epoch_loss += loss.item()

        print('batch %d --- Loss: %.4f' % (i, loss.item() / batch_size))
    print('Epoch %d / %d --- Loss: %.4f' % (e + 1, epoch_n, epoch_loss / trainset.__len__()))

    torch.save(model.state_dict(), 'checkpoint.pt')

    model.eval()

    total = 0
    correct = 0
    total_loss = 0

    with torch.no_grad():
        for i, data in enumerate(testloader):
            image, label = data

            image = image.unsqueeze(1).to(device)
            label = label.long().to(device)

            pred = model(image)
            crop_x = (label.shape[1] - pred.shape[2]) // 2
            crop_y = (label.shape[2] - pred.shape[3]) // 2

            label = label[:, crop_x: label.shape[1] - crop_x, crop_y: label.shape[2] - crop_y]

            loss = criterion(pred, label)
            total_loss += loss.item()

            _, pred_labels = torch.max(pred, dim = 1)

            total += label.shape[0] * label.shape[1] * label.shape[2]
            correct += (pred_labels == label).sum().item()

        print('Accuracy: %.4f ---- Loss: %.4f' % (correct / total, total_loss / testset.__len__()))


### Testing and Visualization

In [None]:
model.eval()


output_masks = []
output_labels = []

with torch.no_grad():
    for i in range(testset.__len__()):
        image, labels = testset.__getitem__(i)
    
        input_image = image.unsqueeze(0).unsqueeze(0).to(device)
        pred = model(input_image)

        output_mask = torch.max(pred, dim = 1)[1].cpu().squeeze(0).numpy()

        crop_x = (labels.shape[0] - output_mask.shape[0]) // 2
        crop_y = (labels.shape[1] - output_mask.shape[1]) // 2
        labels = labels[crop_x: labels.shape[0] - crop_x, crop_y: labels.shape[1] - crop_y].numpy()
    
        output_masks.append(output_mask)
        output_labels.append(labels)


In [None]:
fig, axes = plt.subplots(testset.__len__(), 2, figsize = (20, 20))

for i in range(testset.__len__()):
    axes[i, 0].imshow(output_labels[i])
    axes[i, 0].axis('off')
    axes[i, 1].imshow(output_masks[i])
    axes[i, 1].axis('off')