06/10/2023

Uso questo script per ricreare i datasets cercando di strutturarli meglio
- dataset che prende movies e labels come inputs,
- dataset che prende dataset_path e movie ids come inputs,
- dataset che gestisce l'inference con o senza ground truth

In [1]:
# reload modules automatically
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
import logging
import math
import ntpath
import os

from typing import List, Dict, Union, Tuple, Any

import imageio
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from PIL.ExifTags import TAGS
from scipy.interpolate import interp1d
from scipy.ndimage import convolve
from torch import nn
from torch.utils.data import Dataset
from torchvision.transforms import GaussianBlur
from torch.utils.data import DataLoader

from config import config, TrainingConfig
from data.data_processing_tools import detect_spark_peaks
from utils.in_out_tools import load_annotations_ids, load_movies_ids
from utils.training_script_utils import init_model, init_dataset
from utils.training_inference_tools import do_inference

In [3]:
# Create a TrainingConfig object
params = TrainingConfig()

# Adapt parameters for debugging
params.inference_dataset_size = "minimal"
params.inference_batch_size = 2
params.data_duration = 64
params.set_device(device="cpu")

# Create a sparkdataset
dataset = init_dataset(params=params, sample_ids=["05", "34"], testing_dataset=True)

# Create a dataloader
dataset_loader = DataLoader(
    dataset,
    batch_size=params.inference_batch_size,
    shuffle=False,
    num_workers=params.num_workers,
    pin_memory=params.pin_memory,
)

# Create a U-Net
network = init_model(params=params)

[21:20:19] [  INFO  ] [utils.training_script_utils] <141 > -- Samples in training dataset: 43


In [4]:
# get item from dataloader
batch = next(iter(dataset_loader))

In [5]:
batch.keys(), batch["movie_id"], batch["data"].shape, batch["labels"].shape

(dict_keys(['movie_id', 'original_duration', 'data', 'labels', 'sample_id']),
 tensor([0, 0]),
 torch.Size([2, 64, 64, 512]),
 torch.Size([2, 64, 64, 512]))

### TODO:  RIORGANIZZARE QUESTE FUNZIONI

In [None]:
# function to run a test sample (i.e., a test dataset) in the UNet
from typing import Optional


@torch.no_grad()
def get_raw_preds_dict(
    model: torch.nn.Module,
    params: TrainingConfig,
    test_dataset: torch.utils.data.Dataset,
    criterion: Optional[torch.nn.Module] = None,
    inference_types: Optional[List[str]] = None,
):  # TODO
    """
    Given a trained model and a test sample (i.e., a test dataset), run the
    sample in the model and return the predictions.

    Args:
        model (torch.nn.Module): The trained neural network model.
        test_dataset (torch.utils.data.Dataset): The test dataset containing the
            sample(s).
        params (TrainingConfig): A TrainingConfig containing various parameters.
        criterion (torch.nn.Module, optional): If provided, the loss criterion
            for computing loss.
        inference_types (list of str, optional): List of inference types to use,
            or None to use the default type.

    Returns:
    TODO
    """
    if inference_types is None:
        assert params.inference in [
            "overlap",
            "average",
            "gaussian",
            "max",
        ], f"inference type '{params.inference}' not implemented yet"
        inference_types = [params.inference]

    else:
        assert all(
            i in ["overlap", "average", "gaussian", "max"] for i in inference_types
        ), "Unsupported inference type."

    # Create a dataloader
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=params.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )

    # Run movie in the network and perform inference
    preds = do_inference(
    network=model,
    params=params,
    dataloader=test_dataloader,
    device=params.device,
    compute_loss=True if criterion is not None else False,
    inference_typesinference_types
)

    # Get original movie xs and annotations ys
    xs = test_dataset.data[0]
    if test_dataset.gt_available:
        ys = test_dataset.annotations[0]

    # Remove padded frames
    pad = test_dataset.pad
    if pad > 0:
        start_pad = pad // 2
        end_pad = -(pad // 2 + pad % 2)
        xs = xs[start_pad:end_pad]

        if test_dataset.temporal_reduction:
            start_pad = start_pad // test_dataset.num_channels
            end_pad = end_pad // test_dataset.num_channels

        if params.nn_architecture != "unet_lstm":
            if test_dataset.gt_available:
                ys = ys[start_pad:end_pad]
            if len(inference_types) == 1:
                if not return_dict:
                    preds = preds[:, start_pad:end_pad]
                else:
                    preds = {
                        event_type: pred[start_pad:end_pad]
                        for event_type, pred in preds.items()
                    }
            else:
                if not return_dict:
                    preds = {i: p[:, start_pad:end_pad] for i, p in preds.items()}
                else:
                    for i, preds_dict in preds.items():
                        preds[i] = {
                            event_type: pred[start_pad:end_pad]
                            for event_type, pred in preds_dict.items()
                        }

        else:
            raise NotImplementedError

    # If original sample was shorter than the current movie duration,
    # remove additional padded frames
    movie_duration = test_dataset.movie_duration
    if movie_duration < xs.shape[0]:
        pad = xs.shape[0] - movie_duration
        start_pad = pad // 2
        end_pad = -(pad // 2 + pad % 2)
        xs = xs[start_pad:end_pad]

        if test_dataset.temporal_reduction:
            start_pad = start_pad // test_dataset.num_channels
            end_pad = end_pad // test_dataset.num_channels

        if ys is not None:
            ys = ys[start_pad:end_pad]

        if len(inference_types) == 1:
            if not return_dict:
                preds = preds[:, start_pad:end_pad]
            else:
                preds = {
                    event_type: pred[start_pad:end_pad]
                    for event_type, pred in preds.items()
                }
        else:
            if not return_dict:
                preds = {i: p[:, start_pad:end_pad] for i, p in preds.items()}
            else:
                for i, preds_dict in preds.items():
                    preds[i] = {
                        event_type: pred[start_pad:end_pad]
                        for event_type, pred in preds_dict.items()
                    }

    if criterion is not None:
        assert ys is not None, "Cannot compute loss if annotations are not available."

        if ys.ndim == 3:
            if len(inference_types) == 1 and not return_dict:
                preds_loss = preds[
                    :, test_dataset.ignore_frames : -test_dataset.ignore_frames
                ]
            else:
                raise NotImplementedError
                # Still need to adapt code to compute loss for list of inference
                # types, however usually loss should be computed only during
                # training, and therefore inference_types should be None.
                # Similarly, return_dict should be False.

            ys_loss = ys[test_dataset.ignore_frames : -test_dataset.ignore_frames]
        else:
            raise NotImplementedError

        if params.criterion == "dice_loss":
            # set regions in pred where label is ignored to 0
            preds_loss = preds_loss * (ys_loss != 4)
            ys_loss = ys_loss * (ys_loss != 4)
        else:
            ys_loss = ys_loss.long()[None, :]
            preds_loss = preds_loss[None, :]

        # Move criterion weights to cpu
        if hasattr(criterion, "weight") and criterion.weight.is_cuda:
            criterion.weight = criterion.weight.cpu()
        if hasattr(criterion, "NLLLoss") and criterion.NLLLoss.weight.is_cuda:
            criterion.NLLLoss.weight = criterion.NLLLoss.weight.cpu()

        loss = criterion(preds_loss, ys_loss).item()
        return xs.numpy(), ys.numpy(), preds.numpy(), loss

    else:
        if len(inference_types) == 1:
            if not return_dict:
                preds = preds.numpy()
            else:
                preds = {event_type: pred.numpy() for event_type, pred in preds.items()}
        else:
            if not return_dict:
                preds = {i: p.numpy() for i, p in preds.items()}
            else:
                for i, preds_dict in preds.items():
                    preds[i] = {
                        event_type: pred.numpy()
                        for event_type, pred in preds_dict.items()
                    }

    return xs.numpy(), ys.numpy(), preds if ys is not None else xs.numpy(), preds

In [None]:
@torch.no_grad()
def get_preds_from_path(  # TODO: vedere se si può eliminare e tenere solo get_preds
    model: nn.Module,
    params: TrainingConfig,
    movie_path: str,
    output_dir: Optional[str] = None,
) -> [Tuple[torch.Tensor, torch.Tensor]]:
    """
    Function to get predictions from a movie path.

    Args:
    - model (torch.nn.Module): The trained neural network model.
    - params (TrainingConfig): A TrainingConfig containing various parameters.
    - movie_path: Path to the movie.
    - return_dict (bool, optional): Whether to return a dictionary with
        inference type as key and predictions as value, or a single tensor of
        predictions. Defaults to False.
    - output_dir: If not None, save raw predictions on disk.

    Returns:
    - If return_dict is True, return a dictionary with keys 'sparks', 'puffs',
        'waves'; else return a tuple of numpy arrays with integral values for
        classes and instances.
    """

    ### Get sample as dataset ###
    sample_dataset = SparkDatasetInference(
        sample_path=movie_path,
        params=params,
        # resampling=False, # It could be implemented later
        # resampling_rate=150,
    )

    ### Run sample in UNet ###
    input_movie, preds_dict = get_raw_preds_dict(
        model=model,
        test_dataset=sample_dataset,
        params=params,
        inference_types=None,
        return_dict=True,
    )

    ### Get processed output ###

    # Get predicted segmentation and event instances
    preds_instances, preds_segmentation, _ = process_raw_predictions(
        raw_preds_dict=preds_dict,
        input_movie=input_movie,
        training_mode=False,
        debug=False,
    )
    # preds_instances and preds_segmentations are dictionaries
    # with keys 'sparks', 'puffs', 'waves'.

    # Save raw preds on disk ### I don't know if this is necessary
    if output_dir is not None:
        # Create output directory if it does not exist
        os.makedirs(output_dir, exist_ok=True)
        write_videos_on_disk(
            training_name=None,
            video_name=sample_dataset.video_name,
            path=output_dir,
            preds=[
                None,
                preds_dict["sparks"],
                preds_dict["waves"],
                preds_dict["puffs"],
            ],
            ys=None,
        )

    if return_dict:
        return preds_segmentation, preds_instances

    else:
        # Get integral values for classes and instances
        preds_segmentation = preds_dict_to_mask(preds_segmentation)
        preds_instances = sum(preds_instances.values())
        # Instances already have different IDs

        return preds_segmentation, preds_instances

In [12]:
pred = do_inference(
    network=network,
    params=params,
    dataloader=dataset_loader,
    device=params.device,
    inference_types=["overlap", "average", "gaussian", "max"],
)

In [13]:
pred[0]["overlap"].shape, pred[0]["average"].shape, pred[0]["gaussian"].shape, pred[0][
    "max"
].shape

(torch.Size([4, 500, 64, 512]),
 torch.Size([4, 500, 64, 512]),
 torch.Size([4, 500, 64, 512]),
 torch.Size([4, 500, 64, 512]))

In [14]:
pred[1]["overlap"].shape, pred[1]["average"].shape, pred[1]["gaussian"].shape, pred[1][
    "max"
].shape

(torch.Size([4, 904, 64, 512]),
 torch.Size([4, 904, 64, 512]),
 torch.Size([4, 904, 64, 512]),
 torch.Size([4, 904, 64, 512]))

In [15]:
# visualize the predictions with napari
import napari

viewer = napari.Viewer()
viewer.add_image(pred[0]["overlap"].numpy())
viewer.add_image(pred[0]["average"].numpy())
viewer.add_image(pred[0]["gaussian"].numpy())
viewer.add_image(pred[0]["max"].numpy())

<Image layer 'Image [3]' at 0x298e8a50520>