<a href="https://colab.research.google.com/github/dachanh/30-days-coding/blob/master/unet_baseline1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip3 install torch===1.3.1 torchvision===0.4.2 -f https://download.pytorch.org/whl/torch_stable.html
!pip3 install easycolab

In [0]:
!pip install opencv-python

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')
import easycolab as ec
ec.mount()

In [0]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset , DataLoader
from torchvision import transforms, datasets, models 
from torchsummary import summary
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2
from os.path import splitext
from os import listdir
import numpy as np
from torch import optim
from tqdm import tqdm
import os
from glob import glob

UNET

In [0]:
class DoubleConv(nn.Module):
  def __init__(self,in_channel,out_channel):
    super().__init__()
    self.double_conv = nn.Sequential(
      nn.Conv2d(in_channel,out_channel,kernel_size=3,padding= 1),
      nn.BatchNorm2d(out_channel),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_channel,out_channel,kernel_size=3,padding= 1),
      nn.BatchNorm2d(out_channel),
      nn.ReLU(inplace=True))
  def forward(self,x):
    return self.double_conv(x)

In [0]:
class ContractLayer(nn.Module):
  def __init__(self,in_channel,out_channel):
    super().__init__()
    self.downsampling =  nn.Sequential(
        nn.MaxPool2d(2),
        DoubleConv(in_channel,out_channel)
    )
  
  def forward(self,x):
    return self.downsampling(x)

In [0]:
class Expansivelayer(nn.Module):
  def __init__(self,in_channel,out_channel,bilinear=True):
    super().__init__()
    if bilinear:
      self.upsampling = nn.Upsample(scale_factor=2,mode='bilinear',align_corners=True)
    else:
      self.upsampling = nn.ConvTranspose2d(in_channel //2 , out_channel//2,kernel_size=2,stride=2)
    self.conv = DoubleConv(in_channel,out_channel)
  def forward(self,x1,x2):
    x1 = self.upsampling(x1)
    diffY = x2.size()[2] - x1.size()[2]
    diffX = x2.size()[3] - x1.size()[3]

    x1 = F.pad(x1,[diffX // 2 , diffX - diffX //2,diffY // 2, diffY - diffY // 2])

    x = torch.cat([x2,x1] ,dim=1)
    return self.conv(x)

In [0]:
class Finallayer(nn.Module):
  def __init__(self,in_channel,out_channel):
    super(Finallayer,self).__init__()
    self.conv = nn.Conv2d(in_channel,out_channel,kernel_size= 1)
  def forward(self,x):
    return self.conv(x)

In [0]:
class UNet(nn.Module):
  def __init__(self,n_channels,n_classes,bilinear=True):
    super(UNet,self).__init__()
    self.n_classes = n_classes
    self.n_channels = n_channels
    self.bilinear = bilinear

    self.input_layer =  DoubleConv(self.n_channels,64)
    self.downLayer_1 = ContractLayer(64,128)
    self.downLayer_2  = ContractLayer(128,256)
    self.downLayer_3 = ContractLayer(256,512)
    self.downLayer_4 = ContractLayer(512,512)

    self.uplayer_1 = Expansivelayer(1024,256,bilinear)
    self.uplayer_2 = Expansivelayer(512,128,bilinear)
    self.uplayer_3 = Expansivelayer(256,64,bilinear)
    self.uplayer_4 = Expansivelayer(128,64,bilinear)

    self.output = Finallayer(64,self.n_classes)

  def forward(self,x):
    x1 = self.input_layer(x)
    x2 = self.downLayer_1(x1)
    x3 = self.downLayer_2(x2)
    x4 = self.downLayer_3(x3)
    x5 = self.downLayer_4(x4)
    x = self.uplayer_1(x5,x4)
    x = self.uplayer_2(x,x3)
    x = self.uplayer_3(x,x2)
    x = self.uplayer_4(x,x1)
    the_last_layer = self.output(x)
    return the_last_layer

In [0]:
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')
model =UNet(n_channels=3,n_classes=1)
model = model.to(device)
summary(model,input_size=(3,224,224))

DATA LOADER

In [0]:
class BasicDataset(Dataset):
  def __init__(self,dir_image,dir_mask):
    self.dir_image = dir_image
    self.dir_mask = dir_mask 
    self.idx = [splitext(it)[0] for it in listdir(dir_image)]
    print(self.idx)
  def __len__(self):
    return len(self.idx) 
  def preprocess(self,path):
    image = cv2.imread(path)
    image = cv2.resize(image,(224,224))
    image_nd = np.array(image)
    if len(image_nd.shape) == 2:
      img_nd = np.expand_dim(image_nd,axis=2)
    image_transpose = image_nd.transpose((2,0,1))
    if image_transpose.max() > 1: 
      image_transpose = image_transpose/255
    return image_transpose
  def __getitem__(self,i):
    index = self.idx[i]
    mask_file = self.dir_mask + index + '.jpg'
    image_file = self.dir_image + index + '.jpg'
    img = self.preprocess(image_file)
    mask = self.preprocess(mask_file)
    return {'input': torch.from_numpy(img),'output':torch.from_numpy(mask)}

LOSS FUNCTION

In [0]:
def soft_dice_loss(y_target,y_pred):
  eps = 1e-4
  y_pred = F.sigmoid(y_pred)
  target = y_target.contiguous()
  predict = y_pred.contiguous()
  intersection = (target*predict).sum(dim=2).sum(dim=2)
  union = predict.sum(dim=2).sum(dim=2)  + target.sum(dim=2).sum(dim=2)
  loss =(1 - ((2* intersection.float() + eps)/(union.float()+eps))).mean()
  return loss 

In [0]:
def dice_cofficient(y_target,y_pred):
  eps = 1e-4
  y_pred = F.sigmoid(y_pred)
  target = y_target.contiguous()
  predict = y_pred.contiguous()
  intersection = (target*predict).sum(dim=2).sum(dim=2)
  union = predict.sum(dim=2).sum(dim=2)  + target.sum(dim=2).sum(dim=2)
  cofficident = ((2*intersection.float() + eps)/(union.float()+eps)).mean()
  return cofficident

In [0]:
!ls cds/datasheet

CONFIG

In [0]:
root = 'cds/datasheet'

In [0]:
dirImage_train = root+ '/X_train/'
dirMask_train = root  + '/y_train/'
dirImage_test = root + '/X_test/'
dirMask_test = root + '/y_test/'

In [0]:
batch_size = 4
epochs = 50
lr = 1e-3
optimizer = optim.RMSprop(model.parameters(),lr=lr,weight_decay=1e-8)

In [0]:
traindata = BasicDataset(dir_image=dirImage_train,dir_mask=dirMask_train)
testdata = BasicDataset(dir_image=dirImage_test,dir_mask=dirMask_test)

In [0]:
train_loader = DataLoader(traindata,batch_size=batch_size,shuffle=True,num_workers=8,pin_memory=True)
test_loader = DataLoader(testdata,batch_size=batch_size,shuffle=True,num_workers=8,pin_memory=True)

In [0]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
best_loss = 1e32
n_train =  int(len(traindata))
# print(n_train)
dir_checkpoint = root + '/checkpoint/'

In [0]:
for epoch in range(epochs):
  model.train()
  epoch_loss = 0
  epoch_cofficient = 0 
  epoch_sample = 0
  for batch in train_loader:
    image = batch['input']
    mask = batch['output']
    image = image.to(device=device,dtype=torch.float32)
    target = mask.to(device=device,dtype=torch.float32)
    predict =  model(image)
    loss = (soft_dice_loss(target,predict))
    epoch_loss += loss.data.cpu().numpy()*target.size(0)
    cofficient = (dice_cofficient(target,predict))
    epoch_cofficient += cofficient.data.cpu().numpy()*target.size(0)
    epoch_sample += image.size(0)
    optimizer.zero_grad()
    loss.backward() 
    optimizer.step()
  print('epoch {}: cofficient {} , loss {}'.format(epoch+1,epoch_cofficient/epoch_sample,epoch_loss/epoch_sample))
  epoch_loss = epoch_loss/epoch_sample
  if best_loss > epoch_loss:
    best_loss = epoch_loss
    if not os.path.isdir(dir_checkpoint):
      os.mkdir(dir_checkpoint)
    torch.save(model.state_dict(),dir_checkpoint+'model.pth')

Predict

In [0]:
model =UNet(n_channels=3,n_classes=1)
model = model.to(device)
model.load_state_dict(torch.load(dir_checkpoint+'model.pth', map_location=device))
model.eval()

In [0]:
!ls root

In [0]:
def preprocess(path):
  image = cv2.imread(path)
  image = cv2.resize(image,(224,224))
  image_nd = np.array(image)
  if len(image_nd.shape) == 2:
    img_nd = np.expand_dim(image_nd,axis=2)
  image_transpose = image_nd.transpose((2,0,1))
  if image_transpose.max() > 1: 
    image_transpose = image_transpose/255
  return image_transpose

In [0]:
output_dir = root + '/output_image_valid/'
for it in glob(dirImage_test+'*.jpg'):
  img = preprocess(it)
  img = torch.from_numpy(img)
  img = img.unsqueeze(0)
  img = img.to(device=device, dtype=torch.float32)
  with torch.no_grad():
    output = model(img)
  probs = F.sigmoid(output)
  probs = probs.squeeze(0)
  tf = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.ToTensor()
            ]
        )
  probs = tf(probs.cpu())
  full_mask = probs.squeeze().cpu().numpy()
  full_mask = np.array(full_mask*255,np.uint8)

  if not os.path.isdir(output_dir):
    os.mkdir(output_dir)
  cv2.imwrite(output_dir+splitext(os.path.basename(it))[0]+'.jpg',full_mask)
