In [42]:
from torchsummary import summary
from torchgeometry.losses import one_hot
import os
import pandas as pd
import numpy as np
from PIL import Image
import cv2
import time
import imageio
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch import Tensor
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import Resize, PILToTensor, ToPILImage, Compose, InterpolationMode
from collections import OrderedDict
import wandb

In [43]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [44]:
import segmentation_models_pytorch as smp

In [45]:
unet = smp.UnetPlusPlus(
    encoder_name = 'resnet50', 
    encoder_weights = 'imagenet', 
    in_channels = 3, 
    classes = 3
)

# Hyper parameters

In [46]:
num_classes = 3

epochs = 50 

learning_rate = 1e-4
batch_size = 64 
display_step = 50

loss_epoch_array = []
train_accuracy = []
test_accuracy = []
valid_accuracy = [] 

# DataLoader

In [47]:
class DataClass(Dataset):
    def __init__(self, images_path, masks_path, transform = None, augmentation = None):
        super(DataClass, self).__init__()
        
        images_list = os.listdir(images_path) 
        masks_list = os.listdir(masks_path)
        
        images_list = [images_path + image_name for image_name in images_list]
        masks_list = [masks_path + mask_name for mask_name in masks_list]
        
        self.iamges_list = images_list
        self.masks_list = masks_list
        self.transform = transform

    def __getitem__(self, index):
        img_path = self.images_list[index]
        mask_path = self.masks_list[index] 
        
        #open image and mask
        data = Image.open(img_path) 
        label = Image.open(mask_path) 

        #normalize 
        data = self.transform(data)/255
        label = self.transform(label)/255

        label = torch.where(label > 0.65, 1.0, 0.0)

        label[2,:, :] = 0.0001
        label = torch.argmax(label, 0).type(torch.int64)
        return data, label
    
    def __len__(self): 
        return len(self.iamges_list)


In [48]:
images_path = "D:/Polyp_segmentation/model/train/train/"
masks_path = "D:/Polyp_segmentation/model/train_gt/train_gt/"

**Transform**

In [49]:
transform = Compose([
    Resize((256, 256),interpolation = InterpolationMode.BILINEAR),
    PILToTensor()
])

In [50]:
unet_dataset = DataClass(images_path, masks_path, transform)

In [51]:
images_list = os.listdir(images_path)
len(unet_dataset)

1000

In [52]:
# split train and valid size
train_size = 0.9
valid_size = 0.1

train_set, valid_set = random_split(unet_dataset,
                                    [(int)(train_size*len(unet_dataset)),
                                     (int)(len(unet_dataset)) - (int)(train_size*len(unet_dataset))])

In [53]:
train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle= True)
valid_dataloader = DataLoader(valid_set, batch_size = batch_size, shuffle= True)

# Data Augmentation

In [54]:
from albumentations import (
    Compose,
    RandomRotate90,
    Flip,
    Transpose,
    ElasticTransform,
    GridDistortion,
    OpticalDistortion,
    RandomBrightnessContrast,
    HorizontalFlip,
    VerticalFlip,
    RandomGamma,
    RGBShift,
)

In [55]:
augmentation = Compose([
    HorizontalFlip(p = 0.5),
    VerticalFlip(p = 0.5), 
    RandomGamma (gamma_limit=(70, 130), eps=None, always_apply=False, p=0.2),
    RGBShift(p = 0.3, r_shift_limit = 10, g_shift_limit = 10, b_shift_limit = 10)
])

  RandomGamma (gamma_limit=(70, 130), eps=None, always_apply=False, p=0.2),


In [56]:
class SegmentDataClass(Dataset):
    def __init__(self, images_path, masks_path, transform=None, augmentation=None):
        super(SegmentDataClass, self).__init__()
        images_list = os.listdir(images_path)
        masks_list = os.listdir(masks_path)

        images_list = [images_path + image_name for image_name in images_list]
        masks_list = [masks_path + mask_name for mask_name in masks_list]

        self.images_list = images_list
        self.masks_list = masks_list
        self.transform = transform
        self.augmentation = augmentation

    def __getitem__(self, index):
        image_path = self.images_list[index]
        mask_path = self.masks_list[index] 

        data = Image.open(image_path)
        label = Image.open(image_path)

        if self.augmentation:
            augmented = self.augmentation(image = np.array(data), mask = np.array(label))
            data = Image.fromarray(augmented['image'])
            label = Image.fromarray(augmented['mask'])

        #Normalize
        data = self.transform(data)/255
        label = self.transform(label)/255 

        label = torch.where(label > 0.65, 1.0, 0.0)
        label[2, :, :] = 0.0001
        label = torch.argmax(label, 0). type(torch.int64)

        return data, label
    def __len__(self):
        return len(self.images_list)

In [57]:
augment_dataset = SegmentDataClass(images_path, masks_path, transform, augmentation)

In [58]:
train_augment_set, valid_augment_set = random_split(augment_dataset,
                                                    [(int)(train_size* len(augment_dataset)),
                                                     (int)(valid_size* len(augment_dataset))])

In [59]:
train_dataloader = DataLoader(train_augment_set, batch_size = batch_size, shuffle= True)

In [60]:
from torch.utils.data import ConcatDataset

combined_dataset = ConcatDataset([train_set, train_augment_set])
train_dataloader = DataLoader(combined_dataset,batch_size = batch_size, shuffle = True)

# Model

In [61]:
class Residual_Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Residual_Block, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size= 3, padding= 1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace= True)
        
        self.conv2 = nn.Conv2d (out_channels, out_channels, kernel_size = 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.bn3 = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout(p = 0.3)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.dropout(out)

        out = self.conv2(out)
        out = self.bn2(out)

        residual = self.conv1(residual) 
        residual = self.bn3(residual)

        out += residual

        out = self.relu(out)

        return out
    
        


In [62]:
class encoder_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(encoder_block, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size= 3, stride = 1, padding = 'same')
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size= 3, stride = 1, padding = 'same')
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.dropout = nn.Dropout(p = 0.3)
        self.max_pool = nn.MaxPool2d(kernel_size= 2, stride = 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.dropout(x)
         
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        next_layer = self.max_pool(x) 
        skip_layer = x

        return next_layer, skip_layer
    



In [69]:
class res_encoder_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(res_encoder_block, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size= 3,stride =1, padding= 1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        
        self.conv2 = nn.Conv2d (out_channels, out_channels, kernel_size = 3, stride = 1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.bn3 = nn.BatchNorm2d(out_channels)
        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.dropout = nn.Dropout(p = 0.3)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.dropout(out)

        out = self.conv2(out)
        out = self.bn2(out)

        residual = self.conv1(residual) 
        residual = self.bn3(residual)

        out += residual
        out = self.relu(out)

        next_layer = self.max_pool(out)
        skip_layer = out

        return next_layer, skip_layer
    
        


In [64]:
class decoder_block(nn.Module): 
    def __init__(self, in_channels, out_channels):
        super(decoder_block, self).__init__()
        self.transpose_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 2, stride = 2)

        self.conv1 = nn.Conv2d(2*out_channels, out_channels, kernel_size= 3, stride = 1, padding = 'same')
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size= 3, stride = 1, padding = 'same')
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p = 0.3)

    def forward(self, x, skip_layer):
        x = self.transpose_conv(x)
        x = torch.cat([x, skip_layer], axis = 1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.dropout(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        return x

In [72]:
class res_decoder_block(nn.Module):
    def __init__(self, in_channels, out_channels): 
        super(res_decoder_block, self).__init__()
        self.transpose_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 2, stride = 2)

        self.conv1 = nn.Conv2d(2*out_channels, out_channels, kernel_size= 3, stride = 1, padding = 'same')
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size= 3, stride = 1, padding = 'same')
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.bn3 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p = 0.3)

    def forward(self, x, skip_layer):
        x = self.transpose_conv(x) 
        x = torch.cat([x, skip_layer], axis = 1)

        residual = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.dropout(x)
        
        x = self.conv2(x)
        x = self.bn2(x)

        residual = self.conv1(residual)
        residual = self.bn3(residual)

        x += residual    
        x = self.relu(x)
        
        return x

In [66]:
class BottleNeck_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(BottleNeck_block, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding='same')
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding='same')
        
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.3)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.dropout(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        return x


In [75]:
class UNet(nn.Module):
    def __init__(self, num_classes = 3): 
        super(UNet, self).__init__()
        
        self.enc1 = encoder_block(3, 64)
        self.enc2 = encoder_block(64, 128)
        self.enc3 = res_encoder_block(128, 256)
        self.enc4 = encoder_block(256, 512)

        self.bottle_neck = BottleNeck_block(512, 1024)

        self.dec1 = decoder_block(1024, 512) 
        self.dec2 = res_decoder_block(512, 256)
        self.dec3 = decoder_block(256, 128)
        self.dec4 = decoder_block(128, 64)

        self.out = nn.Conv2d(64, num_classes, kernel_size= 1, stride= 1, padding = 'same')

    def forward(self, image):
        n1, s1 = self.enc1(image) 
        n2, s2 = self.enc2(n1)
        n3, s3 = self.enc3(n2)
        n4, s4 = self.enc4(n3)

        n5 = self.bottle_neck(n4)

        n6 = self.dec1(n5, s4)
        n7 = self.dec2(n6, s3)
        n8 = self.dec3(n7, s2) 
        n9 = self.dec4(n8, s1)

        output = self.out(n9)

        return output

In [76]:
model  = UNet()
summary(model, (3,256,256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]           1,792
       BatchNorm2d-2         [-1, 64, 256, 256]             128
              ReLU-3         [-1, 64, 256, 256]               0
           Dropout-4         [-1, 64, 256, 256]               0
            Conv2d-5         [-1, 64, 256, 256]          36,928
       BatchNorm2d-6         [-1, 64, 256, 256]             128
              ReLU-7         [-1, 64, 256, 256]               0
         MaxPool2d-8         [-1, 64, 128, 128]               0
     encoder_block-9  [[-1, 64, 128, 128], [-1, 64, 256, 256]]               0
           Conv2d-10        [-1, 128, 128, 128]          73,856
      BatchNorm2d-11        [-1, 128, 128, 128]             256
             ReLU-12        [-1, 128, 128, 128]               0
          Dropout-13        [-1, 128, 128, 128]               0
           Conv2d-14    

In [None]:
class CEDiceLoss(nn.Module):
    def __init__(self, weights) -> None:
        super(CEDiceLoss, self).__init__()
        self.eps: float = 1e-6
        self.weights: torch.Tensor = weights

    def forward(self, input: torch.Tensor,target: torch.Tensor) -> torch.Tensor:
        if not torch.is_tensor(input):
            raise TypeError("Input type is not a torch.Tensor. Got {}"
                            .format(type(input)))
        if not len(input.shape) == 4:
            raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}"
                             .format(input.shape))
        if not input.shape[-2:] == target.shape[-2:]:
            raise ValueError("input and target shapes must be the same. Got: {}"
                             .format(input.shape, input.shape))
        if not input.device == target.device:
            raise ValueError(
                "input and target must be in the same device. Got: {}" .format(
                    input.device, target.device))
        if not self.weights.shape[1] == input.shape[1]:
            raise ValueError("The number of weights must equal the number of classes")
        if not torch.sum(self.weights).item() == 1:
            raise ValueError("The sum of all weights must equal 1")
            
        # cross entropy loss
        celoss = nn.CrossEntropyLoss(self.weights)(input, target)
        
        # compute softmax over the classes axis
        input_soft = F.softmax(input, dim=1)

        # create the labels one hot tensor
        target_one_hot = one_hot(target, num_classes=input.shape[1],
                                 device=input.device, dtype=input.dtype)

        # compute the actual dice score
        dims = (2, 3)
        intersection = torch.sum(input_soft * target_one_hot, dims)
        cardinality = torch.sum(input_soft + target_one_hot, dims)

        dice_score = 2. * intersection / (cardinality + self.eps)
        
        dice_score = torch.sum(dice_score * self.weights, dim=1)
        
        return torch.mean(1. - dice_score) + celoss
#         return dice_score

In [None]:
def weights_init(model):
    if isinstance(model, nn.Linear):
        # Xavier Distribution
        torch.nn.init.xavier_uniform_(model.weight)

In [None]:
def save_model(model, optimizer, path):
    checkpoint = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, path)

def load_model(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint["model"])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer

In [None]:
# Train function for each epoch
def train(train_dataloader, valid_dataloader,learing_rate_scheduler, epoch, display_step):
    print(f"Start epoch #{epoch+1}, learning rate for this epoch: {learing_rate_scheduler.get_last_lr()}")
    start_time = time.time()
    train_loss_epoch = 0
    test_loss_epoch = 0
    last_loss = 999999999
    model.train()
    for i, (data,targets) in enumerate(train_dataloader):
        
        # Load data into GPU
        data, targets = data.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(data)

        # Backpropagation, compute gradients
        loss = loss_function(outputs, targets.long())
        loss.backward()

        # Apply gradients
        optimizer.step()
        
        # Save loss
        train_loss_epoch += loss.item()
        if (i+1) % display_step == 0:
#             accuracy = float(test(test_loader))
            print('Train Epoch: {} [{}/{} ({}%)]\tLoss: {:.4f}'.format(
                epoch + 1, (i+1) * len(data), len(train_dataloader.dataset), 100 * (i+1) * len(data) / len(train_dataloader.dataset), 
                loss.item()))
                  
    print(f"Done epoch #{epoch+1}, time for this epoch: {time.time()-start_time}s")
    train_loss_epoch/= (i + 1)
    
    # Evaluate the validation set
    model.eval()
    with torch.no_grad():
        for data, target in valid_dataloader:
            data, target = data.to(device), target.to(device)
            test_output = model(data)
            test_loss = loss_function(test_output, target)
            test_loss_epoch += test_loss.item()
            
    test_loss_epoch/= (i+1)
    
    return train_loss_epoch , test_loss_epoch

In [None]:
# Test function
def test(dataloader):
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for i, (data, targets) in enumerate(dataloader):
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            _, pred = torch.max(outputs, 1)
            test_loss += targets.size(0)
            correct += torch.sum(pred == targets).item()
    return 100.0 * correct / test_loss

In [None]:
# model = Unet(in_channels=3, num_classes = 3)

try:
    checkpoint = torch.load(pretrained_path)

    new_state_dict = OrderedDict()
    for k, v in checkpoint['model'].items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v
    # load params
    model.load_state_dict(new_state_dict)
    model = nn.DataParallel(model)
    model.to(device)
except:
    model.apply(weights_init)
    model = nn.DataParallel(model)
    model.to(device)

In [None]:
weights = torch.Tensor([[0.4, 0.55, 0.05]]).cuda()
loss_function = CEDiceLoss(weights)

# Define the optimizer (Adam optimizer)
optimizer = optim.Adam(params=model.parameters(), lr=learning_rate)
try: 
    optimizer.load_state_dict(checkpoint['optimizer'])
except:
    pass

# Learning rate scheduler
learing_rate_scheduler = lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.6)

In [None]:
save_model(model, optimizer, checkpoint_path)

In [None]:
wandb.login( 
key = "80fb920b9b9c938e38c91805ea14911702e078be",
)
wandb.init(project="polypsegmentation")
# Training loop
train_loss_array = []
test_loss_array = []
last_loss = 9999999999999
for epoch in range(epochs):
    train_loss_epoch = 0
    test_loss_epoch = 0
    (train_loss_epoch, test_loss_epoch) = train(train_dataloader, 
                                              valid_dataloader, 
                                              learing_rate_scheduler, epoch, display_step)
    
    if test_loss_epoch < last_loss:
        save_model(model, optimizer, checkpoint_path)
        last_loss = test_loss_epoch
        
    learing_rate_scheduler.step()
    train_loss_array.append(train_loss_epoch)
    test_loss_array.append(test_loss_epoch)
    wandb.log({"Train loss": train_loss_epoch, "Valid loss": test_loss_epoch})
#     train_accuracy.append(test(train_loader))
#     valid_accuracy.append(test(test_loader))
#     print("Epoch {}: loss: {:.4f}, train accuracy: {:.4f}, valid accuracy:{:.4f}".format(epoch + 1, 
#                                         train_loss_array[-1], train_accuracy[-1], valid_accuracy[-1]))