<a href="https://colab.research.google.com/github/jean980308/Deep-Learning-Based-Image-Recognition/blob/main/UNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Loss

In [None]:
import torch
import torch.nn as nn

def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()    

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    
    return loss.mean()

# helper

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_img_array(img_array, ncol=3):
    nrow = len(img_array) // ncol

    f, plots = plt.subplots(nrow, ncol, sharex='all', sharey='all', figsize=(ncol * 4, nrow * 4))

    for i in range(len(img_array)):
        plots[i // ncol, i % ncol]
        plots[i // ncol, i % ncol].imshow(img_array[i])

from functools import reduce
def plot_side_by_side(img_arrays):
    flatten_list = reduce(lambda x,y: x+y, zip(*img_arrays))

    plot_img_array(np.array(flatten_list), ncol=len(img_arrays))

import itertools
def plot_errors(results_dict, title):
    markers = itertools.cycle(('+', 'x', 'o'))

    plt.title('{}'.format(title))

    for label, result in sorted(results_dict.items()):
        plt.plot(result, marker=next(markers), label=label)
        plt.ylabel('dice_coef')
        plt.xlabel('epoch')
        plt.legend(loc=3, bbox_to_anchor=(1, 0))

    plt.show()

def masks_to_colorimg(masks):
    colors = np.asarray([(201, 58, 64), (242, 207, 1), (0, 152, 75), (101, 172, 228),(56, 34, 132), (160, 194, 56)])

    colorimg = np.ones((masks.shape[1], masks.shape[2], 3), dtype=np.float32) * 255
    channels, height, width = masks.shape

    for y in range(height):
        for x in range(width):
            selected_colors = colors[masks[:,y,x] > 0.5]

            if len(selected_colors) > 0:
                colorimg[y,x,:] = np.mean(selected_colors, axis=0)

    return colorimg.astype(np.uint8)

# simulation

In [None]:
import numpy as np
import random

def generate_random_data(height, width, count):
    x, y = zip(*[generate_img_and_mask(height, width) for i in range(0, count)])

    X = np.asarray(x) * 255
    X = X.repeat(3, axis=1).transpose([0, 2, 3, 1]).astype(np.uint8)
    Y = np.asarray(y)

    return X, Y

def generate_img_and_mask(height, width):
    shape = (height, width)

    triangle_location = get_random_location(*shape)
    circle_location1 = get_random_location(*shape, zoom=0.7)
    circle_location2 = get_random_location(*shape, zoom=0.5)
    mesh_location = get_random_location(*shape)
    square_location = get_random_location(*shape, zoom=0.8)
    plus_location = get_random_location(*shape, zoom=1.2)

    # Create input image
    arr = np.zeros(shape, dtype=bool)
    arr = add_triangle(arr, *triangle_location)
    arr = add_circle(arr, *circle_location1)
    arr = add_circle(arr, *circle_location2, fill=True)
    arr = add_mesh_square(arr, *mesh_location)
    arr = add_filled_square(arr, *square_location)
    arr = add_plus(arr, *plus_location)
    arr = np.reshape(arr, (1, height, width)).astype(np.float32)

    # Create target masks
    masks = np.asarray([
        add_filled_square(np.zeros(shape, dtype=bool), *square_location),
        add_circle(np.zeros(shape, dtype=bool), *circle_location2, fill=True),
        add_triangle(np.zeros(shape, dtype=bool), *triangle_location),
        add_circle(np.zeros(shape, dtype=bool), *circle_location1),
         add_filled_square(np.zeros(shape, dtype=bool), *mesh_location),
        # add_mesh_square(np.zeros(shape, dtype=bool), *mesh_location),
        add_plus(np.zeros(shape, dtype=bool), *plus_location)
    ]).astype(np.float32)

    return arr, masks

def add_square(arr, x, y, size):
    s = int(size / 2)
    arr[x-s,y-s:y+s] = True
    arr[x+s,y-s:y+s] = True
    arr[x-s:x+s,y-s] = True
    arr[x-s:x+s,y+s] = True

    return arr

def add_filled_square(arr, x, y, size):
    s = int(size / 2)

    xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]

    return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, yy > y - s, yy < y + s]))

def logical_and(arrays):
    new_array = np.ones(arrays[0].shape, dtype=bool)
    for a in arrays:
        new_array = np.logical_and(new_array, a)

    return new_array

def add_mesh_square(arr, x, y, size):
    s = int(size / 2)

    xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]

    return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, xx % 2 == 1, yy > y - s, yy < y + s, yy % 2 == 1]))

def add_triangle(arr, x, y, size):
    s = int(size / 2)

    triangle = np.tril(np.ones((size, size), dtype=bool))

    arr[x-s:x-s+triangle.shape[0],y-s:y-s+triangle.shape[1]] = triangle

    return arr

def add_circle(arr, x, y, size, fill=False):
    xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
    circle = np.sqrt((xx - x) ** 2 + (yy - y) ** 2)
    new_arr = np.logical_or(arr, np.logical_and(circle < size, circle >= size * 0.7 if not fill else True))

    return new_arr

def add_plus(arr, x, y, size):
    s = int(size / 2)
    arr[x-1:x+1,y-s:y+s] = True
    arr[x-s:x+s,y-1:y+1] = True

    return arr

def get_random_location(width, height, zoom=1.0):
    x = int(width * random.uniform(0.1, 0.9))
    y = int(height * random.uniform(0.1, 0.9))

    size = int(min(width, height) * random.uniform(0.06, 0.12) * zoom)

    return (x, y, size)

# Synthetic images for demo training

In [None]:
import matplotlib.pyplot as plt
import numpy as np


input_images,target_masks=generate_random_data(256,256, count=3)

print("input_images shape and range" ,input_images.shape, input_images.min(),input_images.max())
print("target_masks shape and range" ,target_masks.shape, target_masks.min(),target_masks.max())


input_images_rgb = [x.astype(np.uint8) for x in input_images]

target_masks_rgb = [masks_to_colorimg(x) for x in target_masks]






# Synthetic images for demo training

In [None]:
plot_side_by_side([input_images_rgb, target_masks_rgb])

# Prepare Dataset and DataLoader

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision    import transforms , datasets, models

class SimDataset(Dataset):
  def __init__(self, count, transform=None):
    self.input_images, self.target_masks =generate_random_data(256,256,count=count)
    self.transform = transform

  def __len__(self):
    return len(self.input_images)

  def __getitem__(self,idx):
    image = self.input_images[idx]
    mask = self.target_masks[idx]
    if self.transform:
      image = self.transform(image)

    return [image, mask]

trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

train_set= SimDataset(1000,transform=trans)
val_set = SimDataset(200,transform=trans)
image_datasets = {
    'train':train_set,'val':val_set
}

batch_size = 25

dataloaders = {
    'train': DataLoader(train_set,batch_size=batch_size, shuffle=True, num_workers=0),
    'val':  DataLoader(val_set ,batch_size=batch_size, shuffle=True, num_workers=0)
}



# Check the outputs from DataLoader

In [None]:
import torchvision.utils

def reverse_transform(inp):
  inp = inp.numpy().transpose((1,2,0))
  mean=  np.array([0.485,0.456,0.406])
  std=  np.array([0.229,0.224,0.225])
  inp=std*inp+mean
  inp=np.clip(inp,0,1)
  inp=(inp*255).astype(np.uint8)

  return inp

inputs , masks =next(iter(dataloaders['train']))

print(inputs.shape,masks.shape)

plt.imshow(reverse_transform(inputs[3]))

# DoubleCpmv

In [None]:
def double_conv(in_channels,out_channels):
  return nn.Sequential(
      nn.Conv2d(in_channels,out_channels,3,padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_channels,out_channels,3,padding=1),
      nn.ReLU(inplace=True)

  )

# Define a UNet module

In [None]:
class UNet(nn.Module):
  def __init__(self, n_channels, n_class):
    super().__init__()

    self.dconv_down1 = double_conv(n_channels, 64)
    self.dconv_down2 = double_conv(64,128)
    self.dconv_down3 = double_conv(128,256)
    self.dconv_down4 = double_conv(256,512)

    self.maxpool = nn.MaxPool2d(2)

    self.dconv_up3 = double_conv(256+512, 256)
    self.dconv_up2 = double_conv(128+256, 128)
    self.dconv_up1 = double_conv(128+64 , 64)

    self.conv_last = nn.Conv2d(64, n_class, 1)

  def forward(self,x):

    conv1 = self.dconv_down1(x)
    x= self.maxpool(conv1)

    conv2 = self.dconv_down2(x)
    x= self.maxpool(conv2)

    conv3 = self.dconv_down3(x)
    x= self.maxpool(conv3)

    x=self.dconv_down4(x)

    x=nn.functional.interpolate(x, scale_factor=2, mode='bilinear',align_corners=True)
    x=torch.cat([x,conv3],dim=1)

    x=self.dconv_up3(x)
    x=nn.functional.interpolate(x, scale_factor=2, mode='bilinear',align_corners=True)
    x=torch.cat([x,conv2],dim=1)

    x=self.dconv_up2(x)
    x=nn.functional.interpolate(x, scale_factor=2, mode='bilinear',align_corners=True)
    x=torch.cat([x,conv1],dim=1)

    x=self.dconv_up1(x)

    out=self.conv_last(x)

    return out

# Define loss and metric of training

In [None]:
from collections import defaultdict
import torch.nn.functional as F

checkpoint_path = "checkpoint.pth"

def calc_loss(pred, target, metrics, bce_weight=0.5):
  bce = F.binary_cross_entropy_with_logits(pred,target)
  pred=torch.sigmoid(pred)
  dice=dice_loss(pred,target)

  loss=bce*bce_weight+dice*(1-bce_weight)

  metrics['bce'] +=bce.data.cpu().numpy()*target.size(0)
  metrics['dice'] +=bce.data.cpu().numpy()*target.size(0)
  metrics['loss'] +=bce.data.cpu().numpy()*target.size(0)

  return loss

def print_metrics(metrics, epoch_samples, phase):
  outputs =[]
  for k in metrics.keys():
    outputs.append("{}: {:4f}".format(k,metrics[k]/epoch_samples))
  print("{}:{}".format(phase,",".join(outputs)))

  



# Define the main training loop

In [None]:
def train_model(model,optimizer,scheduler,num_epochs=25):
  best_loss=1e10

  for epoch in range(num_epochs):
    print('Epoch{}/{}'.format(epoch,num_epochs-1))
    print('-'*10)

    since=time.time()

    for phase in ['train','val']:
      if phase =='train':
        model.train()
      else:
        model.eval()
      metrics= defaultdict(float)
      epoch_samples=0

      for inputs, labels in dataloaders[phase]:
        inputs=inputs.to(device)
        labels=labels.to(device)

        optimizer.zero_grad()

        with torch.set_grad_enabled(phase == 'train'):
          outputs=model(inputs)
          loss = calc_loss(outputs,labels,metrics)

          if phase == 'train':
            loss.backward()
            optimizer.step()
        epoch_samples += inputs.size(0)
      
      print_metrics(metrics,epoch_samples,phase)
      epoch_loss = metrics['loss'] / epoch_samples

      if phase == 'train':
        scheduler.step()
        for param_group in optimizer.param_groups:
          print("LR",param_group['lr'])

      if phase == 'val' and epoch_loss < best_loss:
        print(f"saving best model to {checkpoint_path}")
        best_loss = epoch_loss 
        torch.save(model.state_dict(), checkpoint_path)

    time_elapsed=time.time() - since
    print('{:.0f}m {:.0f}s' .format(time_elapsed //60,time_elapsed % 60))
  
  print('Best val loss: {:4f}'.format(best_loss))

  model.load_state_dict(torch.load(checkpoint_path))
  return model
  

# Training process

In [None]:
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time

num_class=6

#device =torch.device("cpu")
model = UNet(3,num_class)

optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft,step_size=8, gamma=0.1)

model = train_model(model,optimizer_ft,exp_lr_scheduler,num_epochs=5)
