In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
import torch.optim as optim
from torch.optim import lr_scheduler
import os
import numpy as np
import time
import copy
import pandas as pd
import math
import matplotlib.pyplot as plt
import pickle
import nibabel as nib
import random
from tqdm import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2

import warnings
warnings.filterwarnings("ignore")

In [2]:
class ImageDatasetBrats(torch.utils.data.Dataset):
    def __init__(self, transforms):
        self.root = 'burdenko_numpy/'
        unnecessary_files = {'.DS_Store', '.ipynb_checkpoints'}
        self.df = pd.read_csv(f'burdenko_numpy/burdenko_slices.csv')
        self.group1 = {'1019_18','1034_18_4','1036_18','1043_18_4','1056_18_4','1072_19','1112_19_4','1159_18_4',
                  '1164_18','1170_18_4','1184_18','1185_18_4','1214_18','1257_18','1267_18_4','1275_19_4', '1333_18',
                  '1357_19_4','1484_18_4','1541_18_4','1546_18','1733_18','1734_18','1781_18','1795_18_', '255_18',
                  '608_18_4','664_18_4','672_18_4','705_18_4','746_19_4','788_18','826_18_4','856_19_4', '923_18',
                  '925_18_4','946_18','979_18_4'}

        self.group2 = {'1028_18_4','1096_18','1216_18','1254_18','1255_18','1258_18','1302_18_4','1326_18','1354_18_4',
                  '1360_18','1362_18_4','1421_18','1463_18_4','1470_18_4','1501_18_4','1515_18_4','1539_18',
                  '1566_18','1573_18_4','1635_18','1646_18','1685_18_4','1702_18','1743_18_4','1746_18_4',
                  '1764_18_4','1769_18','1770_18_4','322_18_4','349_18_4','351_18','541_18','558_18_4','573_18_4',
                  '575_18_4','593_18','644_19_4','660_18_4','745_18_4','770_18','875_18_4','971_18_4','990_18_4',
                  'Patient_1000314','Patient_1000815','Patient_1001316','Patient_102117','Patient_103717',
                  'Patient_104514','Patient_105215','Patient_107017','Patient_109017','Patient_109414',
                  'Patient_110014','Patient_110816','Patient_111016','Patient_120115','Patient_12115',
                  'Patient_12214','Patient_122315','Patient_123816','Patient_12417','Patient_127916',
                  'Patient_129316','Patient_129415','Patient_129816','Patient_130514','Patient_131416',
                  'Patient_132216','Patient_133916','Patient_135915','Patient_136415','Patient_136715',
                  'Patient_136915','Patient_137315','Patient_138316','Patient_138516','Patient_140316',
                  'Patient_146716','Patient_15215','Patient_15817','Patient_158716','Patient_161316',
                  'Patient_1815','Patient_20717','Patient_22117','Patient_24117','Patient_24717', 'Patient_24815',
                  'Patient_28514','Patient_2914','Patient_33217','Patient_43316', 'Patient_43515','Patient_45217',
                  'Patient_48417','Patient_48517','Patient_49617', 'Patient_5117','Patient_51815','Patient_52315',
                  'Patient_54317','Patient_56717', 'Patient_59315','Patient_59817','Patient_61715','Patient_61916',
                  'Patient_62315', 'Patient_62817','Patient_65516','Patient_66615','Patient_69515','Patient_70614',
                  'Patient_716','Patient_72715','Patient_74417','Patient_75116','Patient_76516', 'Patient_8017',
                  'Patient_83217','Patient_83714','Patient_84116','Patient_87114', 'Patient_88817','Patient_88917',
                  'Patient_89117','Patient_90517','Patient_90616', 'Patient_92114','Patient_9315','Patient_95717',
                  'Patient_98814','Patient_98817','Patient_99715'}

        self.group3 = {'1029_18_4','1744_18','1765_18_4','1788_18_4','423_18','607_18','668_18_4','688_18'}
        self.transforms = transforms
        
    def __getitem__(self, idx):
        folder, slice_ = self.df.iloc[idx][['patient', 'slice']]
        if (folder in self.group1) or (folder in self.group3):
            image = np.load(f'{self.root}{folder}/flair.npz')['arr_0'][:, slice_, :]
        if folder in self.group2:
            image = np.load(f'{self.root}{folder}/flair.npz')['arr_0'][:, :, slice_]
        image = image / np.max(image) * 255
        
        if (folder in self.group1) or (folder in self.group3):
            mask = np.load(f'{self.root}{folder}/mask.npz')['arr_0'][:, slice_, :]
        if folder in self.group2:
            mask = np.load(f'{self.root}{folder}/mask.npz')['arr_0'][:, :, slice_]
        mask[mask > 1] = 1
        transformed = self.transforms(image=np.array(image, dtype = np.uint8),
                                      mask=np.array(mask, dtype = np.uint8))
        image = transformed["image"].float()
        mask = transformed["mask"].float().unsqueeze(0)
        return image, mask


    def __len__(self):
        return len(self.df)

In [3]:
data_transforms = {
    'train': A.Compose(
        [
        A.RandomResizedCrop(256, 256, scale=(0.8, 1.0), ratio=(0.9, 1.1), p=0.3),
        A.Resize(256, 256),
        A.RandomBrightnessContrast(p=0.2),
        A.Normalize(mean=0, std=1),
        ToTensorV2(),
        ]
    ),
    'val': A.Compose(
        [
        A.Resize(256, 256),
        A.Normalize(mean=0, std=1),
        ToTensorV2(),
        ]
    )
}

In [4]:
dataset_train = ImageDatasetBrats(data_transforms['train'])
dataset_test = ImageDatasetBrats(data_transforms['val'])

torch.manual_seed(123) #для воспроизводимости
indices = torch.randperm(180).tolist()
t = int(180*0.7)

df_for_training = pd.read_csv("burdenko_numpy/burdenko_slices.csv")

train_indices = df_for_training[df_for_training.patient_index.isin(indices[:t])]['index'].tolist()
test_indices = df_for_training[df_for_training.patient_index.isin(indices[t:])]['index'].tolist()

dataset_train = torch.utils.data.Subset(dataset_train, train_indices)
dataset_test = torch.utils.data.Subset(dataset_test, test_indices)

dataloaders = {'train': torch.utils.data.DataLoader(dataset_train, batch_size=16, shuffle=True, num_workers=4),
               'test': torch.utils.data.DataLoader(dataset_test, batch_size=16, shuffle=False, num_workers=4)}

dataset_sizes = {'train': len(dataset_train), 'val': len(dataset_test)}

In [5]:
!nvidia-smi

Sat May 21 23:05:30 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.54       Driver Version: 510.54       CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:06:00.0 Off |                  N/A |
| 32%   65C    P2    90W / 250W |   2213MiB /  6144MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [6]:
device = torch.device("cuda:0")

In [7]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.act1 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)

        self.bn2 = nn.BatchNorm2d(in_channels)
        self.act2 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)

    def forward(self, x):
        identity = x

        out = self.bn1(x)
        out = self.act1(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.act2(out)
        out = self.conv2(out)

        out += identity
        return out


class ResUNet(nn.Module):
    def __init__(self, in_channels=4, out_channels=1, features=32, dropout=False, pooling_size=2):
        super(ResUNet, self).__init__()

        if dropout:
            dropout_layer = nn.Dropout(0.1)
        else:
            dropout_layer = nn.Identity()

        self.init_path = nn.Sequential(
            nn.Conv2d(in_channels, features, kernel_size=3, padding=1, bias=False),
            nn.ReLU(),
            ResidualBlock(features, features, kernel_size=3, padding=1),
            ResidualBlock(features, features, kernel_size=3, padding=1),
            ResidualBlock(features, features, kernel_size=3, padding=1)
        )
        self.shortcut0 = nn.Conv2d(features, features, kernel_size=1)

        self.down1 = nn.Sequential(
            nn.BatchNorm2d(features),
            nn.Conv2d(features, features * 2, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer,
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1),
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1),
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1)
        )
        self.shortcut1 = nn.Conv2d(features * 2, features * 2, 1)

        self.down2 = nn.Sequential(
            nn.BatchNorm2d(features * 2),
            nn.Conv2d(features * 2, features * 4, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer,
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1),
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1),
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1)
        )
        self.shortcut2 = nn.Conv2d(features * 4, features * 4, 1)

        self.down3 = nn.Sequential(
            nn.BatchNorm2d(features * 4),
            nn.Conv2d(features * 4, features * 8, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer,
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            dropout_layer
        )

        self.up3 = nn.Sequential(
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(features * 8),
            nn.ConvTranspose2d(features * 8, features * 4, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer
        )

        self.up2 = nn.Sequential(
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1),
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1),
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1),
            nn.BatchNorm2d(features * 4),
            nn.ConvTranspose2d(features * 4, features * 2, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer
        )

        self.up1 = nn.Sequential(
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1),
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1),
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(features * 2),
            nn.ConvTranspose2d(features * 2, features, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer
        )

        self.out_path = nn.Sequential(
            ResidualBlock(features, features, kernel_size=1, padding=0),
            nn.BatchNorm2d(features),
            nn.ReLU(),
            nn.Conv2d(features, out_channels, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        x0 = self.init_path(x)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)

        x2_up = self.up3(x3)
        x1_up = self.up2(x2_up + self.shortcut2(x2))
        x0_up = self.up1(x1_up + self.shortcut1(x1))
        x_out = self.out_path(x0_up + self.shortcut0(x0))
        return x_out
#         return torch.sigmoid(x_out)


In [8]:
# model = ResUNet(in_channels=1, out_channels=4)
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=1, out_channels=1, init_features=32, pretrained=False)
model.load_state_dict(torch.load("model_weights/unet_burd_orig.pth", map_location=torch.device('cpu')))
model.to(device)
model_name = "model_weights/unet_burd_orig_2"

Using cache found in /home/i_govorova/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


In [9]:
def dice_coef_metric(pred, label):
    intersection = 2.0 * (pred * label).sum()
    union = pred.sum() + label.sum()
    if pred.sum() == 0 and label.sum() == 0:
        return 1.
    return intersection / union

def dice_coef_loss(pred, label):
    smooth = 1.0
    intersection = 2.0 * (pred * label).sum() + smooth
    union = pred.sum() + label.sum() + smooth
    return 1 - (intersection / union)

def train_loop(model, loader, loss_func):
    model.train()
    train_losses = []
    train_dices = []
    
    for image, mask in tqdm(loader):
        image = image.to(device)
        mask = mask.to(device)
        outputs = model(image)
        out_cut = np.copy(outputs.data.cpu().numpy())
        out_cut[np.nonzero(out_cut < 0.5)] = 0.0
        out_cut[np.nonzero(out_cut >= 0.5)] = 1.0            

        dice = dice_coef_metric(out_cut, mask.data.cpu().numpy())
        loss = loss_func(outputs, mask)
        train_losses.append(loss.item())
        train_dices.append(dice)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    return train_dices, train_losses

def eval_loop(model, loader, loss_func, epoch, best_dice, training=True):
    model.eval()
    val_loss = 0
    val_dice = 0
    with torch.no_grad():
        for step, (image, mask) in tqdm(enumerate(loader)):
            image = image.to(device)
            mask = mask.to(device)
    
            outputs = model(image)
            loss = loss_func(outputs, mask)
            
            out_cut = np.copy(outputs.data.cpu().numpy())
            out_cut[np.nonzero(out_cut < 0.5)] = 0.0
            out_cut[np.nonzero(out_cut >= 0.5)] = 1.0
            dice = dice_coef_metric(out_cut, mask.data.cpu().numpy())
            
            val_loss += loss
            val_dice += dice
        
        val_mean_dice = val_dice / step
        val_mean_loss = val_loss / step
        
        if val_mean_dice > best_dice:
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(best_model_wts, f'{model_name}.pth')      
        
        if training:
            scheduler.step(val_mean_dice)
        
    return val_mean_dice, val_mean_loss

def train_model(train_loader, val_loader, loss_func, optimizer, scheduler, num_epochs):
    train_loss_history = []
    train_dice_history = []
    val_loss_history = []
    val_dice_history = []
    val_mean_dice = 0
    
    for epoch in range(num_epochs):
        train_dices, train_losses = train_loop(model, train_loader, loss_func)
        train_mean_dice = np.array(train_dices).mean()
        train_mean_loss = np.array(train_losses).mean()
        val_mean_dice, val_mean_loss = eval_loop(model, val_loader, loss_func, epoch, val_mean_dice)
        
        train_loss_history.append(np.array(train_losses).mean())
        train_dice_history.append(np.array(train_dices).mean())
        val_loss_history.append(val_mean_loss)
        val_dice_history.append(val_mean_dice)
        
        print('Epoch: {}/{} |  Train Loss: {:.3f}, Val Loss: {:.3f}, Train DICE: {:.3f}, Val DICE: {:.3f}'.format(epoch+1, num_epochs,
                                                                                                                 train_mean_loss,
                                                                                                                 val_mean_loss,
                                                                                                                 train_mean_dice,
                                                                                                                 val_mean_dice))
        

    return train_loss_history, train_dice_history, val_loss_history, val_dice_history

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3)
train_loss_history, train_dice_history, val_loss_history, val_dice_history = train_model(dataloaders['train'],
                                                                                         dataloaders['test'],
                                                                                         dice_coef_loss, 
                                                                                         optimizer,
                                                                                         scheduler, 11)


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 896/896 [28:06<00:00,  1.88s/it]
234it [05:49,  1.50s/it]


Epoch: 1/11 |  Train Loss: 0.203, Val Loss: 0.413, Train DICE: 0.797, Val DICE: 0.618


 56%|██████████████████████████████████████████████████████████████████▍                                                    | 500/896 [15:16<09:30,  1.44s/it]