# 2D UNet training for stroke detection in MRI
* one subject less than in ATLAS R2.0 - upload problem

## Importing libraries, dividing annotated data into training, validation and test datasets

In [None]:
import torch
import torchmetrics
from sklearn.utils.class_weight import compute_class_weight 
import matplotlib.pyplot as plt
import numpy as np
import math,os,sys
import random
import copy
from pdb import set_trace
from tqdm import tqdm

# then import my classes
from loaders import ISLES_Dataset_2D
from unet import UNet
from resunet import ResUNet
from loss_fcns import TverskyLoss,CrossEntropyDiceLoss,DiceLoss2D

# %% device for pytorch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

## Code for generating the dataset lists
This cell will produce 3 lists of folder names for training, validation and test datasets. It uses the whole ATLAS v2.0 dataset.

In [None]:
#%% get folder content of annotated data
imgPath = r'/mnt/Data/ondrejnantl/DPData/train/derivatives/ATLAS'
imgAnnot = sorted(os.listdir(imgPath))
imgAnnot.remove(imgAnnot[0])

valCount = math.floor(0.15*len(imgAnnot))
testCount = math.floor(0.25*len(imgAnnot))
trainCount = len(imgAnnot) - (valCount+testCount)

randIdx = random.sample(imgAnnot,len(imgAnnot))

# splitting subjects into training, validation and testing part
valIdx = randIdx[0:(valCount)]
testIdx = randIdx[valCount:(valCount+testCount)]
trainIdx = randIdx[(valCount+testCount):len(imgAnnot)+1]

print('ID lists created')

This cell will produce 3 lists of folder names for training, validation and test datasets. It uses only cohorts R002, R003, R004, R034 cohorts of ATLAS v2.0 dataset.

In [None]:
import re

# get folder content of annotated data
imgPath = r'/mnt/Data/ondrejnantl/DPData/train/derivatives/ATLAS'
imgAnnot = sorted(os.listdir(imgPath))
imgAnnot.remove(imgAnnot[0])

# defining cohort to use for training
cohorts = ["^sub-r009.*$"]
selImgAnnot = []
for it in cohorts:
    r = re.compile(it)
    tempList = list(filter(r.match, imgAnnot))
    selImgAnnot = selImgAnnot + tempList

valCount = math.floor(0.15*len(selImgAnnot))
testCount = math.floor(0.25*len(selImgAnnot))
trainCount = len(selImgAnnot) - (valCount+testCount)

randIdx = random.sample(selImgAnnot,len(selImgAnnot))

# splitting subjects into training, validation and testing part
valIdx = randIdx[0:(valCount)]
testIdx = randIdx[valCount:(valCount+testCount)]
trainIdx = randIdx[(valCount+testCount):len(selImgAnnot)+1]

print('ID lists created')

## Load datasets' lists
Loading object IDs from division of data - for all data

In [None]:
imgPath = r'/mnt/Data/ondrejnantl/DPData/train/derivatives/ATLAS'

trainFd= open(r'/mnt/Data/ondrejnantl/DPData/trainNames.txt', 'r')
trainIdx = trainFd.read()
trainIdx = trainIdx.splitlines()
trainFd.close()

valFd= open(r'//mnt/Data/ondrejnantl/DPData/valNames.txt', 'r')
valIdx = valFd.read()
valIdx = valIdx.splitlines()
valFd.close()

testFd= open(r'/mnt/Data/ondrejnantl/DPData/testNames.txt', 'r')
testIdx = testFd.read()
testIdx = testIdx.splitlines()
testFd.close()

print('IDs loaded')

Load subject IDs for cohort R009

In [None]:
imgPath = r'/mnt/Data/ondrejnantl/DPData/train/derivatives/ATLAS'

trainFd= open(r'/mnt/Data/ondrejnantl/DPData/trainNamesSP.txt', 'r')
trainIdx = trainFd.read()
trainIdx = trainIdx.splitlines()
trainFd.close()

valFd= open(r'//mnt/Data/ondrejnantl/DPData/valNamesSP.txt', 'r')
valIdx = valFd.read()
valIdx = valIdx.splitlines()
valFd.close()

testFd= open(r'/mnt/Data/ondrejnantl/DPData/testNamesSP.txt', 'r')
testIdx = testFd.read()
testIdx = testIdx.splitlines()
testFd.close()

print('IDs loaded')

## Get dataset for 2D

In [None]:
tr_ds = ISLES_Dataset_2D(
    imgPath = imgPath,
    imgAnnot = trainIdx[0:15],
    sliceCount = 189,
    cropped = True,
    downsample = True
    )
tr_ds_dl = torch.utils.data.DataLoader(tr_ds,batch_size=64,shuffle=True)

val_ds = ISLES_Dataset_2D(
    imgPath = imgPath,
    imgAnnot = valIdx[0:7],
    sliceCount = 189,
    cropped = True,
    downsample = True
    )
val_ds_dl = torch.utils.data.DataLoader(val_ds,batch_size=64,shuffle=False)

ts_ds = ISLES_Dataset_2D(
    imgPath = imgPath,
    imgAnnot = testIdx[0:10],
    sliceCount = 189,
    cropped = True,
    downsample = True
    )
ts_ds_dl = torch.utils.data.DataLoader(ts_ds,batch_size=64,shuffle=False)

print('Datasets and dataloaders created')

## Training the net and validation of results

In [None]:
#%% U-Net
# net = ResUNet(1,2,filters=[32,64,128,256])
net = UNet(1,2,filters=[32,64,128,256])
net.to(device)

#%% training and validation
# loss_f = torch.nn.CrossEntropyLoss(weight = torch.Tensor([0.003, 0.997]).to(device))
loss_f = DiceLoss2D()
# loss_f = CrossEntropyDiceLoss(weights=[0.003,0.997])

net.eval()
opt = torch.optim.Adam(net.parameters(),lr=4e-5,weight_decay=5e-6)
# scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=15, gamma=0.1)

tr_losses = []
val_losses = []
dice_scores = []
tr_dice_scores = []
best_net = []
learning_rates = []
best_dice = 0.0
start_epoch = 1
end_epoch = 201

# loading checkpoint when resuming
# checkpoint = torch.load("./UNet2Dclassiccheckpoint1012.pt")
# net.load_state_dict(checkpoint['state_dict'])
# opt.load_state_dict(checkpoint['optimizer'])
# start_epoch = checkpoint['epoch']+ start_epoch
# end_epoch = checkpoint['epoch'] + end_epoch
# tr_losses = checkpoint['tr_losses']
# val_losses = checkpoint['val_losses']
# tr_dice_scores = checkpoint['tr_dice_scores']
# dice_scores = checkpoint['dice_scores']

with torch.autograd.set_detect_anomaly(True):
    for epoch in tqdm(range(start_epoch,end_epoch)):
        tr_loss = 0.0
        val_loss = 0.0
        tr_dice = 0.0
        dice = 0.0

        # iteration through all training batches - forward pass, backpropagation, optimalization, performance evaluation
        net.train()
        print('\n Epoch {}: First training batch loaded'.format(epoch))
        for img,lbl in tr_ds_dl: 
            img,lbl = img.to(device),lbl.to(device)

            # calculating of weights for Cross Entropy 
#             pos_weight = 1-(lbl.sum()/torch.numel(lbl))
#             loss_weights = [1-pos_weight.item(),pos_weight.item()]
#             if loss_weights == [0.0,1.0]: loss_weights = [0.01,0.99]
#             loss_weights = compute_class_weight(class_weight = "balanced", classes= np.unique(lbl.cpu().numpy().flatten()), y=lbl.cpu().numpy().flatten())
#             loss_f = torch.nn.CrossEntropyLoss(weight = torch.Tensor(loss_weights).to(device))
#             loss_f = CrossEntropyDiceLoss(weights=loss_weights)

            pred = net(img)

            loss = loss_f(pred,lbl)
            opt.zero_grad()
            loss.backward()
            opt.step()

            tr_loss+=loss.item()
            pred = torch.argmax(torch.softmax(pred,dim=1),dim=1)
            tr_dice += torchmetrics.functional.dice(pred,lbl,ignore_index = 0)

        print('All training batches used')
        
        # showing improvement in detection through epochs
        if epoch>0 and epoch % 5 == 0:
            fig, ax = plt.subplots(1, 2)
            ax[0].imshow(torch.rot90(img[32,0,:,:].cpu()),cmap = 'gray')
            ax[0].imshow(torch.rot90(pred[32,:,:].cpu()),alpha=0.5,cmap = 'copper')
            ax[0].set_aspect(img.shape[2]/img.shape[3])
            ax[0].set_title('Predikce')
            ax[0].set_axis_off()
            ax[1].imshow(torch.rot90(img[32,0,:,:].cpu()),cmap = 'gray')
            ax[1].imshow(torch.rot90(lbl[32,:,:].cpu()),alpha=0.5,cmap = 'copper')
            ax[1].set_aspect(img.shape[2]/img.shape[3])
            ax[1].set_title('Zlatý standard')
            ax[1].set_axis_off()
            plt.show()
    
        # iteration through all validation batches - forward pass, performance estimation
        net.eval()
        print('Epoch {}: First validation batch loaded'.format(epoch))
        with torch.no_grad():
            for img,lbl in val_ds_dl:
                img,lbl = img.to(device),lbl.to(device)

                # calculating loss weights for Cross Entropy
#                 pos_weight = 1-(lbl.sum()/torch.numel(lbl))
#                 loss_weights = [1-pos_weight.item(),pos_weight.item()]
#                 if loss_weights == [0.0,1.0]: loss_weights = [0.01,0.99]
#                 loss_weights = compute_class_weight(class_weight = "balanced", classes= np.unique(lbl.cpu().numpy().flatten()), y=lbl.cpu().numpy().flatten())
#                 loss_f = torch.nn.CrossEntropyLoss(weight = torch.Tensor(loss_weights).to(device))
#                 loss_f = CrossEntropyDiceLoss(weights=loss_weights)
                
                pred=net(img)
                
                loss=loss_f(pred,lbl)

                val_loss+=loss.item()
                pred = torch.argmax(torch.softmax(pred,dim=1),dim=1)
                dice += torchmetrics.functional.dice(pred,lbl,ignore_index = 0)

                if math.isnan(val_loss):
                    print("Check variables")
                    set_trace()

        print('\n All validation batches used')
        tr_loss=tr_loss/len(tr_ds_dl)
        val_loss=val_loss/len(val_ds_dl)
        tr_dice = tr_dice/len(tr_ds_dl)
        dice = dice/len(val_ds_dl)

        tr_losses.append(tr_loss)
        val_losses.append(val_loss)
        tr_dice_scores.append(tr_dice.detach().cpu().numpy())
        dice_scores.append(dice.detach().cpu().numpy())

        # scheduler.step()

        # saving the best model
        if dice>best_dice:
            best_net = copy.deepcopy(net)
            best_dice = copy.deepcopy(dice)

        print('Epoch {}; LR: {}; Train Loss: {:.4f}; Valid Loss: {:.4f}; Train Dice coeff: {:.4f}; Valid Dice coeff: {:.4f}'.format(epoch,opt.param_groups[0]['lr'],tr_loss,val_loss,tr_dice,dice))


## Saving the model

Saving checkpoint for resuming the training

In [5]:
checkpoint = {
    'epoch': epoch,
    'state_dict': net.state_dict(),
    'optimizer': opt.state_dict(),
    'tr_losses': tr_losses,
    'val_losses': val_losses,
    'tr_dice_scores': tr_dice_scores,
    'dice_scores': dice_scores
}
torch.save(checkpoint, "./2DUNetcheckpoint.pt")

Code for saving the whole model

In [7]:
torch.save(best_net,'./2DUNet.pb')

## Plot results of training and validation

In [None]:
plt.plot(tr_losses)
plt.plot(val_losses)
plt.legend(['tr_loss','val_loss'])
plt.xlabel("Epochy")
plt.ylabel("Kriteriální funkce")
plt.title('Kriteriální funkce - ATLAS R2.0 dataset')
plt.show()
plt.plot(tr_dice_scores)
plt.plot(dice_scores)
plt.legend(['tr_dice','val_dice'])
plt.xlabel("Epochy")
plt.ylabel("DSC")
plt.title('DSC- ATLAS R2.0 dataset')
plt.show()

## Evaluate one image

In [None]:
# loading data for prediction and prediction
image,label = tr_ds.__getitem__(510)
image,label = image.to(device),label.to(device)
net.eval()
prediction = torch.argmax(torch.softmax(net(torch.unsqueeze(image,dim=0)),dim=1),dim=1)

# plotting detection
fig, ax = plt.subplots(1, 2)
ax[0].imshow(torch.rot90(image[0,:,:].cpu()),cmap = 'gray')
ax[0].imshow(torch.rot90(prediction[0,:,:].cpu()),alpha=0.5,cmap = 'copper')
ax[0].set_aspect(image.shape[1]/image.shape[2])
ax[0].set_title('Predikce')
ax[0].set_axis_off()
ax[1].imshow(torch.rot90(image[0,:,:].cpu()),cmap = 'gray')
ax[1].imshow(torch.rot90(label[:,:].cpu()),alpha=0.5,cmap = 'copper')
ax[1].set_aspect(image.shape[1]/image.shape[2])
ax[1].set_title('Zlatý standard')
ax[1].set_axis_off()

## Evaluate performance on test dataset

In [None]:
ts_loss = 0.0
ts_dice = 0.0
net.eval()

with torch.no_grad():
    for img,lbl in ts_ds_dl:
        # loading of test batch
        img,lbl = img.to(device),lbl.to(device)
        
        # calculating loss weights for cross entropy
#         pos_weight = 1-(lbl.sum()/torch.numel(lbl))
#         loss_weights = [1-pos_weight.item(),pos_weight.item()]
#         if loss_weights == [0.0,1.0]: loss_weights = [0.01,0.99]
#         loss_f = torch.nn.CrossEntropyLoss(weight = torch.Tensor(loss_weights).to(device))
#        loss_f = CrossEntropyDiceLoss(weights=loss_weights)
        
        # prediction
        pred=net(img)
        
        # performance estimation
        loss=loss_f(pred,lbl)
        ts_loss+=loss.item()
        pred = torch.argmax(torch.softmax(pred,dim=1),dim=1)
        ts_dice += torchmetrics.functional.dice(pred,lbl,ignore_index = 0)

ts_loss = ts_loss/len(ts_ds_dl)
ts_dice = ts_dice.detach().cpu().numpy()/len(ts_ds_dl)

print('Test Loss: {:.4f}; Test Dice coeff: {:.4f}'.format(ts_loss,ts_dice))