<a href="https://colab.research.google.com/github/circle0103/image_segmentation/blob/main/Copy_of_EML_Assignment_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


In [None]:
class retinaDataset(torch.utils.data.Dataset):
  def __init__(self, length = 15):
      'Initialization'
      self.X = []
      self.y = []
      self.length = length
  
  def __len__(self):
      'Denotes the total number of samples'
      return self.length

  def transform(self, image):
      X = transforms.Compose([transforms.ToTensor()])(image)
      return X

  def transform_norm(self, image):
    # data augmentation
    X = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
    # X1 = transforms.Compose(transforms.CenterCrop(10),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), transforms.ToTensor())(image)
    return X
    

  
  
  def __getitem__(self, index):
      'Generates one sample of data'
      # Select sample
      image = Image.open(f'{index + 1}_training.tif')
      label = Image.open(f'{index + 1}_manual1.gif')
      X = self.transform_norm(image)
      y = self.transform(label)
      return X, y
        

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
full_set = retinaDataset()
fullloader = torch.utils.data.DataLoader(full_set, batch_size=1, shuffle=False, num_workers=0)
trainset, valset = torch.utils.data.random_split(full_set, [10, 5])
trainloader = torch.utils.data.DataLoader(trainset, batch_size=3, shuffle=True, num_workers=0)
valloader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=True, num_workers=0)


cuda:0


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def imshow_color(img,i):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    print(npimg.shape)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(f'image {i*5+1} to {i*5+5}')
    plt.show()

def imshow(img):
    #img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

def imshow_gpu(img):
    #img = img / 2 + 0.5     # unnormalize
    npimg = img.cpu().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()



In [None]:
def imshow_rgb(i):
  im = Image.open(f'{i}_training.tif').convert('RGB')
  #display(im.getchannel('R'))
  #display(im.getchannel('G'))
  display(im.getchannel('B'))


In [None]:
for i, (data, label) in enumerate(fullloader):
  imshow_rgb(i+1)

In [None]:
for i, (data, label) in enumerate(fullloader):
  imshow_color(torchvision.utils.make_grid(data), i)
  

In [None]:
dataiter = iter(trainloader)
images, labels = dataiter.next()

We use U-Net architecture to train our model since Unet performed well emperically and has been an popular convolutional neural network architecture for semantic segmentations. 

U-Net consists of a contracting path and an expansive path. 

* The contracting path follows a typical convolutional nerual 
neural network architecture, in which we first apply a 3x3 valid convoltions to the input image with 3 input channels(RGB) and get 64 output channels. Then we apply ReLU to this output. Next we down-sample with a 2x2 max pooling operations with stride = 2, which doubles the number of feature channels. Repeat this process 5 times. At each iteration, the number of input channels is the number of output channels of the previous convolutional layer. At the end of the fifth iteration, however, we performed an up sampling insteatd of a down-sampling, which starts the expansive path of the architecture.

* At the beginning of the expansive path, we perform a 2x2 up-convolution that halves the number of feature channels, then concatenate it with the corresponding output from the contractiing path (see figure above). Next we again apply a 3x3 valid convolution followed by ReLU to the output from the previous step. Repeat this process five times. At each iteration, the number of input channels is the number of output channels of the previous convolutional layer. At the end of the fifth iteration, instead of another up-sampling, we apply a 1x1 convolution and get a output segmentation map with 2 channels, with which each channel represents one label(blood vessel and non-blood vessel in our case).

In the concatenation step, cropping of the output from contracting path is needed since we used valid convolutions, from which we lost the border pixels. 

In [None]:
class doubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(doubleConv,self).__init__()
    self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride = 1, padding = 'same') 
    self.relu = nn.ReLU()
    self.norm = nn.BatchNorm2d(out_channels)
    self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride = 1, padding = 'same')

  def forward(self,x):
    return nn.Sequential(self.conv1,self.norm, self.relu, self.conv2, self.norm, self.relu)(x)



In [None]:
class UNET(nn.Module):
  def __init__(self, in_channels = 3, out_channels = 1, channel_list = [64,128]):
    super(UNET,self).__init__()
    self.contract = nn.ModuleList()
    self.expand = nn.ModuleList()
    self.expand_UpSamp = nn.ModuleList()
    self.pool = nn.MaxPool2d(2, 2)


    # Contraction path 
    for num in channel_list:
      self.contract.append(doubleConv(in_channels, num))
      in_channels = num

    # The end of contraction path
    self.end_contract = doubleConv(channel_list[-1], channel_list[-1]*2)

    # In expansion path, we concatenate the inputs with outputs from contraction path, so number of input channels are doubled.
    for num in channel_list[::-1]:
      # up conv
      self.expand_UpSamp.append(nn.ConvTranspose2d(num*2, num, 2, 2))

      # Double convolution 
      self.expand.append(doubleConv(num*2, num))

    # The last step of expansion path, which yields the final output
    self.final = nn.Conv2d(channel_list[0],out_channels, kernel_size = 1)


  def forward(self, x):
    # store the outputs from contraction path for later concatenations
    concat_list = []

    # Perform the contraction path of UNET
    for contract in self.contract:
      x = contract(x)
      concat_list.append(x)
      x = self.pool(x)

    x = self.end_contract(x)


    # start concatenation from the end of the list concat_list. Reverse the concat_list for convenience

    concat_list = concat_list[::-1]
    for i in range(len(self.expand)):
      x = self.expand_UpSamp[i](x)

      # match the skip connection with the size of x before concatenating.
      # pad the image with smaller size, x, with 0's so that it has the same number of pixels as concat_list[i]
      
      if concat_list[i].shape != x.shape:
        x = transforms.functional.resize(x, size = concat_list[i].shape[2:])

      # concatenate the skip connection with x, along the channels
      concat = torch.cat((concat_list[i],x), dim = 1)

      # Double convolution on the concatenated inputs 
      x = self.expand[i](concat)

    x = self.final(x)

    return torch.sigmoid(x)

      
  


    


We define the loss function to be the cross entropy loss and an optimizer to be stochastic gradient descent

In [None]:
unet = UNET()


In [None]:
unet.to(device)
optimizer = optim.SGD(unet.parameters(), lr=0.001, momentum=0.9)

In [None]:
from re import I
loss_list = []
for epoch in range(20):  # loop over the dataset twice
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        pred = unet(inputs)
        loss_list = []
        loss = nn.BCELoss()(pred, labels)
        #loss = nn.KLDivLoss()(pred, labels)
        #loss = nn.BCEwithLogitLoss()(pred, labels)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
    loss_list.append(running_loss/(i+1))
        # print statistics
    print('Avg. loss this epoch: ', loss.item())


print('Finished Training')

In [None]:
ls = list(range(1,1+len(loss_list)))

plt.plt(ls,loss_list)
plt.title("Training Binary Cross Entropy Loss")
plt.xlabel("epoch")
plt.ylabel("epoch")
plt.xticks(ls)
plt.show

Train on a smaller dataset

In [None]:
smallset, _ = torch.utils.data.random_split(trainset, [5, 5])
smallloader = torch.utils.data.DataLoader(smallset, batch_size=3, shuffle=True, num_workers=0)

To check if our model is overfitted or underfitted, we use the validation set and compute the validation error

In [None]:
# validation error
acc_loss = 0
for i, data in enumerate(valloader, 0):
    # get the inputs
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)
    pred = unet(inputs)
    #loss = nn.CrossEntropyLoss(reduce = 'none')(pred, labels)
    loss = nn.BCELoss()(pred, labels)
    loss_list.append(loss.item())

    # print statistics
    acc_loss += loss.item()

print("Validation error:", acc_loss/(i+1))

Validation error: 0.2518784314393997


Our model seems underfitted since the training error is still high after two rounds of training. The dataset we have is too small in this case We do data augmentation to resolve this problem


In [None]:
class retinaDataset2(torch.utils.data.Dataset):
  def __init__(self, length=15):
      'Initialization'
      self.X = []
      self.y = []
      self.length = length
  
  def __len__(self):
      'Denotes the total number of samples'
      return self.length

  def transform(self, image):
      X = transforms.RandomChoice([transforms.RandomRotation([-90,90]), transforms.RandomHorizontalFlip(),transforms.RandomAutocontrast(),transforms.RandomVerticalFlip()])(image)
      X = transforms.ToTensor()(X)

      return X


  def __getitem__(self, index):
      'Generates one sample of data'
      # Select sample
      image = Image.open(f'{index + 1}_training.tif')
      label = Image.open(f'{index + 1}_manual1.gif')
      X = self.transform(image)
      #X1 = self.transform_aug1(image)
      y = self.transform(label)
      #y1 = self.transform_aug1(label)
      return X, y

In [None]:
loss_list_aug = []
full_aug_set = retinaDataset2()
aug_train, aug_val = torch.utils.data.random_split(full_aug_set, [10, 5])
for epoch in range(20):  # loop over the dataset multiple times
    trainloader_aug = torch.utils.data.DataLoader(aug_train, batch_size=3, shuffle=True, num_workers=0)
    valloader_aug = torch.utils.data.DataLoader(aug_val, batch_size=1, shuffle=True, num_workers=0)
    acc_loss = 0
    for i, data in enumerate(trainloader_aug, 0):
        # get the inputs
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        pred = unet(inputs)
        loss = nn.BCELoss()(pred, labels)
        acc_loss += loss.item()
        loss.backward()
        optimizer.step()

    loss_list_aug.append(acc_loss/(i+1))
    # print statistics
    print('Avg. loss this epoch: ', acc_loss/(i+1))


print('Finished Training')




In [None]:
#with torch.no_grad():
for i, data in enumerate(trainloader, 0):
    images, labels = data
    images, labels = images.to(device), labels.to(device)

    outputs = unet(images)
    imshow_gpu(torchvision.utils.make_grid(outputs))


In [None]:
# validation error
acc_loss = 0
for i, data in enumerate(valloader, 0):
    # get the inputs
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)
    pred = unet(inputs)
    #loss = nn.CrossEntropyLoss(reduce = 'none')(pred, labels)
    loss = nn.BCELoss()(pred, labels)
    loss_list.append(loss.item())

    # print statistics
    acc_loss += loss.item()

print("Validation error:", acc_loss/(i+1))


Validation error: 0.3114476203918457


In [None]:
def imshow2(img):
    #img = img / 2 + 0.5     # unnormalize
    npimg = img.detach().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
class testDataset(torch.utils.data.Dataset):
  def __init__(self, length = 5):
      'Initialization'
      self.length = length
  
  def __len__(self):
      'Denotes the total number of samples'
      return self.length

  def transform(self, image):
    X = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
    return X

  
  def __getitem__(self, index):
      'Generates one sample of data'
      # Select sample
      image = Image.open(f'{index + 1}_test.tif')
      X = self.transform(image)
      return X

In [None]:
testSet = testDataset()
testloader = torch.utils.data.DataLoader(testSet, batch_size=1, shuffle=False, num_workers=0)

In [None]:
def single_dice_coef(label, pred):
    label = label.detach().cpu().numpy()
    pred = np.array((pred > 0.5).float())
    intersection = np.sum(label * pred)
    if (np.sum(label)==0) and (np.sum(pred)==0):
        return 1
    return (2*intersection) / (np.sum(label) + np.sum(pred))


acc_dice = 0
for i, data in enumerate(valloader_aug,0):
  images, labels = data
  pred = unet(images.to(device))
  dice = single_dice_coef(labels, pred.cpu())
  acc_dice += dice
  print("Dice coefficient: ", dice)

print("Average dice coefficient:",acc_dice/(i+1) )
  
  


Dice coefficient:  0.12690446316544182
Dice coefficient:  0.11412106262688
Dice coefficient:  0.14114550240537352
Dice coefficient:  0.5625997769370393
Dice coefficient:  0.09195789068884296
4
Average dice coefficient: 0.2073457391647155


In [None]:
for i, data in enumerate(testloader, 0):
    images= data
    images= images.to(device)

    outputs = unet(images)
    outputs = (outputs > 0.3).float()
    plt.title(f"Predition of test image {i+1}")
    imshow_gpu(torchvision.utils.make_grid(outputs))