#  ISLES

In [None]:
import torch
import os
from utils.data_setup import ISLESDataSet, CreateDataSets
from utils.plots import *
print(torch.__version__)


## 1. Data

In [None]:
# make a short check
check = ISLESDataSet(max_size=1)
batch, seg, patid, fid = check[0]
print('Look at a sample dataset: ', batch.shape, seg.shape, patid, fid)

In [None]:
# Create training and testing datasets
train_ds, val_ds = CreateDataSets()

In [None]:
# Plot flair, DWI, T1, T2, and segmentation for one stratum. Output the patient ID and stratum number of the dataset.
plot_images(train_ds[60])

## 2. Training

In [None]:
def run_iteration(dataloader, ll, do_backprob=True):
    model, optimizer, loss_fn, device = ll[0], ll[1], ll[2], ll[3]
    loss_iter = 0
    for x, y, pid, sid in dataloader:
        # Data to device
        x, y = x.to(device), y.squeeze(dim=1).long().to(device)

        # 1. Forward pass
        y_logits = model(x) #.squeeze()
        y_pred = y_logits.softmax(dim=1).argmax(dim=1)

        #print(f"y_logits dtype: {y_logits.dtype} | y_true dtype: {y.dtype}")
        #print(f"y_logits shape: {y_logits.shape} | y_true shape: {y.shape}")
        # 2. Calculate the loss
        loss = loss_fn(y_logits, y)

        if do_backprob:
            # 3. Optimizer zero grad
            optimizer.zero_grad()
            # 4. Loss backward
            loss.backward()
            # 5. Optimizer step
            optimizer.step()

        loss_iter += loss.item()
      
    return loss_iter/len(dataloader)

In [None]:
# Training
from torch import optim
from torch.utils.data import DataLoader
import time
from model_builder import *

# set a flag which device to use ('cpu' or 'cuda', according to availability)
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'
print(f"Device: {device}")

# define training hyperparameters
batch_size = 32
num_epochs = 20  # later use 100 - 500 Epochs
lr = 0.1

# initializations
model = ISLESSegNet().to(device)
#model = UNet(4,2).to(device)
#model = torch.compile(model, backend="aot_eager") # aot_eager for mps

# uncomment for retraining previous model
#chkpt_file = '/content/checkpoints/isles01.pt'
#model.load_state_dict(torch.load(chkpt_file))

# CrossEntropyLoss works fine
loss_fn = nn.CrossEntropyLoss()

# BCEWithLogitsLoss
#loss_fn = nn.BCEWithLogitsLoss()

# DiceLoss achieved minimal better results and has less overfitting: Training 80.84%, Testing 72.44% (see Figure in last cell) 
#loss_fn = DiceLoss(include_background=False)

optimizer = optim.SGD(model.parameters(), lr=lr)

# Adam optimizer much better testing Dice???: Training 69.01%, Testing 75.28% (see Figure in last cell)
#optimizer = optim.Adam(model.parameters(), lr=lr)

ll = [model, optimizer, loss_fn, device]
tdl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
vdl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

is_better = True
prev_loss = [float('inf'), float('inf')]

epoch_loss = torch.zeros(num_epochs)
val_loss = torch.zeros(num_epochs)
save_path = "checkpoints"
os.makedirs(save_path, exist_ok=True)
# start timer for full training
t_end = time.time()
# training loop
for epoch in range(num_epochs):
    # start timer for epoch
    t_start = time.time()
    print('Epoch {} from {}'.format(epoch+1, num_epochs))

    model.train()
    epoch_loss[epoch] = run_iteration(tdl, ll)
  
    # Validation
    model.eval()
    with torch.inference_mode():
        val_loss[epoch] = run_iteration(vdl, ll, do_backprob=False)

        delta_epoch = time.time() - t_start
        
        # print the current epoch's training and validation mean loss
        print('[{}] Training loss: {:.4f}'.format(epoch+1, epoch_loss[epoch]))
        print('[{}] Validation Loss: {:.4f}\t Time: {:.2f}s'.format(epoch+1, val_loss[epoch], delta_epoch))

        # check if current epoch's losses are better then best saved
        is_better = epoch_loss[epoch] < prev_loss[0] and val_loss[epoch] <= prev_loss[1]
        if is_better:
            # update best training and validation losses
            prev_loss[0] = epoch_loss[epoch]
            prev_loss[1] = val_loss[epoch]
            # save best model
        if epoch > 15:
            torch.save(model.state_dict(), './checkpoints/isles01.pt')
            print("\033[91m {}\033[00m" .format("Saved best model"))
t_end = time.time() - t_end
print('Finished Training in {:.2f} seconds'.format(t_end))

In [None]:
# save manually
torch.save(model.state_dict(), './checkpoints/isles01.pt')
print("\033[91m {}\033[00m" .format("Saved best model"))

In [None]:
# visualize the losses
plot_loss(epoch_loss, val_loss)

In [None]:
# plot prediction and ground trouth
plot_pred(model, train_ds[55], device)