In [2]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

!pip install albumentations

import numpy as np
from PIL import Image
import glob
import cv2
import os
from os.path import join as pjoin
from pdb import set_trace
import copy
import scipy

import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn

from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset, WeightedRandomSampler
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, ToPILImage, CenterCrop, RandomResizedCrop
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import alexnet, resnet18, inception_v3
from torchvision.models.alexnet import model_urls
try:
  from torchvision.models.utils import load_state_dict_from_url
except ImportError:
  from torch.hub import load_state_dict_from_url

import albumentations as albu
from albumentations.pytorch import ToTensorV2

# Detect if we have a GPU available
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def make_train_step_fn(model, loss_fn, optimizer):
    # Builds function that performs a step in the train loop
    def perform_train_step_fn(x, y):
        # Sets model to TRAIN mode
        model.train()

        # Step 1 - Computes our model's predicted output - forward pass
        yhat = model(x)

        # Step 2 - Computes the loss
        loss = loss_fn(yhat, y)

        # Step 3 - Computes gradients for both "a" and "b" parameters
        loss.backward()

        # Step 4 - Updates parameters using gradients and the learning rate
        optimizer.step()
        optimizer.zero_grad()

        # Returns the loss
        return loss.item()

    # Returns the function that will be called inside the train loop
    return perform_train_step_fn

def mini_batch(device, data_loader, step_fn):
    mini_batch_losses = []
    for x_batch, y_batch in data_loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        mini_batch_loss = step_fn(x_batch, y_batch)
        mini_batch_losses.append(mini_batch_loss)

    loss = np.mean(mini_batch_losses)
    return loss

def make_val_step_fn(model, loss_fn):
    # Builds function that performs a step in the validation loop
    def perform_val_step_fn(x, y):
        # Sets model to EVAL mode
        model.eval()
        
        # Step 1 - Computes our model's predicted output - forward pass
        yhat = model(x)

        # Step 2 - Computes the loss
        loss = loss_fn(yhat, y)

        # There is no need to compute Steps 3 and 4, since we don't update parameters during evaluation
        return loss.item()
    
    return perform_val_step_fn

def multi_acc(y_pred, y_test):
    """ Function to calculate multi-class accuracy
    """
    y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
    correct_pred = (y_pred_tags == y_test).float()
    acc = correct_pred.sum() / len(correct_pred)
    acc = torch.round(acc)
#    acc = torch.round(acc * 100)
    return acc

class SegmentationDataset(Dataset):
    def __init__(self, folder_path, transform=None, subset="train"):
        super(SegmentationDataset, self).__init__()

        if subset=="train": 
            self.img_files = [pjoin(folder_path, "image", "train-volume%02d.jpg" % (idx)) for idx in range(23+1)]
            self.mask_files = [pjoin(folder_path, "labels", "train-labels%02d.jpg" % (idx)) for idx in range(23+1)]
        elif subset=="val":
            self.img_files = [pjoin(folder_path, "image", "train-volume%02d.jpg" % (idx)) for idx in range(24, 26+1)]
            self.mask_files = [pjoin(folder_path, "labels", "train-labels%02d.jpg" % (idx)) for idx in range(24, 26+1)]    
        else:
            self.img_files = [pjoin(folder_path, "image", "train-volume%02d.jpg" % (idx)) for idx in range(27, 30)]
            self.mask_files = [pjoin(folder_path, "labels", "train-labels%02d.jpg" % (idx)) for idx in range(27, 30)]
        self.transform = transform

    def __getitem__(self, index):
        img_path = self.img_files[index]
        mask_path = self.mask_files[index]

        data = cv2.imread(img_path)
        label = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        try:
            label[label < 127] = 0
        except:
            set_trace()
        label[label >= 127] = 1
        label1 = np.zeros((label.shape[0], label.shape[1], 2))
        label1 = np.eye(2)[label]
        label = label1
        if self.transform is not None:
            # Albumentation handles the synchronizing between image and mask
            transformed = self.transform(image=data, mask=label)
            data = transformed["image"]
            label = transformed["mask"]
        label = torch.permute(label, (2, 0, 1))
        return data, label

    def __len__(self):
        return len(self.img_files)

class UNet(nn.Module):
    def __init__(self, num_class=2, num_channel=3):
        super(UNet, self).__init__()

        psize = 1
        dropout_rate = 0.1
        # Contracting part
        self.conv1_1 = nn.Conv2d(in_channels=num_channel, out_channels=64, kernel_size=3, padding=psize)
        self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=psize)
        self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=psize)
        self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=psize)
        self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.conv3_1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=psize)
        self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=psize)
        self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=psize)
        self.conv4_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=psize)
        self.mpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.conv5_1 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=psize)
        self.conv5_2 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=psize) # No max-pooling in this part
        
        # Expanding part
        self.upconv6 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, padding=psize, stride=2, output_padding=1)
        self.conv6_1 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=psize)
        self.conv6_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=psize)

        self.upconv7 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, padding=psize, stride=2, output_padding=1)
        self.conv7_1 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=psize)
        self.conv7_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=psize)

        self.upconv8 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, padding=psize, stride=2, output_padding=1)
        self.conv8_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=psize)
        self.conv8_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=psize)

        self.upconv9 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, padding=psize, stride=2, output_padding=1)
        self.conv9_1 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=psize)
        self.conv9_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=psize)

        # Last layer is 1x1 convolution to map to number of classes
        self.conv_last = nn.Conv2d(in_channels=64, out_channels=num_class, kernel_size=1, padding=0)
        # self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=1)

        self.dropout = nn.Dropout(p=dropout_rate)


    def forward(self, x1):
        # Contracing part
        x1_1 = self.conv1_1(x1)
        x1_1 = F.relu(x1_1)
        x1_1 = self.dropout(x1_1)
        x1_2 = self.conv1_2(x1_1)        
        x1_2 = F.relu(x1_2)
        x1_2 = self.dropout(x1_2)

        x2 = self.mpool1(x1_2)
        x2_1 = self.conv2_1(x2)
        x2_1 = F.relu(x2_1)
        x2_1 = self.dropout(x2_1)
        x2_2 = self.conv2_2(x2_1)        
        x2_2 = F.relu(x2_2)
        x2_2 = self.dropout(x2_2)

        x3 = self.mpool2(x2_2)
        x3_1 = self.conv3_1(x3)
        x3_1 = F.relu(x3_1)
        x3_1 = self.dropout(x3_1)
        x3_2 = self.conv3_2(x3_1)        
        x3_2 = F.relu(x3_2)

        x3_2 = self.dropout(x3_2)
        x4 = self.mpool3(x3_2)
        x4_1 = self.conv4_1(x4)
        x4_1 = F.relu(x4_1)

        x4_1 = self.dropout(x4_1)
        x4_2 = self.conv4_2(x4_1)        
        x4_2 = F.relu(x4_2)
        x4_2 = self.dropout(x4_2)

        x5 = self.mpool4(x4_2)
        x5_1 = self.conv5_1(x5)
        x5_1 = F.relu(x5_1)
        x5_1 = self.dropout(x5_1)

        x5_2 = self.conv5_2(x5_1)        
        x6 = F.relu(x5_2) # No max-pooling in this part
 
        # Expanding part
        x6_up = self.upconv6(x6)
        x6_0 = torch.cat([x4_2, x6_up], dim=1)
        x6_1 = self.conv6_1(x6_0)
        x6_1 = F.relu(x6_1)
        x6_1 = self.dropout(x6_1)
        x6_2 = self.conv6_2(x6_1)        
        x7 = F.relu(x6_2)

        x7 = self.dropout(x7)
        x7_up = self.upconv7(x7)  
        x7_0 = torch.cat([x3_2, x7_up], dim=1)
        x7_1 = self.conv7_1(x7_0)
        x7_1 = F.relu(x7_1)
        x7_1 = self.dropout(x7_1)
        x7_2 = self.conv7_2(x7_1)  
        
        x8 = F.relu(x7_2)
        x8 = self.dropout(x8)
        x8_up = self.upconv8(x8)
        x8_0 = torch.cat([x2_2, x8_up], dim=1)
        x8_1 = self.conv8_1(x8_0)
        x8_1 = F.relu(x8_1)
        x8_1 = self.dropout(x8_1)
        x8_2 = self.conv8_2(x8_1)  
        
        x9 = F.relu(x8_2)
        x9 = self.dropout(x9)
        x9_up = self.upconv9(x9)
        x9_0 = torch.cat([x1_2, x9_up], dim=1)
        x9_1 = self.conv9_1(x9_0)
        x9_1 = F.relu(x9_1)
        x9_1 = self.dropout(x9_1)
        x9_2 = self.conv9_2(x9_1)  
        
        x10 = F.relu(x9_2)
        x10 = self.dropout(x10)
        x_last = self.conv_last(x10)
        return x_last

# def dice_loss(target, pred):
#     numerator = 2 * torch.sum(pred * target)
#     denominator = torch.sum(pred + target)
#     return 1 - (numerator + 0.0001) / (denominator + 0.0001)

def dice_loss(target, pred):
    numerator = 2 * np.sum(pred * target)
    denominator = np.sum(pred + target)
    return 1 - (numerator + 0.0001) / (denominator + 0.0001)

def soft_dice_loss(y_true, y_pred, epsilon=1e-6): 
    ''' 
    Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
    Assumes the `channels_last` format.
  
    # Arguments
        y_true: b x X x Y( x Z...) x c One hot encoding of ground truth
        y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax) 
        epsilon: Used for numerical stability to avoid divide by zero errors
    
    # References
        V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation 
        https://arxiv.org/abs/1606.04797
        More details on Dice loss formulation 
        https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72)
        
        Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022
    '''
    
    # skip the batch and class axis for calculating Dice score
    axes = tuple(range(1, len(y_pred.shape)-1)) 
    numerator = 2. * np.sum(y_pred * y_true, axes)
    denominator = np.sum(np.square(y_pred) + np.square(y_true), axes)
    
    return 1 - np.mean((numerator + epsilon) / (denominator + epsilon)) # average over classes and batch
    # thanks @mfernezir for catching a bug in an earlier version of this implementation!

# dataset_path = '/content/drive/My Drive/Colab Notebooks/UNet_Pytorch/ISBI_2012_dataset_ver2'
#dataset_path = './ISBI_2012_dataset_ver2'
dataset_path = "/kaggle/input/isbi-2012/ISBI_2012_dataset_ver2"
n_cls = 2
n_epoch = 2000
batch_size = 2
img_size = 512
learning_rate = 0.001
# learning_rate = 3e-5

train_path = pjoin(dataset_path, "train")
val_path = pjoin(dataset_path, "val")
test_path = pjoin(dataset_path, "test")

# Use albumentations
train_transform = albu.Compose([
    albu.Resize(height=img_size, width=img_size),
    albu.HorizontalFlip(p=0.5),
    albu.VerticalFlip(p=0.5),
    albu.ShiftScaleRotate(scale_limit=0.20, rotate_limit=30, shift_limit=0.1, p=0.5, border_mode=cv2.BORDER_CONSTANT, value=0),
#    albu.CLAHE(),
#    albu.Downscale(p=0.5),
#    albu.GridDistortion(p=0.5),
#    albu.Equalize(p=0.5),
#    albu.GaussNoise(p=0.5),
#    albu.RandomBrightnessContrast(p=0.5),
#    albu.Blur(blur_limit=3),
#    albu.OpticalDistortion(),
#    albu.HueSaturationValue(),
#    albu.RandomRotate90(p=0.5),
#    albu.ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
    albu.ElasticTransform(sigma=4, alpha=34),
    # albu.Resize(height=img_size, width=img_size),
    # albu.RandomCrop(*crop_size),
    # albu.ChannelShuffle(),
    # albu.InvertImg(),
    albu.ToGray(),
    # albu.ToTensor(),    
    albu.Normalize(),
    # albu.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

val_transform = albu.Compose([
    albu.Resize(height=img_size, width=img_size),
    albu.ToGray(),
    # albu.ToTensor(),
    albu.Normalize(),
    # albu.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

test_transform = albu.Compose([
    albu.Resize(height=img_size, width=img_size),
    albu.ToGray(),
    # albu.ToTensor(),
    albu.Normalize(),
    # albu.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

dataset_train = SegmentationDataset(train_path, transform=train_transform, subset="train")
dataset_val = SegmentationDataset(val_path, transform=val_transform, subset="val")
dataset_test = SegmentationDataset(test_path, transform=test_transform, subset="test")

train_loader = DataLoader(dataset=dataset_train, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(dataset=dataset_val, shuffle=False, batch_size=1)
test_loader = DataLoader(dataset=dataset_test, shuffle=False, batch_size=1)

image, label = dataset_train[0]

model = UNet(num_class=n_cls)
model.to(DEVICE)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
train_step_fn = make_train_step_fn(model, loss_fn, optimizer)
val_step_fn = make_val_step_fn(model, loss_fn)

dice_stats = {
    'train': [],
    "val": []
}
acc_stats = {
    'train': [],
    
    "val": [],
}
loss_stats = {
    'train': [],
    "val": []
}
best_dice = -1
best_loss = 100000
best_model = None
for epoch in range(0, n_epoch):
    # TRAINING
    # (train_epoch_acc, train_epoch_loss) = mini_batch(DEVICE, train_loader, train_step_fn)
    train_epoch_loss = mini_batch(DEVICE, train_loader, train_step_fn)
    loss_stats['train'].append(train_epoch_loss/len(train_loader))
    # acc_stats['train'].append(train_epoch_acc)
    # dice_stats['train'].append(1 - train_epoch_loss)

    with torch.no_grad():
        # (val_epoch_acc, val_epoch_loss) = mini_batch(DEVICE, val_loader, val_step_fn)
        val_epoch_loss = mini_batch(DEVICE, val_loader, val_step_fn)
        loss_stats['val'].append(val_epoch_loss/len(val_loader))
        # acc_stats['val'].append(val_epoch_acc)
        # dice_stats['val'].append(1 - val_epoch_loss)

    # print(f'Epoch {epoch+0:02}: | Train Loss: {train_epoch_loss/len(train_loader):.5f} | Val Loss: {val_epoch_loss/len(val_loader):.5f} | Train Acc: {train_epoch_acc:.3f}| Val Acc: {val_epoch_acc:.3f}')
    print(f'Epoch {epoch+0:02}: | Train Loss: {train_epoch_loss/len(train_loader):.5f} | Val Loss: {val_epoch_loss/len(val_loader):.5f}')
    if val_epoch_loss < best_loss:
        # best_acc = val_epoch_acc
        best_loss = val_epoch_loss
        best_model = copy.deepcopy(model.state_dict())

# Plot loss
plt.plot(loss_stats['train'], label='Training loss')
plt.plot(loss_stats['val'], label='Validation loss')
plt.legend()
plt.title('Loss')
plt.show()
plt.savefig("Loss.png")
plt.close()

# Test
gt_list = []
pred_list = []
model.load_state_dict(best_model)
for idx, (img, mask) in enumerate(test_loader):
    with torch.no_grad():
        model.load_state_dict(best_model)
        model.eval()
        pred = model(img.to(DEVICE).float())
    fig, axarr = plt.subplots(1,2)

    img_show = img.squeeze()
    img_show = np.transpose(img_show, (1, 2, 0))
    img_show_np = img_show.cpu().detach().numpy()

    mask_pred = mask.cpu().detach().numpy()
    mask_pred = mask_pred.squeeze()
    mask_pred = np.argmax(mask_pred, axis=0)
    mask_show_np = mask_pred

    pred_show_np = pred.cpu().detach().numpy()
    pred_show_np = pred_show_np.squeeze()
    pred_show_np = np.transpose(pred_show_np, (1, 2, 0))
    pred_show_np = scipy.special.softmax(pred_show_np, axis=2)
    pred_show_np = np.argmax(pred_show_np, axis=2)

    gt_list.append(mask_show_np)
    pred_list.append(pred_show_np)

    # set_trace()
    fig, axes = plt.subplots(3)
    axes[0].imshow(img_show_np)
    axes[1].imshow(mask_show_np)
    axes[2].imshow(pred_show_np)

    plt.show()
    plt.savefig("Compare_%d.png" % (idx))
    plt.close()

n_test = len(gt_list)
gt_np = np.zeros((n_test, img_size, img_size, 2))
pred_np = np.zeros((n_test, img_size, img_size, 2))
for idx_img in range(n_test):
    gt = gt_list[idx_img]
    pred = pred_list[idx_img]

    for idx1 in range(img_size):
        for idx2 in range(img_size):
            if gt[idx1, idx2] == 1:
                gt_np[idx_img, idx1, idx2, 1] = 1
            else:
                gt_np[idx_img, idx1, idx2, 0] = 1
            if pred[idx1, idx2] == 1:
                pred_np[idx_img, idx1, idx2, 1] = 1
            else:
                pred_np[idx_img, idx1, idx2, 0] = 1

    # gt_np[idx_img, :, :, 0] = gt_list[idx_img]
    # pred_np[idx_img, :, :, 0] = pred_list[idx_img]
dice_loss = soft_dice_loss(gt_np, pred_np)
dice_coeff = 1 - dice_loss
print("Dice coefficient: %f" % (dice_coeff))
# Dice coefficent best: 0.87188

from sklearn.metrics import accuracy_score
acc = accuracy_score(gt_np.flatten(), pred_np.flatten())
print("Accuracyt: %f" % (acc))

[0mEpoch 00: | Train Loss: 0.27724 | Val Loss: 0.20347
Epoch 01: | Train Loss: 0.04996 | Val Loss: 0.20008
Epoch 02: | Train Loss: 0.04663 | Val Loss: 0.18765
Epoch 03: | Train Loss: 0.04184 | Val Loss: 0.13926
Epoch 04: | Train Loss: 0.02927 | Val Loss: 0.12010
Epoch 05: | Train Loss: 0.02758 | Val Loss: 0.11906
Epoch 06: | Train Loss: 0.02794 | Val Loss: 0.11148
Epoch 07: | Train Loss: 0.02736 | Val Loss: 0.11044
Epoch 08: | Train Loss: 0.02529 | Val Loss: 0.10650
Epoch 09: | Train Loss: 0.02612 | Val Loss: 0.10213
Epoch 10: | Train Loss: 0.02694 | Val Loss: 0.11163
Epoch 11: | Train Loss: 0.07737 | Val Loss: 0.11443
Epoch 12: | Train Loss: 0.02913 | Val Loss: 0.11318
Epoch 13: | Train Loss: 0.02705 | Val Loss: 0.11169
Epoch 14: | Train Loss: 0.02739 | Val Loss: 0.11106
Epoch 15: | Train Loss: 0.03069 | Val Loss: 0.11137
Epoch 16: | Train Loss: 0.02646 | Val Loss: 0.11021
Epoch 17: | Train Loss: 0.02650 | Val Loss: 0.11092
Epoch 18: | Train Loss: 0.02761 | Val Loss: 0.11095
Epoch 19

In [12]:
# Plot loss
plt.plot(loss_stats['train'], label='Train loss')
plt.plot(loss_stats['val'], label='Validation loss')
# plt.legend()
plt.title('Loss')
plt.show()
plt.savefig("abc.png")
plt.close()