## Train on Patches of Confocal tyx-Image Series

**Author**: Prisca Dotti

**Last Edit**: 18.06.2024

This notebook is used to run some experiments using the U-Net model on patches obtained from either real or fake confocal imaging data.

The dataset used for training could be one of the following:
- dataset of patches extracted in a meaningful way from confocal imaging recordings
- dataset of recordings from which patches are extracted when processing them in the U-Net for memory management reasons and which are recombined as whole movies after inference
- dataset of simulated confocal imaging patches --> this is similar to what I've been doing so far
- a combination of real and fake patches of confocal imaging data

UPDATES:
- 18.06.2024: this notebook is now only suitable for training and running validation.

In [1]:
# autoreload is used to reload modules automatically before entering the
# execution of code typed at the IPython prompt.
%load_ext autoreload
%autoreload 2
# To import modules from parent directory in Jupyter Notebook
import sys

sys.path.append("..")

In [2]:
import logging
import os
import random

import numpy as np
import torch
import wandb
from torch import nn, optim

# from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter

from config import TrainingConfig, config
from data.datasets import PatchSparksDataset
from models.UNet import unet
from utils.training_inference_tools import (
    MyTrainingManager,
    TransformedSparkDataset,
    random_flip,
    random_flip_noise,
    sampler,
    test_function_patches,
    training_step,
    weights_init,
)
from utils.training_script_utils import (
    get_sample_ids,
    init_criterion,
    init_model,
)


logger = logging.getLogger(__name__)

In [3]:
logger.setLevel(logging.DEBUG)

In [4]:
##################### Get training-specific parameters #####################

# Initialize training-specific parameters
# (get the configuration file path from ArgParse)
config_filename = os.path.join(
    "config_files", "config_sparks_patches_64x64x64_nll_loss.ini"
)
# config_filename = os.path.join("config_files", "config_final_model.ini")

params = TrainingConfig(training_config_file=config_filename)

# Print parameters to console if needed
params.print_params()

######################### Initialize random seeds ##########################

# We used these random seeds to ensure reproducibility of the results

torch.manual_seed(0)  # <--------------------------------------------------!
random.seed(0)  # <--------------------------------------------------------!
np.random.seed(0)  # <-----------------------------------------------------!

[02:16:38] [  INFO  ] [   config   ] <318 > -- Loading C:\Users\dotti\Code\sparks_project\config_files\config_sparks_patches_64x64x64_nll_loss.ini
[02:16:38] [ ERROR  ] [wandb.jupyter] <224 > -- Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mdottip[0m. Use [1m`wandb login --relogin`[0m to force relogin


[02:16:45] [  INFO  ] [   config   ] <570 > --     training_config_file: C:\Users\dotti\Code\sparks_project\config_files\config_sparks_patches_64x64x64_nll_loss.ini
[02:16:45] [  INFO  ] [   config   ] <570 > --              dataset_dir: C:\Users\dotti\Code\sparks_project\data\sparks_dataset
[02:16:45] [  INFO  ] [   config   ] <570 > --                        c: <configparser.ConfigParser object at 0x00000179D007F940>
[02:16:45] [  INFO  ] [   config   ] <570 > --                 run_name: sparks_patches_64x64x64_nll_loss
[02:16:45] [  INFO  ] [   config   ] <570 > --            load_run_name: 
[02:16:45] [  INFO  ] [   config   ] <570 > --               load_epoch: 0
[02:16:45] [  INFO  ] [   config   ] <570 > --             train_epochs: 100000
[02:16:45] [  INFO  ] [   config   ] <570 > --                criterion: nll_loss
[02:16:45] [  INFO  ] [   config   ] <570 > --                 lr_start: 0.0001
[02:16:45] [  INFO  ] [   config   ] <570 > --       ignore_frames_loss: 6
[02:1

In [5]:
# params.set_device("cpu")
params.display_device_info()

[02:16:46] [  INFO  ] [   config   ] <566 > -- Using cuda


In [6]:
############################ Configure datasets ############################

# Select samples for training and testing based on dataset size
train_sample_ids = get_sample_ids(
    train_data=True,
    dataset_size=params.dataset_size,
)
test_sample_ids = get_sample_ids(
    train_data=False,
    dataset_size=params.dataset_size,
)

# Initialize training dataset
dataset = PatchSparksDataset(
    params=params,
    base_path=params.dataset_dir,
    sample_ids=train_sample_ids,
    load_instances=True,  # this is needed to detect patches wrt spark peaks
    inference=None,
)

# Apply transforms based on noise_data_augmentation setting
# (transforms are applied when getting a sample from the dataset)
transforms = random_flip_noise if params.noise_data_augmentation else random_flip
dataset = TransformedSparkDataset(dataset, transforms)

logger.info(f"Samples in dataset (patches): {len(dataset)}")

# Initialize testing datasets
testing_dataset = PatchSparksDataset(
    params=params,
    base_path=params.dataset_dir,
    sample_ids=test_sample_ids,
    load_instances=True,  # this is needed to detect patches wrt spark peaks
    inference=None,
)

logger.info(f"Samples in test dataset (patches): {len(testing_dataset)}")

[02:20:29] [  INFO  ] [  __main__  ] < 27 > -- Samples in dataset (patches): 528
[02:21:07] [  INFO  ] [  __main__  ] < 38 > -- Samples in test dataset (patches): 127


In [7]:
# import napari

# viewer = napari.Viewer()

# train_sample_id = 9
# train_sample = dataset[train_sample_id]
# train_data = train_sample["data"].numpy()
# train_labels = train_sample["labels"].numpy()
# viewer.add_image(train_data)
# viewer.add_labels(train_labels)

# test_sample_id = 0
# test_sample = testing_dataset[test_sample_id]
# test_data = test_sample["data"].numpy()
# test_labels = test_sample["labels"].numpy()
# viewer.add_image(test_data)
# viewer.add_labels(test_labels)

In [8]:
# Initialize data loaders
dataset_loader = DataLoader(
    dataset,
    batch_size=params.batch_size,
    num_workers=params.num_workers,
    pin_memory=params.pin_memory,
)

In [9]:
############################## Configure UNet ##############################

# Initialize the UNet model
network = init_model(params=params)

# Move the model to the GPU if available
if params.device.type != "cpu":
    network = nn.DataParallel(network).to(params.device, non_blocking=True)
    # cudnn.benchmark = True

# Watch the model with wandb for logging if enabled
if params.wandb_log:
    wandb.watch(network)

# Initialize UNet weights if required
if params.initialize_weights:
    logger.info("Initializing UNet weights...")
    network.apply(weights_init)

# The following line is commented as it does not work on Windows
# torch.compile(network, mode="default", backend="inductor")

In [10]:
########################### Initialize training ############################

# Initialize the optimizer based on the specified type
if params.optimizer == "adam":
    optimizer = optim.Adam(network.parameters(), lr=params.lr_start)
elif params.optimizer == "adadelta":
    optimizer = optim.Adadelta(network.parameters(), lr=params.lr_start)
elif params.optimizer == "sgd":
    optimizer = optim.SGD(network.parameters(), lr=params.lr_start)
else:
    logger.error(f"{params.optimizer} is not a valid optimizer.")
    exit()

# Initialize the learning rate scheduler if specified
if params.scheduler == "step":
    scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=params.scheduler_step_size,
        gamma=params.scheduler_gamma,
    )
else:
    scheduler = None

# Define the output directory path
output_path = os.path.join(config.output_dir, params.run_name)
logger.info(f"Output directory: {os.path.realpath(output_path)}")

# Initialize the summary writer for TensorBoard logging
summary_writer = SummaryWriter(os.path.join(output_path, "summary"), purge_step=0)

# Check if a pre-trained model should be loaded
if params.load_run_name != "":
    load_path = os.path.join(config.output_dir, params.load_run_name)
    logger.info(f"Model loaded from directory: {os.path.realpath(load_path)}")
else:
    load_path = None

# Initialize the loss function
criterion = init_criterion(params=params, dataset=dataset)

# Create a directory to save predicted class movies
preds_output_dir = os.path.join(output_path, "predictions")
os.makedirs(preds_output_dir, exist_ok=True)

# Create a dictionary of managed objects
managed_objects = {"network": network, "optimizer": optimizer}
if scheduler is not None:
    managed_objects["scheduler"] = scheduler

# Create a training manager with the specified training and testing functions
trainer = MyTrainingManager(
    # Training parameters
    training_step=lambda _: training_step(
        dataset_loader=dataset_loader,
        params=params,
        sampler=sampler,
        network=network,
        optimizer=optimizer,
        criterion=criterion,
        scheduler=scheduler,
        # scaler=GradScaler(),
    ),
    save_every=params.c.getint("training", "save_every", fallback=5000),
    load_path=load_path,
    save_path=output_path,
    managed_objects=unet.managed_objects(managed_objects),
    # Testing parameters
    test_function=lambda _: test_function_patches(
        network=network,
        device=params.device,
        criterion=criterion,
        params=params,
        testing_dataset=testing_dataset,
        training_name=params.run_name,
        output_dir=preds_output_dir,
        training_mode=True,
    ),
    test_every=params.c.getint("training", "test_every", fallback=1000),
    plot_every=params.c.getint("training", "test_every", fallback=1000),
    summary_writer=summary_writer,
)

# Load the model if a specific epoch is provided
if params.load_epoch != 0:
    trainer.load(params.load_epoch)

[02:21:08] [  INFO  ] [  __main__  ] < 26 > -- Output directory: C:\Users\dotti\Code\sparks_project\models\saved_models\sparks_patches_64x64x64_nll_loss
[02:21:08] [  INFO  ] [utils.training_script_utils] <361 > -- Using class weights: 0.5046204328536987, 54.60780715942383


In [11]:
# Validate the network before training if resuming from a checkpoint
# if params.load_epoch > 0:
#     logger.info("Validate network before training")
#     trainer.run_validation(wandb_log=params.wandb_log)

In [12]:
# Set the network in training mode
network.train()

# Train the model for the specified number of epochs
logger.info("Starting training")
trainer.train(
    params.train_epochs,
    print_every=params.c.getint("training", "print_every", fallback=100),
    wandb_log=params.wandb_log,
)

[02:21:08] [  INFO  ] [  __main__  ] < 5  > -- Starting training
[02:21:10] [  INFO  ] [utils.training_inference_tools] <1626> -- Iteration 0...
[02:21:10] [  INFO  ] [utils.training_inference_tools] <1627> -- 	Training loss: 0.6692
[02:21:10] [  INFO  ] [utils.training_inference_tools] <1628> -- 	Time elapsed: 3.38s
[02:21:54] [  INFO  ] [utils.training_inference_tools] <1626> -- Iteration 100...
[02:21:54] [  INFO  ] [utils.training_inference_tools] <1627> -- 	Training loss: 0.5769
[02:21:54] [  INFO  ] [utils.training_inference_tools] <1628> -- 	Time elapsed: 212.36s
[02:22:39] [  INFO  ] [utils.training_inference_tools] <1626> -- Iteration 200...
[02:22:39] [  INFO  ] [utils.training_inference_tools] <1627> -- 	Training loss: 0.4238
[02:22:39] [  INFO  ] [utils.training_inference_tools] <1628> -- 	Time elapsed: 422.88s
[02:23:25] [  INFO  ] [utils.training_inference_tools] <1626> -- Iteration 300...
[02:23:25] [  INFO  ] [utils.training_inference_tools] <1627> -- 	Training loss: 0.

In [13]:
logger.info("Starting final validation")
# Run the final validation/testing procedure
trainer.run_validation(wandb_log=params.wandb_log)

[15:59:10] [  INFO  ] [  __main__  ] < 1  > -- Starting final validation
[15:59:10] [  INFO  ] [utils.training_inference_tools] <1495> -- Validating network at iteration 100000...
[15:59:42] [  INFO  ] [utils.training_inference_tools] <1598> -- Metrics:
[15:59:42] [  INFO  ] [utils.training_inference_tools] <1604> -- 	validation_loss: 48.49
[15:59:42] [  INFO  ] [utils.training_inference_tools] <1604> -- 	sparks/precision: 0.1059
[15:59:42] [  INFO  ] [utils.training_inference_tools] <1604> -- 	sparks/recall: 0.3205
[15:59:42] [  INFO  ] [utils.training_inference_tools] <1604> -- 	sparks/correctly_classified: 1
[15:59:42] [  INFO  ] [utils.training_inference_tools] <1604> -- 	sparks/detected: 0.3205
[15:59:42] [  INFO  ] [utils.training_inference_tools] <1604> -- 	sparks/f1-score: 0.1592
[15:59:42] [  INFO  ] [utils.training_inference_tools] <1604> -- 	sparks/labeled: 0.1059
[15:59:42] [  INFO  ] [utils.training_inference_tools] <1604> -- 	total/precision: 0.1059
[15:59:42] [  INFO  ] 

In [14]:
# Close the summary writer
summary_writer.close()

# Close the wandb run
if params.wandb_log:
    wandb.finish()

0,1
U-Net training loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
average/correctly_classified,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
average/detected,██▄▃▃▃▃▂▃▂▂▂▂▂▃▅▂▂▂▂▁▂▂▁▁▂▁▁▂▁▁▁▁▂▁▁▂▁▂▂
average/f1-score,█▅▂▁▁▃▂▂▃▂▂▂▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▃▂▂▁▂▂▂▂▁▂▂▂▂
average/labeled,█▅▁▁▁▃▃▃▃▂▃▂▄▃▄▃▄▃▃▂▂▃▃▂▂▃▂▃▃▂▂▂▂▂▂▁▂▂▂▂
average/precision,█▅▁▁▁▃▃▃▃▂▃▂▄▃▄▃▄▃▃▂▂▃▃▂▂▃▂▃▃▂▂▂▂▂▂▁▂▂▂▂
average/recall,██▄▃▃▃▃▂▃▂▂▂▂▂▃▅▂▂▂▂▁▂▂▁▁▂▁▁▂▁▁▁▁▂▁▁▂▁▂▂
segmentation/average_IoU,▇█▇▆▄▃▂▂▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▁▂▂▂▂▃
segmentation/sparks_IoU,▇█▇▆▄▃▂▂▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▁▂▂▂▂▃
sparks/correctly_classified,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
U-Net training loss,0.0059
average/correctly_classified,1.0
average/detected,0.32046
average/f1-score,0.15917
average/labeled,0.10588
average/precision,0.10588
average/recall,0.32046
segmentation/average_IoU,0.02579
segmentation/sparks_IoU,0.02579
sparks/correctly_classified,1.0
