# U-Net Model Training Script

**Author:** Prisca Dotti

**Last Edit:** 24.10.2023

This Jupyter Notebook contains the code for training a U-Net model on a dataset of sparks videos. The dataset is split into training and testing sets, and the model is trained using the training set. The testing set is used to evaluate the performance of the trained model.

To run the notebook, simply execute each cell in order.

In [1]:
# autoreload is used to reload modules automatically before entering the
# execution of code typed at the IPython prompt.
%load_ext autoreload
%autoreload 2
# To import modules from parent directory in Jupyter Notebook
import sys

sys.path.append("..")

In [2]:
import logging
import os
import torch
import random
import numpy as np

from torch import nn, optim

from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter

import wandb
from config import TrainingConfig, config
from models.UNet import unet
from utils.training_inference_tools import (
    MyTrainingManager,
    sampler,
    test_function,
    training_step,
    weights_init,
)
from utils.training_script_utils import (
    get_sample_ids,
    init_criterion,
    init_dataset,
    init_model,
)

logger = logging.getLogger(__name__)

In [3]:
# print logger configuation
logger



Important notes for thesis:
- after trying with and without weight init, I think I will keep U-Net weight init

In [4]:
##################### Get training-specific parameters #####################

# Initialize training-specific parameters
# (get the configuration file path from ArgParse)
config_filename = os.path.join("config_files", "config_final_model.ini")
# config_filename = os.path.join("config_files", "config_final_model_w_init.ini")
# config_filename = os.path.join("config_files", "config_final_model_TEST.ini")
# config_filename = os.path.join("config_files", "config_sin_channels.ini")
params = TrainingConfig(training_config_file=config_filename)

# Print parameters to console if needed
params.print_params()

######################### Initialize random seeds ##########################

# We used these random seeds to ensure reproducibility of the results

torch.manual_seed(0)  # <--------------------------------------------------!
random.seed(0)  # <--------------------------------------------------------!
np.random.seed(0)  # <-----------------------------------------------------!

[18:41:42] [  INFO  ] [   config   ] <316 > -- Loading C:\Users\dotti\Code\sparks_project\config_files\config_final_model.ini
[18:41:42] [  INFO  ] [   config   ] <562 > --     training_config_file: C:\Users\dotti\Code\sparks_project\config_files\config_final_model.ini
[18:41:42] [  INFO  ] [   config   ] <562 > --              dataset_dir: C:\Users\dotti\Code\sparks_project\data\sparks_dataset
[18:41:42] [  INFO  ] [   config   ] <562 > --                        c: <configparser.ConfigParser object at 0x000001EE4EB63C40>
[18:41:42] [  INFO  ] [   config   ] <562 > --                 run_name: final_model
[18:41:42] [  INFO  ] [   config   ] <562 > --            load_run_name: 
[18:41:42] [  INFO  ] [   config   ] <562 > --               load_epoch: 100000
[18:41:42] [  INFO  ] [   config   ] <562 > --             train_epochs: 100000
[18:41:42] [  INFO  ] [   config   ] <562 > --                criterion: lovasz_softmax
[18:41:42] [  INFO  ] [   config   ] <562 > --                 lr

In [5]:
# params.set_device("cpu")
# params.display_device_info()

In [6]:
############################ Configure datasets ############################

# Select samples for training and testing based on dataset size
train_sample_ids = get_sample_ids(
    train_data=True,
    dataset_size=params.dataset_size,
)
test_sample_ids = get_sample_ids(
    train_data=False,
    dataset_size=params.dataset_size,
)

# Initialize training dataset
dataset = init_dataset(
    params=params,
    sample_ids=train_sample_ids,
    apply_data_augmentation=True,
    load_instances=False,
)

# Initialize testing datasets
testing_dataset = init_dataset(
    params=params,
    sample_ids=test_sample_ids,
    apply_data_augmentation=False,
    load_instances=True,
)

[18:41:57] [  INFO  ] [utils.training_script_utils] <149 > -- Samples in dataset: 672
[18:42:38] [  INFO  ] [utils.training_script_utils] <149 > -- Samples in dataset: 158


In [7]:
# # traing only with the central chunk for each movie in the dataset
# dataset.source_dataset.set_debug_dataset()
# testing_dataset.set_debug_dataset()

# print("Dataset size: ", len(dataset))
# print("Testing dataset size: ", len(testing_dataset))

# # train with only one batch
# import numpy as np
# from torch.utils.data import Subset
# ids = list(np.arange(3, 3+params.batch_size, 1, dtype=np.int64))
# dataset = Subset(dataset, ids)

In [8]:
# Initialize data loaders
dataset_loader = DataLoader(
    dataset,
    batch_size=params.batch_size,
    shuffle=True,
    num_workers=params.num_workers,
    pin_memory=params.pin_memory,
)

In [9]:
############################## Configure UNet ##############################

# Initialize the UNet model
network = init_model(params=params)

# Move the model to the GPU if available
if params.device.type != "cpu":
    network = nn.DataParallel(network).to(params.device, non_blocking=True)
    # cudnn.benchmark = True

# Watch the model with wandb for logging if enabled
if params.wandb_log:
    wandb.watch(network)

# Initialize UNet weights if required
if params.initialize_weights:
    logger.info("Initializing UNet weights...")
    network.apply(weights_init)

# The following line is commented as it does not work on Windows
# torch.compile(network, mode="default", backend="inductor")

In [10]:
# next(iter(network.parameters())).device

In [11]:
########################### Initialize training ############################

# Initialize the optimizer based on the specified type
if params.optimizer == "adam":
    optimizer = optim.Adam(network.parameters(), lr=params.lr_start)
elif params.optimizer == "adadelta":
    optimizer = optim.Adadelta(network.parameters(), lr=params.lr_start)
elif params.optimizer == "sgd":
    optimizer = optim.SGD(network.parameters(), lr=params.lr_start)
else:
    logger.error(f"{params.optimizer} is not a valid optimizer.")
    exit()

# Initialize the learning rate scheduler if specified
if params.scheduler == "step":
    scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=params.scheduler_step_size,
        gamma=params.scheduler_gamma,
    )
else:
    scheduler = None

# Define the output directory path
output_path = os.path.join(config.output_dir, params.run_name)
logger.info(f"Output directory: {os.path.realpath(output_path)}")

# Initialize the summary writer for TensorBoard logging
summary_writer = SummaryWriter(os.path.join(output_path, "summary"), purge_step=0)

# Check if a pre-trained model should be loaded
if params.load_run_name != "":
    load_path = os.path.join(config.output_dir, params.load_run_name)
    logger.info(f"Model loaded from directory: {os.path.realpath(load_path)}")
else:
    load_path = None

# Initialize the loss function
criterion = init_criterion(params=params, dataset=dataset)

# Create a directory to save predicted class movies
preds_output_dir = os.path.join(output_path, "predictions")
os.makedirs(preds_output_dir, exist_ok=True)

# Create a dictionary of managed objects
managed_objects = {"network": network, "optimizer": optimizer}
if scheduler is not None:
    managed_objects["scheduler"] = scheduler

# Create a training manager with the specified training and testing functions
trainer = MyTrainingManager(
    # Training parameters
    training_step=lambda _: training_step(
        dataset_loader=dataset_loader,
        params=params,
        sampler=sampler,
        network=network,
        optimizer=optimizer,
        criterion=criterion,
        scheduler=scheduler,
        scaler=GradScaler(),
    ),
    save_every=params.c.getint("training", "save_every", fallback=5000),
    load_path=load_path,
    save_path=output_path,
    managed_objects=unet.managed_objects(managed_objects),
    # Testing parameters
    test_function=lambda _: test_function(
        network=network,
        device=params.device,
        criterion=criterion,
        params=params,
        testing_dataset=testing_dataset,
        training_name=params.run_name,
        output_dir=preds_output_dir,
        training_mode=True,
    ),
    test_every=params.c.getint("training", "test_every", fallback=1000),
    plot_every=params.c.getint("training", "test_every", fallback=1000),
    summary_writer=summary_writer,
)

# Load the model if a specific epoch is provided
if params.load_epoch != 0:
    trainer.load(params.load_epoch)

[18:42:39] [  INFO  ] [  __main__  ] < 26 > -- Output directory: C:\Users\dotti\Code\sparks_project\models\saved_models\final_model
[18:42:39] [  INFO  ] [utils.training_inference_tools] <1360> -- Loading 'C:\Users\dotti\Code\sparks_project\models\saved_models\final_model\network_100000.pth'...
[18:42:39] [  INFO  ] [utils.training_inference_tools] <1360> -- Loading 'C:\Users\dotti\Code\sparks_project\models\saved_models\final_model\optimizer_100000.pth'...


In [12]:
# torch.set_float32_matmul_precision("high")

In [13]:
############################## Start training ##############################

# Resume the W&B run if needed (commented out for now)
# if wandb.run.resumed:
#     checkpoint = torch.load(wandb.restore(checkpoint_path))
#     network.load_state_dict(checkpoint['model_state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#     epoch = checkpoint['epoch']
#     loss = checkpoint['loss']

In [17]:
# Validate the network before training if resuming from a checkpoint
if True:  # params.load_epoch > 0:
    logger.info("Validate network before training")
    trainer.run_validation(wandb_log=params.wandb_log)

[18:49:13] [  INFO  ] [  __main__  ] < 3  > -- Validate network before training
[18:49:13] [  INFO  ] [utils.training_inference_tools] <1207> -- Validating network at iteration 100000...
[18:49:13] [ DEBUG  ] [utils.training_inference_tools] <923 > -- Test function: running samples in UNet
[18:57:58] [ DEBUG  ] [utils.training_inference_tools] <951 > -- Time to run samples ['05', '10', '15', '20', '25', '32', '34', '40', '45'] in UNet: 524.97 s
[18:57:58] [ DEBUG  ] [utils.training_inference_tools] <955 > -- Test function: computing loss
[18:58:20] [ DEBUG  ] [utils.training_inference_tools] <975 > -- Processing sample 05
[18:58:20] [ DEBUG  ] [utils.training_inference_tools] <984 > -- Test function: saving raw predictions on disk
[18:58:20] [ DEBUG  ] [utils.training_inference_tools] <998 > -- Test function: re-organising annotations
[18:58:22] [ DEBUG  ] [utils.training_inference_tools] <1008> -- Time to re-organise annotations: 1.52 s
[18:58:22] [ DEBUG  ] [utils.training_inference_

In [None]:
# Set the network in training mode
network.train()

# Train the model for the specified number of epochs
logger.info("Starting training")
trainer.train(
    params.train_epochs,
    print_every=params.c.getint("training", "print_every", fallback=100),
    wandb_log=params.wandb_log,
)

In [None]:
logger.info("Starting final validation")
# Run the final validation/testing procedure
# trainer.run_validation(wandb_log=params.wandb_log)

[15:31:51] [  INFO  ] [  __main__  ] < 1  > -- Starting final validation


In [None]:
# Close the summary writer
summary_writer.close()

In [19]:
# Close the wandb run
if params.wandb_log:
    wandb.finish()

In [None]:
# For debugging purposes
# model_parameters = filter(lambda p: p.requires_grad, network.parameters())
# model_parameters = sum([np.prod(p.size()) for p in model_parameters])
# logger.debug(f"Number of trainable parameters: {model_parameters}")

In [None]:
# for load_epoch in [10000,20000,30000,40000,50000,60000,70000,80000,90000,100000]:
# for load_epoch in [100000]:
#     trainer.load(load_epoch)
#     logger.info("Starting final validation")
#     trainer.run_validation(wandb_log=wandb_log)
# if wandb_log:
#     wandb.finish()

### Visualize UNet architecture (for debugging)

In [None]:
# # get number of trainable parameters
# num_params = sum(p.numel() for p in network.parameters() if p.requires_grad)
# logger.debug(f"Number of trainable parameters: {num_params}")
# # get dummy unet input
# batch = next(iter(dataset_loader))
# x = batch[0].to(device)
# yhat = network(x[:,None]) # Give dummy batch to forward()
# from torchviz import make_dot
# make_dot(yhat, params=dict(list(network.named_parameters()))).render("unet_model", format="png")
# a = [0,1,2,3,4,5,6,7,8,9,10,11,12,13]

# len(a[0:4])