## Visualise Trained Models Results

**Author**: Prisca Dotti

**Last Edit**: 18.06.2024

In [1]:
# autoreload is used to reload modules automatically before entering the
# execution of code typed at the IPython prompt.
%load_ext autoreload
%autoreload 2
# To import modules from parent directory in Jupyter Notebook
import sys

sys.path.append("..")

In [2]:
import logging
import os
import time

import torch
from torch import nn

# from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader

from config import TrainingConfig, config
from data.datasets import PatchSparksDataset
from data.data_processing_tools import masks_to_instances_dict, process_raw_predictions
from utils.training_script_utils import (
    get_sample_ids,
    init_model,
)


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

In [3]:
####################### Get training-specific parameters #######################

run_name = "sparks_patches"
# run_name = "TEMP_new_annotated_peaks_physio"  # TEMP local (run on laptop)
config_filename = os.path.join("config_files", "config_sparks_patches.ini")
load_epoch = 100000

use_train_data = False
# custom_ids = []
custom_ids = [
    "05",
    # "10",
    # "15",
    # "20",
    # "25",
    # "32",
    # "34",
    # "40",
    # "45",
]  # override sample_ids if needed

testing = True  # set to False to only generate unet predictions
# set to True to also compute processed outputs and metrics

# Initialize general parameters
params = TrainingConfig(training_config_file=config_filename)

if run_name:
    params.run_name = run_name
model_filename = f"network_{load_epoch:06d}.pth"

# Print parameters to console if needed
# params.print_params()

debug = True if config.verbosity == 3 else False

[17:56:13] [  INFO  ] [   config   ] <318 > -- Loading C:\Users\prisc\Code\sparks_project\config_files\config_sparks_patches.ini


In [4]:
########################### Detect GPU, if available ###########################

params.set_device(device="auto")
# params.set_device(device="cpu")  # temporary
params.display_device_info()

[17:56:14] [  INFO  ] [   config   ] <566 > -- Using cuda


In [5]:
########################### Configure output folder ############################

output_folder = "results_visualisation"  # Same folder for train and test preds
os.makedirs(output_folder, exist_ok=True)

# Subdirectory of output_folder where predictions are saved.
# Change this to save results for same model with different inference
# approaches.
# output_name = training_name + "_step=2"
output_name = params.run_name

save_folder = os.path.join(config.basedir, output_folder, output_name)
os.makedirs(save_folder, exist_ok=True)
logger.info(f"Annotations and predictions will be saved on '{save_folder}'")

[17:56:14] [  INFO  ] [  __main__  ] < 14 > -- Annotations and predictions will be saved on 'C:\Users\prisc\Code\sparks_project\results_visualisation\sparks_patches'


In [6]:
############################ Configure datasets ############################

logger.info(f"Processing training '{params.run_name}'...")

# Define the sample IDs based on dataset size and usage
sample_ids = get_sample_ids(
    train_data=use_train_data,
    dataset_size=params.dataset_size,
    custom_ids=custom_ids,
)
logger.info(f"Predicting outputs for samples {sample_ids}.")
logger.info(f"Using {params.dataset_dir} as dataset root path.")

# Initialize training dataset
dataset = PatchSparksDataset(
    params=params,
    base_path=params.dataset_dir,
    sample_ids=sample_ids,
    load_instances=True,  # this is needed to detect patches wrt spark peaks
    inference=None,
)

logger.info(f"Samples in dataset (patches): {len(dataset)}")

# Create a dataloader
dataset_loader = DataLoader(
    dataset,
    batch_size=params.inference_batch_size,
    shuffle=False,
    num_workers=params.num_workers,
    pin_memory=params.pin_memory,
)

[17:56:14] [  INFO  ] [  __main__  ] < 3  > -- Processing training 'sparks_patches'...
[17:56:14] [  INFO  ] [  __main__  ] < 11 > -- Predicting outputs for samples ['05'].
[17:56:14] [  INFO  ] [  __main__  ] < 12 > -- Using C:\Users\prisc\Code\sparks_project\data\sparks_dataset as dataset root path.
[17:56:26] [  INFO  ] [  __main__  ] < 23 > -- Samples in dataset (patches): 45


In [7]:
############################## Configure UNet ##############################

# Initialize the UNet model
network = init_model(params=params)

# Move the model to the GPU if available
if params.device.type != "cpu":
    network = nn.DataParallel(network).to(params.device, non_blocking=True)
    # cudnn.benchmark = True

Can't run this on laptop

In [8]:
### Load UNet model ###

# Path to the saved model checkpoint
models_relative_path = os.path.join(
    "models", "saved_models", params.run_name, model_filename
)
model_dir = os.path.realpath(os.path.join(config.basedir, models_relative_path))

# Load the model state dictionary
logger.info(f"Loading trained model '{run_name}' at epoch {load_epoch}...")
try:
    network.load_state_dict(torch.load(model_dir, map_location=params.device))
except RuntimeError as e:
    if "module" in str(e):
        # The error message contains "module," so handle the DataParallel loading
        logger.warning(
            "Failed to load the model, as it was trained with DataParallel. Wrapping it in DataParallel and retrying..."
        )
        # Get current device of the object (model)
        temp_device = next(iter(network.parameters())).device

        network = nn.DataParallel(network)
        network.load_state_dict(torch.load(model_dir, map_location=params.device))

        logger.info("Network should be on CPU, removing DataParallel wrapper...")
        network = network.module.to(temp_device)
    else:
        # Handle other exceptions or re-raise the exception if it's unrelated
        raise

In [20]:
########################### Run samples in UNet ############################

xs_list = []
ys_list = []
ys_instances_list = []

# get U-Net's raw predictions
network.eval()
raw_preds_list = []
start = time.time()
with torch.no_grad():
    for batch in dataset_loader:
        x = batch["data"]
        y = batch["labels"]
        y_instances = batch["instances"]

        raw_pred = network(x.to(params.device).unsqueeze(1))
        raw_pred = torch.exp(raw_pred).cpu()

        raw_preds_list.extend(raw_pred.numpy())
        xs_list.extend(x.numpy())
        ys_list.extend(y.numpy())
        ys_instances_list.extend(y_instances.numpy())

logger.debug(f"Time to run testing dataset in UNet: {time.time() - start:.2f} s")

[18:09:17] [ DEBUG  ] [  __main__  ] < 25 > -- Time to run testing dataset in UNet: 0.57 s


In [21]:
#################### Get processed output (if required) ####################

logger.debug("Getting processed output (segmentation and instances)")

preds_list = []
preds_instances_list = []
for i in range(len(raw_preds_list)):
    logger.debug(f"Processing patch {i+1}/{len(dataset)}...")

    # transform raw predictions into a dictionary
    raw_preds_dict = {
        event_type: raw_preds_list[i][event_label]
        for event_type, event_label in config.classes_dict.items()
        if event_type in config.event_types
    }

    preds_instances, preds_segmentation, _ = process_raw_predictions(
        raw_preds_dict=raw_preds_dict,
        input_movie=xs_list[i],
        training_mode=False,
    )

    preds_list.append(preds_segmentation["sparks"])
    preds_instances_list.append(preds_instances["sparks"])

[18:09:21] [ DEBUG  ] [  __main__  ] < 3  > -- Getting processed output (segmentation and instances)
[18:09:21] [ DEBUG  ] [  __main__  ] < 8  > -- Processing patch 1/45...
[18:09:21] [ DEBUG  ] [  __main__  ] < 8  > -- Processing patch 2/45...
[18:09:22] [ DEBUG  ] [  __main__  ] < 8  > -- Processing patch 3/45...
[18:09:22] [ DEBUG  ] [  __main__  ] < 8  > -- Processing patch 4/45...
[18:09:22] [ DEBUG  ] [  __main__  ] < 8  > -- Processing patch 5/45...
[18:09:22] [ DEBUG  ] [  __main__  ] < 8  > -- Processing patch 6/45...
[18:09:22] [ DEBUG  ] [  __main__  ] < 8  > -- Processing patch 7/45...
[18:09:22] [ DEBUG  ] [  __main__  ] < 8  > -- Processing patch 8/45...
[18:09:22] [ DEBUG  ] [  __main__  ] < 8  > -- Processing patch 9/45...
[18:09:22] [ DEBUG  ] [  __main__  ] < 8  > -- Processing patch 10/45...
[18:09:22] [ DEBUG  ] [  __main__  ] < 8  > -- Processing patch 11/45...
[18:09:22] [ DEBUG  ] [  __main__  ] < 8  > -- Processing patch 12/45...
[18:09:22] [ DEBUG  ] [  __main_

### Visualise Preds with Napari

In [48]:
import napari
from utils.visualization_tools import (
    get_discrete_cmap,
    get_labels_cmap,
    get_annotations_contour,
)

cmap = get_discrete_cmap(name="gray", lut=16)

In [49]:
# Initialize the napari viewer
viewer = napari.Viewer()

for i in range(len(dataset)):
    # get contours of annotations, for visualization
    y_contours = get_annotations_contour(annotations=ys_list[i], contour_val=2)

    viewer.add_image(xs_list[i], colormap="gray", name="Patches", blending="additive")
    viewer.add_labels(
        y_contours,
        name="Ground Truth",
        blending="additive",
        opacity=0.3,
        color=get_labels_cmap(),
    )
    viewer.add_labels(
        preds_list[i],
        name="Predictions",
        blending="additive",
        opacity=0.3,
        color=get_labels_cmap(),
    )

# Enable grid mode
viewer.grid.enabled = True
viewer.grid.stride = 3
# n_cols = 5
# n_rows = len(dataset) // n_cols
# viewer.grid.shape = (n_rows, n_cols)