In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import configparser
import logging
import os
import sys
import time
import math

import numpy as np
import torch
from architectures import TempRedUNet
from datasets import SparkDataset
from in_out_tools import write_videos_on_disk
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from training_inference_tools import get_preds
from data_processing_tools import get_event_instances_class,get_processed_result
from metrics_tools import get_score_matrix, get_matches_summary, get_metrics_from_summary
import unet

In [3]:
################################ Set parameters ################################

training_name = "dice_loss"
config_file = "config_dice_loss.ini"
use_train_data = True
get_final_pred = False # set to False to only compute raw predictions
testing = False  # set to False to only generate unet raw predictions
                # set to True to also compute processed outputs and metrics

if testing:
    get_final_pred = True

In [4]:
if get_final_pred:
    # set physiological and validation parameters

    # physiological params (for spark peaks results)
    pixel_size = 0.2  # 1 pixel = 0.2 um x 0.2 um
    # min distance in space
    min_dist_xy = round(1.8 / pixel_size) # = 9 pixels
    time_frame = 6.8  # 1 frame = 6.8 ms
    # min distance in time
    min_dist_t = round(20 / time_frame)  # = 3 frames

    # spark instances detection parameters
    radius = math.ceil(min_dist_xy / 2)
    y, x = np.ogrid[-radius: radius + 1, -radius: radius + 1]
    disk = x**2 + y**2 <= radius**2
    conn_mask = np.stack([disk] * (min_dist_t), axis=0)

    debug = True
    ca_release_events = ['sparks', 'puffs', 'waves']

    # TODO: use better parameters !!!
    pixel_size = 0.2
    spark_min_width = 3
    spark_min_t = 3
    puff_min_t = 5
    wave_min_width = round(15 / pixel_size)

    # connectivity for event instances detection
    connectivity = 26

    # maximal gap between two predicted puffs or waves that belong together
    max_gap = 2  # i.e., 2 empty frames

    sigma = 3

    # parameters for correspondence computation
    # threshold for considering annotated and pred ROIs a match
    iomin_t = 0.5

In [5]:
########################### Configure output folder ############################

output_folder = "trainings_validation"  # same folder for train and test preds
os.makedirs(output_folder, exist_ok=True)

# subdirectory of output_folder where predictions are saved
# change this to save results for same model with different inference approaches
#output_name = training_name + "_step=2"
output_name = training_name

save_folder = os.path.join(output_folder, output_name)
os.makedirs(save_folder, exist_ok=True)

In [6]:
############################### Configure logger ###############################

# set verbosity
logger = logging.getLogger(__name__)

log_level = logging.DEBUG
log_handlers = (logging.StreamHandler(sys.stdout),)

logging.basicConfig(
    level=log_level,
    format="[{asctime}] [{levelname:^8s}] [{name:^12s}] <{lineno:^4d}> -- {message:s}",
    style="{",
    datefmt="%H:%M:%S",
    handlers=log_handlers,
)

In [7]:
########################### Detect GPU, if available ###########################

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpus = torch.cuda.device_count()
logger.info(f"Using device '{device}' with {n_gpus} GPUs")


[10:01:41] [  INFO  ] [  __main__  ] < 5  > -- Using device 'cuda' with 1 GPUs


In [8]:
############################### Load config file ###############################

config_folder = "config_files"
CONFIG_FILE = os.path.join(config_folder, config_file)
c = configparser.ConfigParser()
if os.path.isfile(CONFIG_FILE):
    logger.info(f"Loading {CONFIG_FILE}")
    c.read(CONFIG_FILE)
else:
    logger.info(
        f"No config file found at {CONFIG_FILE}, trying to use fallback values."
    )


[10:01:41] [  INFO  ] [  __main__  ] < 7  > -- Loading config_files\config_dice_loss.ini


In [9]:
######################## Config dataset and UNet model #########################

logger.info(f"Processing training '{training_name}'...")

### Params ###
load_epoch = c.getint("testing", "load_epoch")

batch_size = c.getint("testing", "batch_size", fallback="1")
ignore_frames = c.getint("training", "ignore_frames_loss")

temporal_reduction = c.getboolean(
    "network", "temporal_reduction", fallback=False)
num_channels = (
    c.getint("network", "num_channels",
             fallback=1) if temporal_reduction else 1
)


[10:01:41] [  INFO  ] [  __main__  ] < 3  > -- Processing training 'dice_loss'...


In [10]:
### Configure dataset/inference method ###
dataset_size = c.get("testing", "dataset_size")
data_step = c.getint("testing", "data_step")
data_duration = c.getint("testing", "data_duration")
inference = c.get("testing", "inference")

if use_train_data:
    logger.info("Predict outputs for training data")
    if dataset_size == "full":
        sample_ids = [
            "01",
            "02",
            "03",
            "04",
            "06",
            "07",
            "08",
            "09",
            "11",
            "12",
            "13",
            "14",
            "16",
            "17",
            "18",
            "19",
            "21",
            "22",
            "23",
            "24",
            "27",
            "28",
            "29",
            "30",
            "33",
            "35",
            "36",
            "38",
            "39",
            "41",
            "42",
            "43",
            "44",
            "46",
        ]
    elif dataset_size == "minimal":
        sample_ids = ["01"]
else:
    logger.info("Predict outputs for testing data")
    if dataset_size == "full":
        sample_ids = ["05", "10", "15", "20", "25", "32", "34", "40", "45"]
    elif dataset_size == "minimal":
        sample_ids = ["34"]

relative_path = c.get("dataset", "relative_path")
dataset_path = os.path.realpath(f"{relative_path}")
assert os.path.isdir(dataset_path), f'"{dataset_path}" is not a directory'
logger.info(f"Using {dataset_path} as dataset root path")
logger.info(f"Annotations and predictions will be saved on '{save_folder}'")


[10:01:41] [  INFO  ] [  __main__  ] < 8  > -- Predict outputs for training data
[10:01:41] [  INFO  ] [  __main__  ] < 58 > -- Using C:\Users\dotti\sparks_project\data\sparks_dataset as dataset root path
[10:01:41] [  INFO  ] [  __main__  ] < 59 > -- Annotations and predictions will be saved on 'trainings_validation\dice_loss'


In [11]:
### Configure UNet ###

batch_norm = {"batch": True, "none": False}

unet_config = unet.UNetConfig(
    steps=c.getint("network", "unet_steps"),
    first_layer_channels=c.getint("network", "first_layer_channels"),
    num_classes=4,
    ndims=3,
    dilation=c.getint("network", "dilation", fallback=1),
    border_mode=c.get("network", "border_mode"),
    batch_normalization=batch_norm[c.get("network", "batch_normalization")],
    num_input_channels=num_channels,
)
if not temporal_reduction:
    network = unet.UNetClassifier(unet_config)
else:
    assert (
        c.getint("dataset", "data_duration") % num_channels == 0
    ), "using temporal reduction chunks_duration must be a multiple of num_channels"
    network = TempRedUNet(unet_config)

network = nn.DataParallel(network).to(device)

### Load UNet model ###
models_relative_path = "runs/"
model_path = os.path.join(models_relative_path, training_name)
# logger.info(f"Saved model path: {model_path}")
summary_writer = SummaryWriter(
    os.path.join(model_path, "summary"), purge_step=0)

trainer = unet.TrainingManager(
    # training items
    training_step=None,
    save_path=model_path,
    managed_objects=unet.managed_objects({"network": network}),
    summary_writer=summary_writer,
)

logger.info(
    f"Loading trained model '{training_name}' at epoch {load_epoch}...")
trainer.load(load_epoch)
# logger.info(f"Loaded trained model located in '{model_path}'")


[10:01:42] [  INFO  ] [  __main__  ] < 40 > -- Loading trained model 'dice_loss' at epoch 100000...
[10:01:42] [  INFO  ] [unet.trainer] <131 > -- Loading 'runs/dice_loss\network_100000.pth'...


In [12]:
############################# Run samples in UNet ##############################

xs = {}
ys = {}
ys_instances = {}
preds = {}
if get_final_pred:
    preds_instances = {}
    preds_segmentation = {}

for sample_id in sample_ids:
    logger.debug(f"Processing sample {sample_id}...")
    start = time.time()
    ### Create dataset ###
    testing_dataset = SparkDataset(
        base_path=dataset_path,
        sample_ids=[sample_id],
        testing=testing,
        smoothing=c.get("dataset", "data_smoothing"),
        step=data_step,
        duration=data_duration,
        remove_background=c.get("dataset", "remove_background"),
        temporal_reduction=c.getboolean(
            "network", "temporal_reduction", fallback=False
        ),
        num_channels=num_channels,
        normalize_video=c.get("dataset", "norm_video"),
        only_sparks=c.getboolean("dataset", "only_sparks", fallback=False),
        sparks_type=c.get("dataset", "sparks_type"),
        ignore_frames=c.get("training", "ignore_frames_loss"),
        ignore_index=4,
        gt_available=True,
        inference=inference,
    )

    logger.info(
        f"\tTesting dataset of movie {testing_dataset.video_name} "
        f"\tcontains {len(testing_dataset)} samples."
    )

    logger.info(f"\tProcessing samples in UNet...")
    # ys and preds are numpy arrays
    xs[sample_id], ys[sample_id], preds[sample_id] = get_preds(
        network=network, test_dataset=testing_dataset, compute_loss=False, device=device
    )

    # preds are in logarithmic scale, compute exp
    preds[sample_id] = np.exp(preds[sample_id])

    if testing:
        # get labelled event instances, for validation
        # ys_instances is a dict with classified event instances, for each class
        ys_instances[sample_id] = get_event_instances_class(
            event_instances=testing_dataset.events, class_labels=ys[sample_id], shift_ids=True
        )
        # remove ignored events entry from ys_instances
        ys_instances[sample_id].pop("ignore", None)

    if get_final_pred:
        ######################### get processed output #########################

        logger.debug(
            "Testing function: getting processed output (segmentation and instances)")

        # get predicted segmentation and event instances
        preds_instances[sample_id], preds_segmentation[sample_id], _ = get_processed_result(
            sparks=preds[sample_id][1],
            puffs=preds[sample_id][3],
            waves=preds[sample_id][2],
            xs=xs[sample_id],
            conn_mask=conn_mask,
            connectivity=connectivity,
            max_gap=max_gap,
            sigma=sigma,
            wave_min_width=wave_min_width,
            puff_min_t=puff_min_t,
            spark_min_t=spark_min_t,
            spark_min_width=spark_min_width,
            training_mode=False,
            debug=debug
        )
        
    if not get_final_pred:
        logger.info(f"\tTime to process sample {sample_id} in UNet: {time.time() - start:.2f} seconds.")
    else:
        logger.info(f"\tTime to process sample {sample_id} in UNet + post-processing: {time.time() - start:.2f} seconds.")

    ### Save preds on disk ###
    logger.info(f"\tSaving annotations and predictions...")

    video_name = f"{str(load_epoch)}_{testing_dataset.video_name}"

    write_videos_on_disk(
        training_name=output_name,
        video_name=video_name,
        path=save_folder,
        preds=preds[sample_id],
        ys=ys[sample_id],
    )

logger.info(f"DONE")


[10:01:46] [ DEBUG  ] [  __main__  ] < 12 > -- Processing sample 01...
[10:01:46] [ DEBUG  ] [  datasets  ] <292 > -- Added padding of 12 frames to video with unsuitable duration
[10:01:46] [  INFO  ] [  __main__  ] < 36 > -- 	Testing dataset of movie 01 	contains 9 samples.
[10:01:46] [  INFO  ] [  __main__  ] < 41 > -- 	Processing samples in UNet...
[10:01:55] [  INFO  ] [  __main__  ] < 84 > -- 	Time to process sample 01 in UNet: 8.80 seconds.
[10:01:55] [  INFO  ] [  __main__  ] < 89 > -- 	Saving annotations and predictions...
[10:01:55] [ DEBUG  ] [in_out_tools] <278 > -- Writing videos on directory c:\Users\dotti\sparks_project\sparks\trainings_validation\dice_loss ..
[10:01:55] [  INFO  ] [  __main__  ] <101 > -- DONE


### Compute metrics

In [14]:
if not testing:
    print("!!!!!!!!! THE FOLLOWING CODE WON'T WORK !!!!!!!!!")

In [15]:
# initialize dicts that will contain the results
matched_preds_ids = {}
matched_ys_ids = {}
# sum events over all samples
preds_cat = ['tot', 'tp', 'ignored', 'unlabeled'] + ca_release_events
ys_cat = ['tot', 'tp', 'undetected'] + ca_release_events
matched_preds_ids['sum'] = {}
matched_ys_ids['sum'] = {}
for ca_event in ca_release_events:
    matched_preds_ids['sum'][ca_event] = {cat: 0 for cat in preds_cat if cat != ca_event}
    matched_ys_ids['sum'][ca_event] = {cat: 0 for cat in ys_cat if cat != ca_event}

for sample_id in sample_ids:
    logger.debug(f"Processing sample {sample_id}...")
    
    pred_instances = preds_instances[sample_id]
    y_instances = ys_instances[sample_id]
    ignore_mask = np.where(ys[sample_id] == 4, 1, 0)

############### compute pairwise scores (based on IoMin) ###############

    start = time.time()

    if debug:
        n_ys_events = max(
            [np.max(y_instances[event_type])
                for event_type in ca_release_events]
        )

        n_preds_events = max(
            [np.max(pred_instances[event_type])
                for event_type in ca_release_events]
        )
        logger.debug(
            f"Testing function: computing pairwise scores between {n_ys_events} annotated events and {n_preds_events} predicted events")

    iomin_scores = get_score_matrix(
        ys_instances=y_instances,
        preds_instances=pred_instances,
        ignore_mask=None,
        score="iomin",
    )

    logger.debug(
        f"Time to compute pairwise scores: {time.time() - start:.2f} s")

    ####################### get matches summary #######################

    start = time.time()

    logger.debug("Testing function: getting matches summary")

    matched_ys_ids[sample_id], matched_preds_ids[sample_id] = get_matches_summary(
        ys_instances=y_instances,
        preds_instances=pred_instances,
        scores=iomin_scores,
        t=iomin_t,
        ignore_mask=ignore_mask,
    )

    # count number of categorized events that are necessary for the metrics
    for ca_event in ca_release_events:
        for cat in matched_ys_ids[sample_id][ca_event].keys():
            matched_ys_ids['sum'][ca_event][cat] += len(matched_ys_ids[sample_id][ca_event][cat])
        
        for cat in matched_preds_ids[sample_id][ca_event].keys():
            matched_preds_ids['sum'][ca_event][cat] += len(matched_preds_ids[sample_id][ca_event][cat])

    logger.debug(
        f"Time to get matches summary: {time.time() - start:.2f} s")


[17:08:31] [ DEBUG  ] [  __main__  ] < 14 > -- Processing sample 05...
[17:08:31] [ DEBUG  ] [  __main__  ] < 34 > -- Testing function: computing pairwise scores between 97 annotated events and 156 predicted events
[17:08:49] [ DEBUG  ] [  __main__  ] < 44 > -- Time to compute pairwise scores: 17.90 s
[17:08:49] [ DEBUG  ] [  __main__  ] < 51 > -- Testing function: getting matches summary
[17:08:51] [ DEBUG  ] [  __main__  ] < 69 > -- Time to get matches summary: 2.43 s
[17:08:51] [ DEBUG  ] [  __main__  ] < 14 > -- Processing sample 10...
[17:08:51] [ DEBUG  ] [  __main__  ] < 34 > -- Testing function: computing pairwise scores between 29 annotated events and 51 predicted events
[17:08:56] [ DEBUG  ] [  __main__  ] < 44 > -- Time to compute pairwise scores: 5.00 s
[17:08:56] [ DEBUG  ] [  __main__  ] < 51 > -- Testing function: getting matches summary
[17:08:58] [ DEBUG  ] [  __main__  ] < 69 > -- Time to get matches summary: 1.37 s
[17:08:58] [ DEBUG  ] [  __main__  ] < 14 > -- Proce

In [16]:
##################### compute instances-based metrics ######################
# get dict that only contains sum over all samples
tot_preds = {ca_class: matched_preds_ids['sum'][ca_class]['tot']
             for ca_class in ca_release_events}
tp_preds = {ca_class: matched_preds_ids['sum'][ca_class]['tp']
            for ca_class in ca_release_events}
ignored_preds = {ca_class: matched_preds_ids['sum'][ca_class]['ignored']
                 for ca_class in ca_release_events}
unlabeled_preds = {ca_class: matched_preds_ids['sum'][ca_class]['unlabeled']
                   for ca_class in ca_release_events}
tot_ys = {ca_class: matched_ys_ids['sum'][ca_class]['tot']
          for ca_class in ca_release_events}
tp_ys = {ca_class: matched_ys_ids['sum'][ca_class]['tp']
         for ca_class in ca_release_events}
undetected_ys = {ca_class: matched_ys_ids['sum'][ca_class]['undetected']
                 for ca_class in ca_release_events}

# get other metrics (precision, recall, % correctly classified, % detected)
metrics_all = get_metrics_from_summary(tot_preds=tot_preds,
                                        tp_preds=tp_preds,
                                        ignored_preds=ignored_preds,
                                        unlabeled_preds=unlabeled_preds,
                                        tot_ys=tot_ys,
                                        tp_ys=tp_ys,
                                        undetected_ys=undetected_ys)

In [17]:
import pprint

In [18]:
print("Metrics for all samples:")
pprint.pprint(metrics_all)

Metrics for all samples:
{'average/correctly_classified': 0.7509044477548414,
 'average/detected': 0.7510720963551153,
 'average/precision': 0.5889053195315935,
 'average/recall': 0.5663558437143342,
 'puffs/correctly_classified': 0.6621621621621622,
 'puffs/detected': 0.7432432432432432,
 'puffs/precision': 0.37404580152671757,
 'puffs/recall': 0.5540540540540541,
 'sparks/correctly_classified': 0.5905511811023622,
 'sparks/detected': 0.6528301886792454,
 'sparks/precision': 0.39267015706806285,
 'sparks/recall': 0.5735849056603773,
 'waves/correctly_classified': 1.0,
 'waves/detected': 0.8571428571428572,
 'waves/precision': 1.0,
 'waves/recall': 0.5714285714285714}


In [19]:
print("Summary of annotated events:")
pprint.pprint(matched_ys_ids['sum'])

Summary of annotated events:
{'puffs': {'sparks': 39, 'tot': 74, 'tp': 41, 'undetected': 19, 'waves': 0},
 'sparks': {'puffs': 24, 'tot': 265, 'tp': 152, 'undetected': 92, 'waves': 0},
 'waves': {'puffs': 5, 'sparks': 6, 'tot': 7, 'tp': 4, 'undetected': 1}}


In [20]:
print("Summary of predicted events:")
pprint.pprint(matched_preds_ids['sum'])

Summary of predicted events:
{'puffs': {'ignored': 7,
           'sparks': 18,
           'tot': 138,
           'tp': 49,
           'unlabeled': 57,
           'waves': 12},
 'sparks': {'ignored': 16,
            'puffs': 69,
            'tot': 398,
            'tp': 150,
            'unlabeled': 128,
            'waves': 36},
 'waves': {'ignored': 0,
           'puffs': 0,
           'sparks': 0,
           'tot': 4,
           'tp': 4,
           'unlabeled': 0}}


In [21]:
# Compute same metrics but as % instad of absolute numbers
matched_preds_percent = {ca_event: {} for ca_event in ca_release_events}
matched_ys_percent = {ca_event: {} for ca_event in ca_release_events}

for ca_event in ca_release_events:
    for cat in matched_preds_ids['sum'][ca_event].keys():
        if cat != 'tot':
            matched_preds_percent[ca_event][cat] = (matched_preds_ids['sum'][ca_event][cat] 
                                                    / matched_preds_ids['sum'][ca_event]['tot'] * 100)
        
    for cat in matched_ys_ids['sum'][ca_event].keys():
        if cat != 'tot':
            matched_ys_percent[ca_event][cat] = (matched_ys_ids['sum'][ca_event][cat] 
                                                    / matched_ys_ids['sum'][ca_event]['tot'] * 100)

In [22]:
print("Summary of annotated events (as %):")
pprint.pprint(matched_ys_percent)

Summary of annotated events (as %):
{'puffs': {'sparks': 52.702702702702695,
           'tp': 55.4054054054054,
           'undetected': 25.675675675675674,
           'waves': 0.0},
 'sparks': {'puffs': 9.056603773584905,
            'tp': 57.35849056603774,
            'undetected': 34.71698113207547,
            'waves': 0.0},
 'waves': {'puffs': 71.42857142857143,
           'sparks': 85.71428571428571,
           'tp': 57.14285714285714,
           'undetected': 14.285714285714285}}


In [48]:
print("Summary of predicted events (as %):")
pprint.pprint(matched_preds_percent)

Summary of predicted events (as %):
{'puffs': {'ignored': 5.072463768115942,
           'sparks': 13.043478260869565,
           'tp': 35.507246376811594,
           'unlabeled': 41.30434782608695,
           'waves': 8.695652173913043},
 'sparks': {'ignored': 4.0201005025125625,
            'puffs': 17.33668341708543,
            'tp': 37.68844221105528,
            'unlabeled': 32.1608040201005,
            'waves': 9.045226130653267},
 'waves': {'ignored': 0.0,
           'puffs': 0.0,
           'sparks': 0.0,
           'tp': 100.0,
           'unlabeled': 0.0}}


In [23]:
# compute IoU scores
ys

{'05': array([[[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        ...,
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 

In [24]:
preds

{'05': array([[[[1.0000000e+00, 1.0000000e+00, 1.0000000e+00, ...,
           1.0000000e+00, 1.0000000e+00, 1.0000000e+00],
          [1.0000000e+00, 1.0000000e+00, 1.0000000e+00, ...,
           1.0000000e+00, 1.0000000e+00, 1.0000000e+00],
          [1.0000000e+00, 1.0000000e+00, 1.0000000e+00, ...,
           1.0000000e+00, 1.0000000e+00, 1.0000000e+00],
          ...,
          [1.0000000e+00, 1.0000000e+00, 1.0000000e+00, ...,
           1.0000000e+00, 1.0000000e+00, 1.0000000e+00],
          [1.0000000e+00, 1.0000000e+00, 1.0000000e+00, ...,
           1.0000000e+00, 1.0000000e+00, 1.0000000e+00],
          [1.0000000e+00, 1.0000000e+00, 1.0000000e+00, ...,
           1.0000000e+00, 1.0000000e+00, 1.0000000e+00]],
 
         [[1.0000000e+00, 1.0000000e+00, 1.0000000e+00, ...,
           1.0000000e+00, 1.0000000e+00, 1.0000000e+00],
          [1.0000000e+00, 1.0000000e+00, 1.0000000e+00, ...,
           1.0000000e+00, 1.0000000e+00, 1.0000000e+00],
          [1.0000000e+00, 1.0000