## Re-compute metrics of given training name at saved test iterations

**Author**: Prisca Dotti

**Last Edit**: 18.06.2024

In [2]:
# autoreload is used to reload modules automatically before entering the
# execution of code typed at the IPython prompt.
%load_ext autoreload
%autoreload 2
# To import modules from parent directory in Jupyter Notebook
import sys

sys.path.append("..")

In [3]:
import logging
import os
import numpy as np
import wandb

import torch
from torch import nn
from torch.utils.data import DataLoader

from config import TrainingConfig, config
from data.datasets import PatchSparksDataset
from utils.training_script_utils import (
    init_model,
    init_criterion,
)
from utils.training_inference_tools import test_function_patches

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

Get training-specific parameters

In [4]:
run_name = "sparks_patches_64x64x64_no_ignore_frames"
config_filename = os.path.join("config_files", "patches", "config_sparks_64x64x64.ini")

use_train_data = False
test_ids = [
    "05",
    "10",
    "15",
    "20",
    "25",
    "32",
    "34",
    "40",
    "45",
]

params = TrainingConfig(training_config_file=config_filename)
params.run_name = run_name

# trained_epochs = params.train_epochs
trained_epochs = 45000
saved_every = params.c.getint("training", "save_every", fallback=1000)

models_dir = os.path.realpath(
    os.path.join(
        config.basedir,
        "models",
        "saved_models",
        params.run_name,
    )
)

logger.info(f"Processing training '{params.run_name}'...")
logger.info(f"Predicting outputs for samples {test_ids}.")
logger.info(f"Using {params.dataset_dir} as dataset root path.")

params.set_device(device="auto")  # can also be set to "cpu" or "cuda"
params.display_device_info()

[16:45:17] [  INFO  ] [   config   ] <318 > -- Loading C:\Users\prisc\Code\sparks_project\config_files\patches\config_sparks_64x64x64.ini
[16:45:17] [ ERROR  ] [wandb.jupyter] <224 > -- Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mdottip[0m. Use [1m`wandb login --relogin`[0m to force relogin


[16:45:30] [  INFO  ] [  __main__  ] < 33 > -- Processing training 'sparks_patches_64x64x64_no_ignore_frames'...
[16:45:30] [  INFO  ] [  __main__  ] < 34 > -- Predicting outputs for samples ['05', '10', '15', '20', '25', '32', '34', '40', '45'].
[16:45:30] [  INFO  ] [  __main__  ] < 35 > -- Using C:\Users\prisc\Code\sparks_project\data\sparks_dataset as dataset root path.
[16:45:30] [  INFO  ] [   config   ] <566 > -- Using cuda


Configure datasets

In [5]:
# Initialize training dataset
dataset = PatchSparksDataset(
    params=params,
    base_path=params.dataset_dir,
    sample_ids=test_ids,
    load_instances=True,  # this is needed to detect patches wrt spark peaks
    inference=None,
)
logger.info(f"Samples in dataset (patches): {len(dataset)}")

# Create a dataloader
dataset_loader = DataLoader(
    dataset,
    batch_size=params.inference_batch_size,
    shuffle=False,
    num_workers=params.num_workers,
    pin_memory=params.pin_memory,
)

[16:46:32] [  INFO  ] [  __main__  ] < 9  > -- Samples in dataset (patches): 126


Configure UNet

In [6]:
# Initialize the UNet model
network = init_model(params=params)

# Move the model to the GPU if available
if params.device.type != "cpu":
    network = nn.DataParallel(network).to(params.device, non_blocking=True)

Load UNet models and compute metrics

In [7]:
output_dir = "DELETEME"
os.makedirs(output_dir, exist_ok=True)

In [8]:
load_epochs = np.arange(0, trained_epochs, saved_every)  # all epochs must be saved
for load_epoch in load_epochs:
    if load_epoch != 0:
        logger.info(f"Loading trained model '{run_name}' at epoch {load_epoch}...")

        # Path to the saved model checkpoint
        model_dir = os.path.join(models_dir, f"network_{load_epoch:06d}.pth")
        try:
            network.load_state_dict(torch.load(model_dir, map_location=params.device))
        except RuntimeError as e:
            if "module" in str(e):
                # The error message contains "module," so handle the DataParallel loading
                logger.warning(
                    "Failed to load the model, as it was trained with DataParallel. Wrapping it in DataParallel and retrying..."
                )
                # Get current device of the object (model)
                temp_device = next(iter(network.parameters())).device

                network = nn.DataParallel(network)
                network.load_state_dict(
                    torch.load(model_dir, map_location=params.device)
                )

                logger.info(
                    "Network should be on CPU, removing DataParallel wrapper..."
                )
                network = network.module.to(temp_device)
            else:
                # Handle other exceptions or re-raise the exception if it's unrelated
                raise

    logger.info(f"Computing metrics for epoch {load_epoch}...")
    network.eval()
    res = test_function_patches(
        network=network,
        device=params.device,
        criterion=init_criterion(params=params, dataset=dataset),
        params=params,
        testing_dataset=dataset,
        training_name=params.run_name,
        output_dir=output_dir,
        training_mode=True,
    )

    for m, val in res.items():
        logger.info(f"{m}: {val}")
        if "confusion_matrix" not in m:
            wandb.log({m: val}, step=saved_every, commit=True)

[16:46:33] [  INFO  ] [  __main__  ] < 28 > -- Computing metrics for epoch 0...


KeyboardInterrupt: 

In [None]:
wandb.join()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
average/correctly_classified,▁
average/detected,▁
average/f1-score,▁
average/labeled,▁
average/precision,▁
average/recall,▁
segmentation/average_IoU,▁
segmentation/sparks_IoU,▁
sparks/correctly_classified,▁
sparks/detected,▁

0,1
average/correctly_classified,1.0
average/detected,0.32685
average/f1-score,0.35462
average/labeled,0.38756
average/precision,0.38756
average/recall,0.32685
segmentation/average_IoU,0.17023
segmentation/sparks_IoU,0.17023
sparks/correctly_classified,1.0
sparks/detected,0.32685


In [None]:
# Close the wandb run
if params.wandb_log:
    wandb.finish()

NameError: name 'params' is not defined