In [24]:
import random
from pathlib import Path

import utils_unet_vnc 
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import numpy as np
import SimpleITK as sitk
import torchvision.transforms as transforms
import torch.nn as nn
import torch
import os
import unet_vnc

In [25]:
# random seed: ensure reproducible training/validation split
random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# directory with data
DATA_DIR = Path.cwd()/'Data_for_model'

# Directory to save the weights of the model
CHECKPOINTS_DIR = Path.cwd() / "segmentation_model_weights"
CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)

# Directory where tensorBoard logs will be saved for each run 
TENSORBOARD_LOGDIR = Path.cwd() /"segmentation_runs" 
TENSORBOARD_LOGDIR.mkdir(parents=True, exist_ok=True)
# print(TENSORBOARD_LOGDIR)

# training settings and hyperparameters
NO_VALIDATION_SCANS = 2
IMAGE_SIZE = [64, 64]
BATCH_SIZE = 15
N_EPOCHS = 50
LEARNING_RATE = 1e-4
TOLERANCE = 0.05  # for early stopping

In [26]:
# Data preparation
VNC_scans = [
    location
    for location in DATA_DIR.glob("*")
]
random.shuffle (VNC_scans)

name_VNC_scan = [os.path.basename(os.path.normpath(str(VNC_scan))) for VNC_scan in VNC_scans]

# Split into training/validation dataset after shuffling
partition = {"train": VNC_scans[:-NO_VALIDATION_SCANS],
    "validation": VNC_scans[-NO_VALIDATION_SCANS:],}


# Load training data: create DataLoader with batching and shuffling 
training_dataset = utils_unet_vnc.VNCDataset(partition["train"], IMAGE_SIZE)
training_dataloader = DataLoader(
    training_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    pin_memory=True,)

# Load validation data
valid_dataset = utils_unet_vnc.VNCDataset(partition["validation"], IMAGE_SIZE)
valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    pin_memory=True,)

In [27]:
# Set model, optimiser and loss function

# Loss function: combination of dice score and binary cross entropy
loss_function = utils_unet_vnc.DiceBCELoss() 

# import unet model for vnc segementation
unet_model = unet_vnc.UNet().to(device)

# Adam optimizer
optimizer = torch. optim.Adam(unet_model.parameters(), lr = LEARNING_RATE)

# minimum_validation_loss = 10
minimum_valid_loss = 10
# keep track of training process
writer = SummaryWriter(log_dir = TENSORBOARD_LOGDIR)


#PS C:\Users\eencinas> tensorboard --logdir="H:\CACS_VNC_2907\segmentation_runs"


# Training loop
for epoch in range (N_EPOCHS):
    current_train_loss = 0.0
    current_valid_loss = 0.0

    for batch in tqdm(training_dataloader):
        inputs, targets, VNCscan_names = batch
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad() # zero gradients
        outputs = unet_model(inputs) # forward pass of the model
        loss = loss_function(outputs.float(), targets.float()) # Compute loss
        loss.backward() # backpropagate
        optimizer.step() # update the parameters of the optimizer using the computed gradients

        current_train_loss += loss.item() # update total loss

        for VNCscan_name in VNCscan_names:
            print(f"Training...processing scan: {VNCscan_name}")
    
    # Evaluate validation loss
    with torch.no_grad():
        unet_model.eval()
        for batch in tqdm(valid_dataloader):
            inputs, targets, VNCscan_names = batch
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = unet_model(inputs)
            loss = loss_function(outputs.float(), targets.float())
            current_valid_loss += loss.item()

            for VNCscan_name in VNCscan_names:
                print(f"Validation...processing scan: {VNCscan_name}")

        unet_model.train()

    # print the results
    print(f'EPOCH: {epoch+1:0>{len(str(N_EPOCHS))}}/{N_EPOCHS}',end=' ')
    # write to tensorboard log
    writer.add_scalar("Loss/train", current_train_loss / len(training_dataloader), epoch)
    writer.add_scalar("Loss/validation", current_valid_loss / len(valid_dataloader), epoch)
    # if validation loss is improving, save model checkpoint # only start saving after 10 epochs
    if (current_valid_loss / len(valid_dataloader)) < minimum_valid_loss + TOLERANCE:
        minimum_valid_loss = current_valid_loss / len(valid_dataloader)
        weights_dict = {k: v.cpu() for k, v in unet_model.state_dict().items()}
        if epoch > 9:
            torch.save(weights_dict, CHECKPOINTS_DIR / f"u_net_{epoch}.pth",)

writer.close()

  0%|          | 0/537 [00:00<?, ?it/s]


ValueError: not enough values to unpack (expected 3, got 2)

In [None]:
# Check how images look when corpped to 64
# check unet in 3D