In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging
import os

import numpy as np
import torch
import wandb
from torch import nn, optim
import random

# from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from utils.training_inference_tools import (
    sampler,
    test_function,
    training_step,
    weights_init,
    myTrainingManager,
)
from utils.training_script_utils import (
    init_param_config_logs,
    init_training_dataset,
    init_testing_dataset,
    init_model,
    init_criterion,
)
from models.UNet import unet

logger = logging.getLogger(__name__)
torch.set_float32_matmul_precision("high")



Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet



In [3]:
############################# fixed parameters #############################

# General params
logfile = None  # change this when publishing finished project on github
# wandb_project_name = "sparks2"
output_relative_path = "runs"  # directory where output, saved params and
# testing results are saved

# Dataset parameters
ignore_index = 4  # label ignored during training
num_classes = 4  # i.e., BG, sparks, waves, puffs
ndims = 3  # using 3D data

debug_mode = True

In [4]:
############################## get parameters ##############################

c, params, wandb_log = init_param_config_logs(
    basedir=None,
    verbosity=3 if debug_mode else 2,
    wandb_project_name="TEST",
    config_file="config_final_model_blur_frames_nll_loss.ini"
    # config_file="config_final_model.ini"
)

[15:43:09] [  INFO  ] [training_script_utils] < 97 > -- Loading config_files\config_final_model_blur_frames_nll_loss.ini
[15:43:09] [  INFO  ] [training_script_utils] <203 > -- Command parameters:
[15:43:09] [  INFO  ] [training_script_utils] <205 > --                 run_name: final_model_blur_frames_nll_loss_2
[15:43:09] [  INFO  ] [training_script_utils] <205 > --            load_run_name: None
[15:43:09] [  INFO  ] [training_script_utils] <205 > --               load_epoch: 0
[15:43:09] [  INFO  ] [training_script_utils] <205 > --             train_epochs: 100000
[15:43:09] [  INFO  ] [training_script_utils] <205 > --                criterion: nll_loss
[15:43:09] [  INFO  ] [training_script_utils] <205 > --                 lr_start: 0.0001
[15:43:09] [  INFO  ] [training_script_utils] <205 > --       ignore_frames_loss: 6
[15:43:09] [  INFO  ] [training_script_utils] <205 > --                     cuda: True
[15:43:09] [  INFO  ] [training_script_utils] <205 > --                sche

In [5]:
############################ configure datasets ############################

# select samples that are used for training and testing
if params["dataset_size"] == "full":
    train_sample_ids = [
        "01",
        "02",
        "03",
        "04",
        "06",
        "07",
        "08",
        "09",
        "11",
        "12",
        "13",
        "14",
        "16",
        "17",
        "18",
        "19",
        "21",
        "22",
        "23",
        "24",
        "27",
        "28",
        "29",
        "30",
        "33",
        "35",
        "36",
        "38",
        "39",
        "41",
        "42",
        "43",
        "44",
        "46",
    ]
    test_sample_ids = ["05", "10", "15", "20", "25", "32", "34", "40", "45"]
elif params["dataset_size"] == "minimal":
    train_sample_ids = ["01"]
    test_sample_ids = ["34"]
else:
    logger.error(f"{params['dataset_size']} is not a valid dataset size.")
    exit()

# detect CUDA devices
if params["cuda"]:
    # if False:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pin_memory = True
else:
    device = "cpu"
    pin_memory = False
n_gpus = torch.cuda.device_count()
logger.info(f"Using torch device {device}, with {n_gpus} GPUs")

# set if temporal reduction is used
if params["temporal_reduction"]:
    logger.info(f"Using temporal reduction with {params['num_channels']} channels")

# normalize whole videos or chunks individually
if params["norm_video"] == "chunk":
    logger.info("Normalizing each chunk using min and max")
elif params["norm_video"] == "movie":
    logger.info("Normalizing whole video using min and max")
elif params["norm_video"] == "abs_max":
    logger.info("Normalizing whole video using 16-bit absolute max")

dataset_path = os.path.realpath(f"{params['relative_path']}")

# initialize training dataset
dataset = init_training_dataset(
    params=params,
    train_sample_ids=train_sample_ids,
    ignore_index=ignore_index,
    dataset_path=dataset_path,
)

[15:43:09] [  INFO  ] [  __main__  ] < 27 > -- Using torch device cuda, with 1 GPUs
[15:43:09] [  INFO  ] [  __main__  ] < 39 > -- Normalizing whole video using 16-bit absolute max
[15:43:09] [  INFO  ] [  datasets  ] <178 > -- Applying 2d gaussian blur to videos...
[15:43:09] [ DEBUG  ] [  datasets  ] <252 > -- Added padding of 12 frames to video with unsuitable duration


In [6]:
# train with only one batch
# ids = np.arange(0, params["batch_size"], 1, dtype=np.int64)
# dataset = torch.utils.data.Subset(dataset, ids)

In [7]:
logger.info(f"Samples in training dataset: {len(dataset)}")

# initialize testing dataset
testing_datasets = init_testing_dataset(
    params=params,
    test_sample_ids=test_sample_ids,
    ignore_index=ignore_index,
    dataset_path=dataset_path,
)

for i, tds in enumerate(testing_datasets):
    logger.info(f"Testing dataset {i} contains {len(tds)} samples")

# initialize data loaders
dataset_loader = DataLoader(
    dataset,
    batch_size=params["batch_size"],
    shuffle=True,
    num_workers=params["num_workers"],
    pin_memory=pin_memory,
)

[15:43:10] [  INFO  ] [  __main__  ] < 1  > -- Samples in training dataset: 9
[15:43:10] [ DEBUG  ] [  datasets  ] <304 > -- Computing spark peaks...
[15:43:13] [ DEBUG  ] [  datasets  ] <311 > -- Sample 34 contains 16 sparks.
[15:43:13] [  INFO  ] [  datasets  ] <178 > -- Applying 2d gaussian blur to videos...
[15:43:14] [ DEBUG  ] [  datasets  ] <252 > -- Added padding of 24 frames to video with unsuitable duration
[15:43:14] [  INFO  ] [  __main__  ] < 10 > -- Testing dataset 0 contains 22 samples


In [8]:
############################## configure UNet ##############################

network = init_model(params=params, num_classes=num_classes, ndims=ndims)

if device != "cpu":
    network = nn.DataParallel(network).to(device)
    torch.backends.cudnn.benchmark = True

if wandb_log:
    wandb.watch(network)

if params["initialize_weights"]:
    logger.info("Initializing UNet weights...")
    network.apply(weights_init)

# torch.compile(network, mode="default", backend="inductor")
# does not work on windows

In [9]:
if debug_mode:
    model_parameters = filter(lambda p: p.requires_grad, network.parameters())
    model_parameters = sum([np.prod(p.size()) for p in model_parameters])
    logger.debug(f"Number of trainable parameters: {model_parameters}")

[15:43:14] [ DEBUG  ] [  __main__  ] < 4  > -- Number of trainable parameters: 22631764


In [10]:
########################### initialize training ############################

if params["optimizer"] == "adam":
    optimizer = optim.Adam(network.parameters(), lr=params["lr_start"])
elif params["optimizer"] == "adadelta":
    optimizer = optim.Adadelta(network.parameters(), lr=params["lr_start"])
else:
    logger.error(f"{params['optimizer']} is not a valid optimizer.")
    exit()

if params["scheduler"] == "step":
    scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=params["scheduler_step_size"],
        gamma=params["scheduler_gamma"],
    )
else:
    scheduler = None

network.train()

output_path = os.path.join(output_relative_path, params["run_name"])
logger.info(f"Output directory: {output_path}")

summary_writer = SummaryWriter(os.path.join(output_path, "summary"), purge_step=0)

if params["load_run_name"] != None:
    load_path = os.path.join(output_relative_path, params["load_run_name"])
    logger.info(f"Model loaded from directory: {load_path}")
else:
    load_path = None

# initialize loss function
criterion = init_criterion(
    params=params, dataset=dataset, ignore_index=ignore_index, device=device
)

# directory where predicted class movies are saved
preds_output_dir = os.path.join(output_path, "predictions")
os.makedirs(preds_output_dir, exist_ok=True)

# generate dict of managed objects
managed_objects = {"network": network, "optimizer": optimizer}
if scheduler is not None:
    managed_objects["scheduler"] = scheduler

trainer = myTrainingManager(
    # training items
    training_step=lambda _: training_step(
        sampler=sampler,
        network=network,
        optimizer=optimizer,
        # scaler=GradScaler(),
        scheduler=scheduler,
        device=device,
        criterion=criterion,
        dataset_loader=dataset_loader,
        ignore_frames=params["ignore_frames_loss"],
    ),
    save_every=c.getint("training", "save_every", fallback=5000),
    load_path=load_path,
    save_path=output_path,
    managed_objects=unet.managed_objects(managed_objects),
    # testing items
    test_function=lambda _: test_function(
        network=network,
        device=device,
        criterion=criterion,
        testing_datasets=testing_datasets,
        ignore_frames=params["ignore_frames_loss"],
        training_name=params["run_name"],
        output_dir=preds_output_dir,
        batch_size=params["batch_size"],
        training_mode=True,
        debug=debug_mode,
    ),
    test_every=c.getint("training", "test_every", fallback=1000),
    plot_every=c.getint("training", "test_every", fallback=1000),
    summary_writer=summary_writer,
)

[15:43:14] [  INFO  ] [  __main__  ] < 23 > -- Output directory: runs\final_model_blur_frames_nll_loss_2
[15:43:14] [  INFO  ] [training_script_utils] <429 > -- Using class weights: 0.2512798607349396, 312.7631530761719, 0.0, 58.220359802246094


In [11]:
############################ init random seeds #############################

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

In [12]:
# for load_epoch in [10000,20000,30000,40000,50000,60000,70000,80000,90000,100000]:
# for load_epoch in [100000]:
#     trainer.load(load_epoch)
#     logger.info("Starting final validation")
#     trainer.run_validation(wandb_log=wandb_log)
# if wandb_log:
#     wandb.finish()

In [13]:
############################## load model ##############################
if params["load_epoch"] != 0:
    # trainer.load(10000)
    trainer.load(params["load_epoch"])

In [14]:
############################## start training ##############################

if c.getboolean(
    "general", "training", fallback=False
):  # Run training procedure on data
    # logger.info("Validate network before training")
    # trainer.run_validation(wandb_log=wandb_log)
    logger.info("Starting training")
    trainer.train(
        params["train_epochs"],
        print_every=c.getint("training", "print_every", fallback=100),
        wandb_log=wandb_log,
    )

[15:43:16] [  INFO  ] [  __main__  ] < 6  > -- Starting training
[15:43:40] [  INFO  ] [training_inference_tools] <102 > -- Iteration 0...
[15:43:40] [  INFO  ] [training_inference_tools] <103 > -- 	Training loss: 1.41
[15:43:40] [  INFO  ] [training_inference_tools] <104 > -- 	Time elapsed: 25.67s


KeyboardInterrupt: 

In [15]:
############################## run final validation ##############################

if c.getboolean("general", "testing", fallback=False):
    logger.info("Starting final validation")
    trainer.run_validation(wandb_log=wandb_log)

[15:43:51] [  INFO  ] [  __main__  ] < 4  > -- Starting final validation
[15:43:51] [  INFO  ] [training_inference_tools] < 46 > -- Validating network at iteration 6...
[15:43:51] [ DEBUG  ] [training_inference_tools] <801 > -- Testing function: running sample 34 in UNet
[15:43:56] [  INFO  ] [training_inference_tools] <828 > -- DEBUG: max and min difference between preds[1] and preds[3]: 0.11100940406322479, 0.11004564166069031
[15:43:56] [ DEBUG  ] [in_out_tools] <281 > -- Writing videos on directory c:\Users\dotti\sparks_project\sparks\runs\final_model_blur_frames_nll_loss_2\predictions ..
(904, 64, 512)
runs\final_model_blur_frames_nll_loss_2\predictions\final_model_blur_frames_nll_loss_2_34_xs.tif
[15:43:57] [ DEBUG  ] [training_inference_tools] <870 > -- Time to run sample 34 in UNet: 5.55 s
[15:43:57] [ DEBUG  ] [training_inference_tools] <877 > -- Testing function: re-organising annotations
[15:43:59] [ DEBUG  ] [training_inference_tools] <906 > -- Time to re-organise annotatio

KeyboardInterrupt: 

In [None]:
if wandb_log:
    wandb.finish()

# Visualize UNet architecture

In [None]:
# # get number of trainable parameters
# num_params = sum(p.numel() for p in network.parameters() if p.requires_grad)
# logger.debug(f"Number of trainable parameters: {num_params}")
# # get dummy unet input
# batch = next(iter(dataset_loader))
# x = batch[0].to(device)
# yhat = network(x[:,None]) # Give dummy batch to forward()
# from torchviz import make_dot
# make_dot(yhat, params=dict(list(network.named_parameters()))).render("unet_model", format="png")
# a = [0,1,2,3,4,5,6,7,8,9,10,11,12,13]

# len(a[0:4])