In [9]:
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm
from pathlib import Path
import os

import numpy as np
import pandas as pd
from dpipe.io import load_numpy

import torch.optim as optim
import pandas as pd
import torch.nn as nn
from torch.nn import functional
import torch

from data_loader import BraTSDataset
from unet import Unet
from training_utils import train_step

import warnings
warnings.filterwarnings("ignore")

## Create df

In [2]:
data_folder = Path('/Users/alisakugusheva/Desktop/HSE/Project/brats_slices_2/')
# df = []

# for path, _, files in tqdm(os.walk(data_folder)):
#     for file in files:
        
#         subject_id = path.split('/')[-1].split('_')[-1]
#         slice_id = file.split('.')[0].split('_')[0]
#         sample_id = f"{subject_id}_{slice_id}" # SubjectID_SliceIndex
#         is_mask = 'mask' in file
#         if is_mask:
#             mask = load_numpy(Path(path) / file, allow_pickle=True)
#             is_nonzero_mask =  np.any(mask)
#         else:
#             is_nonzero_mask = np.nan
        
#         df.append([Path(Path(path).stem) / file, sample_id, is_mask, subject_id, is_nonzero_mask])
        
# df = pd.DataFrame(df, columns = ['relative_path', 'sample_id', 'is_mask', 'subject_id', 'is_nonzero_mask'])
# print(df.is_nonzero_mask.value_counts())

# df.to_csv(data_folder / 'meta.csv')

In [3]:
df_folder = Path('/Users/alisakugusheva/Desktop/HSE/Project')
df = pd.read_csv(df_folder / 'meta.csv', index_col=0)

df = df.sort_values(by=['subject_id'], ignore_index=True)

train_size = int(0.8 * df.shape[0])
val_size = df.shape[0] - train_size

border_id = df['subject_id'][train_size]

train_df = df[df['subject_id'] < border_id]
val_df = df[df['subject_id'] >= border_id]

## Dataloader

In [4]:
from torchvision import transforms
import numpy as np
from skimage.transform import rotate

In [5]:
def random_crop(sample):
    
    image, mask = sample
    delta_h, delta_w = 20, 20
    h, w = image.shape
    new_h, new_w = 120, 120
    top = np.random.randint(0, h - delta_h - new_h)
    left = np.random.randint(0, w - delta_w - new_w)

    image = image[top: top + new_h,
                  left: left + new_w]
    
    mask = mask[top: top + new_h,
                  left: left + new_w]
    
    return image, mask

def random_rotate(sample):
    
    image, mask = sample
    angles = [0, 90, 180, 270]
    np.random.shuffle(angles)
    angle = angles[0]    
    return rotate(image, angle), rotate(mask, angle)
    
def to_tensor(sample):
    image, mask = sample
    image = image.reshape(1, image.shape[0], image.shape[1])
    
    return torch.from_numpy(image), np.sum(mask, axis=(0,1)).astype(bool)

In [6]:
train_transform = transforms.Compose([
                    random_crop,
                    random_rotate,
                    to_tensor
            ])
val_transform = to_tensor

In [10]:
 # '/home/anvar/work/data/brats_slices/'
# data_folder = Path('/Users/alisakugusheva/Desktop/HSE/Project/brats_slices/')
# df = pd.read_csv(data_folder / 'meta.csv', index_col=0)
# df['relative_path'] = df['relative_path'] + '.gz'

# train_dataset = BraTSDataset(train_df, data_folder, nonzero_mask=True, transform=train_transform)
# val_dataset = BraTSDataset(val_df, data_folder, nonzero_mask=True, transform=val_transform)

train_dataset = BraTSDataset(train_df, data_folder, nonzero_mask=True)
val_dataset = BraTSDataset(val_df, data_folder, nonzero_mask=True)


train_loader = torch.utils.data.DataLoader(train_dataset,
                                             batch_size=128, shuffle=True,
                                             num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=128, shuffle=False,
                                             num_workers=2)

## Train model

In [11]:
device = ("cuda" if torch.cuda.is_available() else 'cpu')
model = Unet().to(device)

criterion = nn.BCEWithLogitsLoss() 
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
for epoch in range(5):
    epoch_loss = 0
    for X_batch, y_batch in tqdm(train_loader):

        loss = train_step(X_batch, y_batch, model, criterion, optimizer)
        
        epoch_loss += loss     

    print(f'Epoch {epoch+0:03}: | Loss: {epoch_loss/len(train_loader):.5f}')

## Save model parameters

https://pytorch.org/tutorials/beginner/saving_loading_models.html

In [None]:
# torch.save(model.state_dict(), './unet.pth')

## Validate model's quality

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np

_ = model.load_state_dict(torch.load('/Users/alisakugusheva/Desktop/HSE/Project/unet.pth'))
_ = model.eval()

x, y = val_dataset[1000]
y_pred = np.exp(model(x.reshape(1,1,240,240).to('cuda'))[0].to('cpu').detach().numpy()) > 0.5

fig, ax = plt.subplots(1, 3, figsize=(15, 5))

ax[0].imshow(x[0, ...], cmap='gray')
ax[0].set_title('Image')
ax[1].imshow(y[0, ...], cmap='gray')
ax[1].set_title('Ground truth mask')
ax[2].imshow(y_pred[0, ...], cmap='gray')
ax[2].set_title('Predicted mask');

## Slice-wise dice

In [None]:
from dpipe.im.metrics import dice_score

In [None]:
dice = []
for i in tqdm(range(len(val_dataset))):
    x, y = val_dataset[i]
    y = y.detach().numpy().astype(bool)
    y_pred = np.exp(model(x.reshape(1,1,240,240).to('cuda'))[0].to('cpu').detach().numpy()) > 0.1
    dice.append(dice_score(y, y_pred))

In [None]:
np.mean(dice), np.std(dice)

# TODO

1. Compute patient-wise DICE score
2. Do cross-validation and compute DICE score on test set
3. Train larger model longer and achieve atleast 0.5 DICE score on test set