In [1]:
import glob
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

import albumentations as A
from albumentations.pytorch import ToTensor

EPOCHS = 150
BATCH_SIZE = 64
LEARNING_RATE = 0.0005
PAITIENCE = 10

IM_HEIGHT = 256
IM_WIDTH = 256
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

train_files = []
mask_files = glob.glob('data/lgg-mri-segmentation/kaggle_3m/*/*_mask*')

for i in mask_files:
    train_files.append(i.replace('_mask',''))

In [2]:
train_transforms = A.Compose([
    A.augmentations.Resize(width=IM_HEIGHT, height=IM_WIDTH, p=1.0),
    A.augmentations.HorizontalFlip(p=0.5),
    A.augmentations.VerticalFlip(p=0.5),
    A.augmentations.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=45),
    A.augmentations.Normalize(p=1.0),
    ToTensor(),
])

valid_transforms = A.Compose([
    A.augmentations.Resize(width=IM_HEIGHT, height=IM_WIDTH, p=1.0),
    A.augmentations.Normalize(p=1.0),
    ToTensor(),
])

In [4]:
class BrainMriDataset(Dataset):
    def __init__(self, df, transforms=None):
        
        self.df = df
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image = cv2.imread(self.df.iloc[idx, 0])
        mask = cv2.imread(self.df.iloc[idx, 1], 0)
        
        augmented = self.transforms(image=image, mask=mask)
        image = augmented['image']
        mask = augmented['mask']   
        
        return image, mask

    
df = pd.DataFrame(data={"filename": train_files, 'mask' : mask_files})
df_train, df_test = train_test_split(df, test_size=0.1, random_state=1234)
df_train, df_valid = train_test_split(df_train, test_size=0.2, random_state=1234)

print(f"Train: {df_train.shape} \nVal: {df_valid.shape} \nTest: {df_test.shape}")

train_dataset = BrainMriDataset(df=df_train, transforms=train_transforms)
train_iterator = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=True)

valid_dataset = BrainMriDataset(df=df_valid, transforms=valid_transforms)
valid_iterator = DataLoader(valid_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=True)

test_dataset = BrainMriDataset(df=df_test, transforms=valid_transforms)
test_iterator = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=False)

Train: (2828, 2) 
Val: (708, 2) 
Test: (393, 2)


In [5]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        
        mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)
    
    
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, bilinear=True):
        super(UNet, self).__init__()
        
        self.n_channels = n_channels
        self.n_classes  = n_classes
        self.bilinear   = bilinear

        self.inc   = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        
        factor     = 2 if bilinear else 1
        
        self.down4 = Down(512, 1024 // factor)
        self.up1   = Up(1024, 512 // factor, bilinear)
        self.up2   = Up(512, 256 // factor, bilinear)
        self.up3   = Up(256, 128 // factor, bilinear)
        self.up4   = Up(128, 64, bilinear)
        self.outc  = OutConv(64, n_classes)

    def forward(self, x):
        x1     = self.inc(x)
        x2     = self.down1(x1)
        x3     = self.down2(x2)
        x4     = self.down3(x3)
        x5     = self.down4(x4)
        x      = self.up1(x5, x4)
        x      = self.up2(x, x3)
        x      = self.up3(x, x2)
        x      = self.up4(x, x1)
        logits = self.outc(x)
        
        return logits

In [6]:
def dice_coef(y_true, y_pred, smooth=100):
    y_truef = torch.flatten(y_true)
    y_predf = torch.flatten(y_pred)
    And = torch.sum(y_truef * y_predf)
    
    return((2* And + smooth) / (torch.sum(y_truef) + torch.sum(y_predf) + smooth))


def dice_coef_loss(y_true, y_pred):
    return - dice_coef(y_true, y_pred)


def iou(y_true, y_pred, smooth=100):
    intersection = torch.sum(y_true * y_pred)
    sum_ = torch.sum(y_true + y_pred)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    
    return jac


def jac_distance(y_true, y_pred):
    y_truef = torch.flatten(y_true)
    y_predf = torch.flatten(y_pred)

    return - iou(y_true, y_pred)

In [7]:
def train(model, iterator, criterion, optimizer, device=device):  
    model.train()
    epoch_loss = 0
        
    for data, target in tqdm(iterator):
        optimizer.zero_grad()
        
        data = data.to(device)
        target = target.to(device)

        output = model(data)           

        pred   = F.sigmoid(output)
        pred   = (pred > 0.5).float()

        loss = criterion(output, target)
        loss.backward()
        
        optimizer.step()
            
        epoch_loss += loss.item()

    return epoch_loss / len(iterator)


def evaluate(model, iterator, criterion, device):
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():
        for data, target in iterator:
            data   = data.to(device)
            target = target.to(device)
            
            output = model(data)
            
            pred   = F.sigmoid(output)
            pred   = (pred > 0.5).float()
            
            loss = criterion(pred, target)
            
            epoch_loss += loss.item()
            
    return epoch_loss / len(iterator)


def print_log(epoch_num, train_loss, valid_loss):
    print(f"EPOCH: {epoch_num}")
    print(f"Train loss: {train_loss}\nValid loss: {valid_loss}")

In [None]:
model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, 'min')
# criterion = nn.BCEWithLogitsLoss()
criterion = dice_coef_loss

n_paitience = 0
best_valid_loss = float('inf')
optimizer.zero_grad()
optimizer.step()

for epoch_num in range(EPOCHS):
    train_loss = train(model, train_iterator, criterion, optimizer, device)
    valid_loss = evaluate(model, valid_iterator, criterion, device)
    
    scheduler.step(valid_loss)
    
    print_log(epoch_num, train_loss, valid_loss)
    
    if n_paitience < PAITIENCE:
        if best_valid_loss > valid_loss:
            best_valid_loss = valid_loss
            torch.save(model, 'weights/unet_best_pytorch')
            n_paitience = 0
        elif best_valid_loss <= valid_loss:
            n_paitience += 1
    else:
        print("Early stop!")
        model = torch.load('weights/unet_best_pytorch')
        break

  0%|          | 0/45 [00:00<?, ?it/s]