# Basic UNet Example

Based on:
https://github.com/overshiki/unet-pytorch/blob/master/utils.py

## Things to do

 * Clean up the plots to provide proper axes labels etc (make it look pretty!)
 * Data augmentation (but carefully as the mask needs to be done too)
 * Early stopping criterion
 * Validation in the epoch loop
 * Train / validation / test split and test at end of epochs
 * Different hyperparameters (what is the "best" learning rate, what is the "best" optimizer)?
 * How could you get more accurate masks? (In particular the finer detail)
 * How much slower is it on a CPU rather than a GPU?
 * How many datasets are required for "accurate" results?  (What is an "accurate" result)?

In [None]:
import zipfile
from io import BytesIO

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets,transforms, models


# Set Hyper Parameters

In [None]:
batch_size = 10
epochs = 30

# CPU or GPU

In [None]:
#
# GPU CPU - nice way to setup the device as it works on any machine
#
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f'Device is {device}')

if device == 'cuda':
    print(f'CUDA device {torch.cuda.device(0)}')
    print(f'Number of devices: {torch.cuda.device_count()}')
    print(f'Device name: {torch.cuda.get_device_name(0)}')

## Load in the Data

## Transforms

In [None]:
train_transform = transforms.Compose([
                                transforms.Resize(224),
                                transforms.CenterCrop(224),
    
                                # converts 0-255 to 0-1 and rowxcolxchan to chanxrowxcol
                                transforms.ToTensor(),
])

## Dataset

In [None]:
class CarDataset(Dataset):
    
    def __init__(self, zip_filename_cars, zip_filename_masks, transform):
        """
        Initialized
        """
        super().__init__()
        
        # Store variables we are interested in...
        
        self._zip_filename_cars = zip_filename_cars
        self._zip_filename_masks = zip_filename_masks
        
        self._zip_filename_cars_zf = zipfile.ZipFile(self._zip_filename_cars, "r")
        self._zip_filename_masks_zf = zipfile.ZipFile(self._zip_filename_masks, "r")
        
        self._transforms = transform        
        
    def __getitem__(self, index):
        """
        Get a single image / label pair.
        """
        
        #
        # Read in the image
        #
        name = self._zip_filename_cars_zf.namelist()[index+1]
        image = Image.open(BytesIO(self._zip_filename_cars_zf.read(name)))
        
        #
        # Read in the mask
        #
        name = name.replace('train/', 'train_masks/').replace('.jpg', '_mask.gif')
        mask = Image.open(BytesIO(self._zip_filename_masks_zf.read(name)))
        
        #
        #  Can do further processing here or anything else
        #
        
        # image = clahe(image)
        
        #
        # Do transformations on it (typicalyl data augmentation)
        #
        if self._transforms is not None:
            image = self._transforms(image)
            mask = self._transforms(mask)
                
        #
        # Return the image mask pair
        #
        return image, mask[0]>0
    
    def __len__(self):
        """
        Return length of the dataset
        """
        return len(self._zip_filename_cars_zf.namelist())-1


## Instantiate the Dataset and DataLoader

In [None]:
train_dataset = CarDataset('train.zip', 'train_masks.zip', transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size) 

# Show an Example (for Sanity)

In [None]:
image, mask = train_dataset[2]

plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.array(image).transpose((1,2,0)))
plt.subplot(1,2,2)
plt.imshow(np.array(mask).squeeze())
plt.show()

## Create Network

In [None]:
class contracting(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(3, 64, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU())

        self.layer2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.ReLU())

        self.layer3 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU())

        self.layer4 = nn.Sequential(nn.Conv2d(256, 512, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(512, 512, 3, stride=1, padding=1), nn.ReLU())

        self.layer5 = nn.Sequential(nn.Conv2d(512, 1024, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(1024, 1024, 3, stride=1, padding=1), nn.ReLU())

        self.down_sample = nn.MaxPool2d(2, stride=2)


    def forward(self, X):
        X1 = self.layer1(X)
        X2 = self.layer2(self.down_sample(X1))
        X3 = self.layer3(self.down_sample(X2))
        X4 = self.layer4(self.down_sample(X3))
        X5 = self.layer5(self.down_sample(X4))
        return X5, X4, X3, X2, X1


class expansive(nn.Module):
    def __init__(self):
        super().__init__()

        self.layer1 = nn.Conv2d(64, 2, 3, stride=1, padding=1)

        self.layer2 = nn.Sequential(nn.Conv2d(128, 64, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU())

        self.layer3 = nn.Sequential(nn.Conv2d(256, 128, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.ReLU())

        self.layer4 = nn.Sequential(nn.Conv2d(512, 256, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU())

        self.layer5 = nn.Sequential(nn.Conv2d(1024, 512, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(512, 512, 3, stride=1, padding=1), nn.ReLU())

        self.up_sample_54 = nn.ConvTranspose2d(1024, 512, 2, stride=2)

        self.up_sample_43 = nn.ConvTranspose2d(512, 256, 2, stride=2)

        self.up_sample_32 = nn.ConvTranspose2d(256, 128, 2, stride=2)

        self.up_sample_21 = nn.ConvTranspose2d(128, 64, 2, stride=2)


    def forward(self, X5, X4, X3, X2, X1):
        X = self.up_sample_54(X5)
        X4 = torch.cat([X, X4], dim=1)
        X4 = self.layer5(X4)

        X = self.up_sample_43(X4)
        X3 = torch.cat([X, X3], dim=1)
        X3 = self.layer4(X3)

        X = self.up_sample_32(X3)
        X2 = torch.cat([X, X2], dim=1)
        X2 = self.layer3(X2)

        X = self.up_sample_21(X2)
        X1 = torch.cat([X, X1], dim=1)
        X1 = self.layer2(X1)

        X = self.layer1(X1)

        return X


class unet(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Encoder
        self.down = contracting()
        
        # Decoder
        self.up = expansive()

    def forward(self, X):
        # Encoder
        X5, X4, X3, X2, X1 = self.down(X)
        
        # Decoder
        X = self.up(X5, X4, X3, X2, X1)
        return X
    

## Train Network

## Create network optimizer and loss function

In [None]:
model = unet()

model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss = torch.nn.CrossEntropyLoss()

## Run the training

In [None]:
epoch_loss = []

for epoch in range(epochs):
    print('='*30)
    print('Epoch {} / {}'.format(epoch, epochs))
    
    # Set variables
    correct = 0
    overlap = 0 
    union = 0
    _len = 0
    l = 0
    count = 0
    
    # Loop over the batches
    for index, (X, Y) in enumerate(train_dataloader):
        print(f'\tBatch {index}')
        
        if device is not None:
            X = X.to(device)
            Y = Y.to(device)

        # Call the model (image to mask)
        R = model(X)

        # Compute the loss
        # L = loss(R[:,0], Y[:,0].long())
        L = loss(R, Y.long())

        # Do PyTorch stuff
        optimizer.zero_grad()
        L.backward()
        optimizer.step()

        # Compute Stats
        pred = R.data.max(1)[1]
        pred_sum, label_sum, overlap_sum = (pred==1).sum(), (Y==1).sum(), (pred*Y==1).sum()
        print(f'\t label_sum {label_sum}  pred_sum {pred_sum}  overlap_sum {overlap_sum}')
        
#         plt.figure(1)
#         plt.subplot(1,2,1)
#         plt.imshow(Y[0].cpu())
#         plt.clim((0,1))
#         plt.subplot(1,2,2)
#         plt.imshow(pred[0].cpu())
#         plt.clim((0,1))
#         plt.show()

        union_sum = pred_sum+label_sum-overlap_sum

        # IoU for accuracy
        overlap = overlap+overlap_sum.data.item()
        union = union+union_sum.data.item()
        l = l+L.data.item()
        count = count+1

    _loss = l/count
    _accuracy = overlap/union
    string = "epoch: {}, accuracy: {}, loss: {}".format(epoch, _accuracy, _loss)
    print(string)
    
    epoch_loss.append(_loss)

# Plot the Loss Curve

In [None]:
plt.figure()
plt.plot(epoch_loss)
plt.title('Epoch Loss')

# Plot an Example Result

In [None]:
plt.figure(1)
plt.subplot(1,2,1)
plt.imshow(Y[3].cpu())
plt.clim((0,1))
plt.subplot(1,2,2)
plt.imshow(pred[3].cpu())
plt.clim((0,1))
plt.show()
