In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
import torch.optim as optim
from torch.optim import lr_scheduler
import os
import numpy as np
import time
import copy
import pandas as pd
import math
import matplotlib.pyplot as plt
import pickle
import nibabel as nib
import random
from tqdm import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, transforms):
        self.root = 'BraTS2021_Training_Data/'
        unnecessary_files = {'.DS_Store', '.ipynb_checkpoints'}
        self.folder_mris = list(sorted(os.listdir(self.root)))
        if self.folder_mris[0] in unnecessary_files:
            self.folder_mris = self.folder_mris[1:]
        self.transforms = transforms
        self.num_slices = 40
        self.start_slice = 30
        
    def __getitem__(self, idx):
        # load images and masks
        folder_idx = idx // self.num_slices
        slice_idx = ((idx % self.num_slices) * 2) + self.start_slice
        image_path = os.path.join(self.root, self.folder_mris[folder_idx])

        image_name = self.folder_mris[folder_idx] + '_flair.nii.gz'
        image = nib.load(os.path.join(image_path, image_name)).get_fdata()[:,:,slice_idx]
            
        mask_name = self.folder_mris[folder_idx] + '_' + 'seg' + '.nii.gz'
        mask = nib.load(os.path.join(image_path, mask_name)).get_fdata()[:,:,slice_idx]
        mask[mask > 0] = 1
        transformed = self.transforms(image=np.array(image, dtype = np.uint8),
                                      mask=np.array(mask, dtype = np.uint8))
        image = transformed["image"].float()
        image /= torch.max(image)
        mask = transformed["mask"].float().unsqueeze(0)
        return image, mask


    def __len__(self):
        return int(len(self.folder_mris) * self.num_slices)

In [None]:
data_transforms = {
    'train': A.Compose(
        [
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.3),
        A.RandomResizedCrop(224, 224, scale=(0.8, 1.0), ratio=(0.9, 1.1), p=0.3),
        A.Resize(224, 224),
        A.RandomBrightnessContrast(p=0.2),
        ToTensorV2(),
        ]
    ),
    'val': A.Compose(
        [
        A.Resize(224, 224),
        ToTensorV2(),
        ]
    )
}

In [None]:
dataset_train = ImageDataset(data_transforms['train'])
dataset_val = ImageDataset(data_transforms['val'])

torch.manual_seed(123) #для воспроизводимости
indices = torch.randperm(1251).tolist()
t = int(0.8 * 1251)
train_indices =  sum([(np.array((range(40)))+(i*40)).tolist() for i in indices[:t]], [])
test_indices = sum([(np.array((range(40)))+(i*40)).tolist() for i in indices[t:]], [])

dataset_train = torch.utils.data.Subset(dataset_train, train_indices)
dataset_val = torch.utils.data.Subset(dataset_val, test_indices)


dataloaders = {'train': torch.utils.data.DataLoader(dataset_train, batch_size=6, shuffle=True, num_workers=3),
               'val': torch.utils.data.DataLoader(dataset_val, batch_size=6, shuffle=False, num_workers=3)}

dataset_sizes = {'train': len(dataset_train), 'val': len(dataset_val)}

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_dice = 0.0

    time_previous = time.time()
    for epoch in range(num_epochs):
        epoch_dice = 0.0
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    for i in range(len(inputs)):
                        epoch_dice += dice_score(torch.round(outputs[i]), labels[i]).item()

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item()

                # statistics
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_dice /= dataset_sizes[phase]
            
            if phase == 'train':
                scheduler.step(epoch_dice)

            print('{} Loss: {:.4f} Dice: {:.4f}'.format(
                phase, epoch_loss, epoch_dice))

            # deep copy the model
            if phase == 'val' and epoch_dice > best_dice:
                best_dice = epoch_dice
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(best_model_wts, f'resunet_{epoch}.pth')
            
        print()
        time_spent = time.time() - time_previous
        time_previous = time.time()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_dice))
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, best_model_wts

In [None]:
!nvidia-smi

In [None]:
device = torch.device("cuda:0")

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.act1 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)

        self.bn2 = nn.BatchNorm2d(in_channels)
        self.act2 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)

    def forward(self, x):
        identity = x

        out = self.bn1(x)
        out = self.act1(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.act2(out)
        out = self.conv2(out)

        out += identity
        return out


class ResUNet(nn.Module):
    def __init__(self, in_channels=4, out_channels=1, features=32, dropout=False, pooling_size=2):
        super(ResUNet, self).__init__()

        if dropout:
            dropout_layer = nn.Dropout(0.1)
        else:
            dropout_layer = nn.Identity()

        self.init_path = nn.Sequential(
            nn.Conv2d(in_channels, features, kernel_size=3, padding=1, bias=False),
            nn.ReLU(),
            ResidualBlock(features, features, kernel_size=3, padding=1),
            ResidualBlock(features, features, kernel_size=3, padding=1),
            ResidualBlock(features, features, kernel_size=3, padding=1)
        )
        self.shortcut0 = nn.Conv2d(features, features, kernel_size=1)

        self.down1 = nn.Sequential(
            nn.BatchNorm2d(features),
            nn.Conv2d(features, features * 2, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer,
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1),
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1),
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1)
        )
        self.shortcut1 = nn.Conv2d(features * 2, features * 2, 1)

        self.down2 = nn.Sequential(
            nn.BatchNorm2d(features * 2),
            nn.Conv2d(features * 2, features * 4, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer,
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1),
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1),
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1)
        )
        self.shortcut2 = nn.Conv2d(features * 4, features * 4, 1)

        self.down3 = nn.Sequential(
            nn.BatchNorm2d(features * 4),
            nn.Conv2d(features * 4, features * 8, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer,
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            dropout_layer
        )

        self.up3 = nn.Sequential(
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(features * 8),
            nn.ConvTranspose2d(features * 8, features * 4, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer
        )

        self.up2 = nn.Sequential(
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1),
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1),
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1),
            nn.BatchNorm2d(features * 4),
            nn.ConvTranspose2d(features * 4, features * 2, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer
        )

        self.up1 = nn.Sequential(
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1),
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1),
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(features * 2),
            nn.ConvTranspose2d(features * 2, features, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer
        )

        self.out_path = nn.Sequential(
            ResidualBlock(features, features, kernel_size=1, padding=0),
            nn.BatchNorm2d(features),
            nn.ReLU(),
            nn.Conv2d(features, out_channels, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        x0 = self.init_path(x)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)

        x2_up = self.up3(x3)
        x1_up = self.up2(x2_up + self.shortcut2(x2))
        x0_up = self.up1(x1_up + self.shortcut1(x1))
        x_out = self.out_path(x0_up + self.shortcut0(x0))
        return torch.sigmoid(x_out)


In [None]:
model = ResUNet(in_channels=1)
model.to(device)
model_name = "resunet"

In [None]:
def dice_coef_metric(pred, label):
    intersection = 2.0 * (pred * label).sum()
    union = pred.sum() + label.sum()
    if pred.sum() == 0 and label.sum() == 0:
        return 1.
    return intersection / union

def dice_coef_loss(pred, label):
    smooth = 1.0
    intersection = 2.0 * (pred * label).sum() + smooth
    union = pred.sum() + label.sum() + smooth
    return 1 - (intersection / union)

def bce_dice_loss(pred, label):
    dice_loss = dice_coef_loss(pred, label)
    bce_loss = nn.BCELoss()(pred, label)
    return dice_loss + bce_loss

def train_loop(model, loader, loss_func):
    model.train()
    train_losses = []
    train_dices = []
    
    for image, mask in tqdm(loader):
        image = image.to(device)
        mask = mask.to(device)
        outputs = model(image)
        out_cut = np.copy(outputs.data.cpu().numpy())
        out_cut[np.nonzero(out_cut < 0.5)] = 0.0
        out_cut[np.nonzero(out_cut >= 0.5)] = 1.0            

        dice = dice_coef_metric(out_cut, mask.data.cpu().numpy())
        loss = loss_func(outputs, mask)
        train_losses.append(loss.item())
        train_dices.append(dice)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    return train_dices, train_losses

def eval_loop(model, loader, loss_func, epoch, training=True):
    model.eval()
    val_loss = 0
    val_dice = 0
    with torch.no_grad():
        for step, (image, mask) in tqdm(enumerate(loader)):
            image = image.to(device)
            mask = mask.to(device)
    
            outputs = model(image)
            loss = loss_func(outputs, mask)
            
            out_cut = np.copy(outputs.data.cpu().numpy())
            out_cut[np.nonzero(out_cut < 0.5)] = 0.0
            out_cut[np.nonzero(out_cut >= 0.5)] = 1.0
            dice = dice_coef_metric(out_cut, mask.data.cpu().numpy())
            
            val_loss += loss
            val_dice += dice
        
        val_mean_dice = val_dice / step
        val_mean_loss = val_loss / step
        
        best_model_wts = copy.deepcopy(model.state_dict())
        torch.save(best_model_wts, f'resunet_{epoch}.pth')      
        
        if training:
            scheduler.step(val_mean_dice)
        
    return val_mean_dice, val_mean_loss

def train_model(train_loader, val_loader, loss_func, optimizer, scheduler, num_epochs):
    train_loss_history = []
    train_dice_history = []
    val_loss_history = []
    val_dice_history = []
    
    for epoch in range(num_epochs):
        train_dices, train_losses = train_loop(model, train_loader, loss_func)
        train_mean_dice = np.array(train_dices).mean()
        train_mean_loss = np.array(train_losses).mean()
        val_mean_dice, val_mean_loss = eval_loop(model, val_loader, loss_func, epoch)
        
        train_loss_history.append(np.array(train_losses).mean())
        train_dice_history.append(np.array(train_dices).mean())
        val_loss_history.append(val_mean_loss)
        val_dice_history.append(val_mean_dice)
        
        print('Epoch: {}/{} |  Train Loss: {:.3f}, Val Loss: {:.3f}, Train DICE: {:.3f}, Val DICE: {:.3f}'.format(epoch+1, num_epochs,
                                                                                                                 train_mean_loss,
                                                                                                                 val_mean_loss,
                                                                                                                 train_mean_dice,
                                                                                                                 val_mean_dice))
        

    return train_loss_history, train_dice_history, val_loss_history, val_dice_history

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3)
train_loss_history, train_dice_history, val_loss_history, val_dice_history = train_model(dataloaders['train'],
                                                                                         dataloaders['val'],
                                                                                         bce_dice_loss, optimizer,
                                                                                         scheduler, 20)
