In [1]:
# Import libraries

import os
import sys
import glob

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

import numpy as np
from datetime import datetime

import wandb  # Import Weights and Biases for tracking model training

# Include src directory in path to import custom modules
if '..\\src' not in sys.path:
    sys.path.append('..\\src')

from models.model_unet import UNet3D
from utils.utils import read_hyperparams
from data.datasets import KneeSegDataset3D
from models.evaluation import bce_dice_loss #, dice_coefficient, batch_dice_coeff
from models.train import train_loop, validation_loop 


In [2]:
sys.path

['c:\\Users\\james\\Documents\\projects\\pred-knee-replacement-oai\\notebooks',
 'c:\\Users\\james\\miniconda3\\envs\\pred-knee-replacement-oai\\python312.zip',
 'c:\\Users\\james\\miniconda3\\envs\\pred-knee-replacement-oai\\DLLs',
 'c:\\Users\\james\\miniconda3\\envs\\pred-knee-replacement-oai\\Lib',
 'c:\\Users\\james\\miniconda3\\envs\\pred-knee-replacement-oai',
 '',
 'c:\\Users\\james\\miniconda3\\envs\\pred-knee-replacement-oai\\Lib\\site-packages',
 'c:\\Users\\james\\miniconda3\\envs\\pred-knee-replacement-oai\\Lib\\site-packages\\win32',
 'c:\\Users\\james\\miniconda3\\envs\\pred-knee-replacement-oai\\Lib\\site-packages\\win32\\lib',
 'c:\\Users\\james\\miniconda3\\envs\\pred-knee-replacement-oai\\Lib\\site-packages\\Pythonwin',
 '..\\src']

In [3]:
os.getcwd()

'c:\\Users\\james\\Documents\\projects\\pred-knee-replacement-oai\\notebooks'

In [4]:
# Define data directory
DATA_DIRECTORY = 'C:/Users/james/OneDrive - University of Leeds/1. Projects/1.1 PhD/1.1.1 Project/Data/OAI Subset'
DATA_TRAIN_DIRECTORY = 'C:/Users/james/OneDrive - University of Leeds/1. Projects/1.1 PhD/1.1.1 Project/Data/OAI Subset/train'
DATA_VALID_DIRECTORY = 'C:/Users/james/OneDrive - University of Leeds/1. Projects/1.1 PhD/1.1.1 Project/Data/OAI Subset/valid'

DATA_RAW_DIRECTORY = '../data/raw'
DATA_PROCESSED_DIRECTORY = '../data/processed'
DATA_INTERIM_DIRECTORY = '../data/processed'

RESULTS_PATH = '../results'
MODELS_PATH = '../models'
MODELS_CHECKPOINTS_PATH = '../models/checkpoints'


In [5]:
# Set Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
# Read in hyperparams
hyperparams = read_hyperparams('..\src\models\hyperparams_unet.txt')
print(hyperparams)


{'run_name': 'high_lr_40_epoch', 'l_rate': 0.001, 'num_epochs': 40.0, 'batch_size': 4.0, 'threshold': 0.5, 'transforms': 'True'}


  hyperparams = read_hyperparams('..\src\models\hyperparams_unet.txt')


In [7]:
# Get paths for training and and validation data
# Get the paths

# Return file name from filepath
train_paths = np.array([os.path.basename(i).split('.')[0] for i in glob.glob(f'{DATA_TRAIN_DIRECTORY}/*.im')])
val_paths = np.array([os.path.basename(i).split('.')[0] for i in glob.glob(f'{DATA_VALID_DIRECTORY}/*.im')])

In [8]:
# Set transforms

if hyperparams['transforms'] == "True":
    # Let's try a horizontal flip transform
    transform = transforms.functional.hflip
else:
    transform = None

In [9]:
# Define PyTorch datasets and dataloader

# Define datasets
train_dataset = KneeSegDataset3D(train_paths, DATA_DIRECTORY, transform=transform)
validation_dataset = KneeSegDataset3D(val_paths, DATA_DIRECTORY, split='valid')

# Define dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=int(hyperparams['batch_size']), num_workers = 1, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=2, num_workers = 1, shuffle=False)

In [10]:
a = iter(train_dataloader)

In [11]:
test = next(a)

In [12]:
test[0].shape

torch.Size([4, 1, 200, 256, 160])

In [13]:
# Create model
model = UNet3D(1, 1, 16)

In [14]:
# Specifiy criterion and optimiser
loss_fn = bce_dice_loss
l_rate = hyperparams['l_rate']
optimizer = optim.Adam(model.parameters(), lr=l_rate)

In [15]:
# How long to train for?
num_epochs = int(hyperparams['num_epochs'])

In [29]:
# Threshold for predicted segmentation mask
pred_threshold = hyperparams['threshold']

In [18]:
# start a new wandb run to track this script - LOG IN ON CONSOLE BEFORE RUNNING
wandb.init(
    # set the wandb project where this run will be logged
    project="oai_subset_knee_seg_unet",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": l_rate,
    "architecture": "3D UNet",
    "kernel_num": 16,
    "dataset": "IWOAI",
    "epochs": num_epochs,
    "threshold": threshold,
    }
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mjamesbattye0[0m ([33mjamesbattye0-university-of-leeds7616[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [19]:
model.to(device)

# use multiple gpu in parallel if available
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

In [30]:
# Define model training fucntion using previously defined training and validation loops

# Capture training start time for output data files
train_start = str(datetime.now())
train_start_file = train_start.replace(" ", "-").replace(".","").replace(":","_")

# # Initialise early stopping criteria
# early_stopper = EarlyStopper(patience=4, min_delta=0.02)

# Initialise training stats 
# stats = {"epoch": [], "train_loss": [], "validation_loss": [], "train_accuracy": [], "validation_accuracy": []}

min_validation_loss = float('inf')

print(f"TRAINING MODEL \n-------------------------------")

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")

    train_loss, avg_train_dice = train_loop(train_dataloader, device, model, loss_fn, optimizer, pred_threshold)
    validation_loss, avg_validation_dice = validation_loop(validation_dataloader, device, model, loss_fn, pred_threshold)

    # log to wandb
    wandb.log({"Train Loss": train_loss, "Train Dice Score": avg_train_dice,
                  "Val Loss": validation_loss, "Val Dice Score": avg_validation_dice})
    
    # save as best if val loss is lowest so far
    if validation_loss < min_validation_loss:
        print(f'Validation Loss Decreased({min_validation_loss:.6f}--->{validation_loss:.6f}) \t Saving The Model')
        model_path = f"{MODELS_CHECKPOINTS_PATH}/{hyperparams['run_name']}_best_E.pth"
        torch.save(model.state_dict(), model_path)
        print(f"Best epoch yet: {epoch}")
        
        # reset min as current
        min_validation_loss = validation_loss


# Once training is done, save final model
model_path = f"{MODELS_CHECKPOINTS_PATH}/{hyperparams['run_name']}.pth"
torch.save(model.state_dict(), model_path)

wandb.finish()


print("Done!")


TRAINING MODEL 
------------------------
Epoch 1
-------------------------------
