06/10/2023

Uso questo script per ricreare i datasets cercando di strutturarli meglio
- dataset che prende movies e labels come inputs,
- dataset che prende dataset_path e movie ids come inputs,
- dataset che gestisce l'inference con o senza ground truth

In [1]:
# reload modules automatically
%load_ext autoreload
%autoreload 2

In [2]:
import os

import numpy as np
import torch
from torch.backends import cudnn
from torch import nn, optim
from torch.utils.data import DataLoader
from data.datasets import SparkDataset
from sklearn.metrics import confusion_matrix

import logging
import time
from typing import Dict, Union, Tuple, List, Optional

from config import config, TrainingConfig
from data.data_processing_tools import (
    masks_to_instances_dict,
    preds_dict_to_mask,
    process_raw_predictions,
)
from evaluation.metrics_tools import compute_iou, get_matches_summary, get_metrics_from_summary, get_score_matrix
from utils.custom_losses import MySoftDiceLoss
from utils.in_out_tools import write_videos_on_disk
from utils.training_inference_tools import training_step, sampler
from utils.training_script_utils import init_model, init_dataset, init_criterion, get_sample_ids

logger = logging.getLogger(__name__)
config.verbosity = 3  # To get debug messages

In [3]:
# Create a TrainingConfig object
# params = TrainingConfig()
config_filename = os.path.join("config_files", "config_final_model.ini")
params = TrainingConfig(training_config_file=config_filename)

# Adapt parameters for debugging
# params.inference_dataset_size = "minimal"
# params.inference_batch_size = 2
# params.data_duration = 64
# params.set_device(device="cpu")

# Select samples for training and testing based on dataset size
train_sample_ids = get_sample_ids(
    train_data=True,
    dataset_size=params.dataset_size,
)

# Create a sparkdataset
dataset = init_dataset(
    params=params,
    sample_ids=train_sample_ids,
    apply_data_augmentation=True,
    load_instances=False,
)

# Create a dataloader
dataset_loader = DataLoader(
    dataset,
    batch_size=params.batch_size,
    shuffle=False,
    num_workers=params.num_workers,
    pin_memory=params.pin_memory,
)

# Create a U-Net
network = init_model(params=params)
# network = network.to(params.device, non_blocking=True)
network = nn.DataParallel(network).to(params.device, non_blocking=True)
# cudnn.benchmark = True

[12:29:34] [  INFO  ] [   config   ] <291 > -- Loading C:\Users\dotti\sparks_project\config_files\config_final_model.ini
[12:29:34] [  INFO  ] [utils.training_script_utils] <137 > -- Samples in training dataset: 9


In [4]:
# get item from dataloader
batch = next(iter(dataset_loader))

In [5]:
batch.keys(), batch["movie_id"], batch["data"].shape, batch["labels"].shape

(dict_keys(['movie_id', 'original_duration', 'data', 'labels', 'sample_id']),
 tensor([0, 0, 0, 0]),
 torch.Size([4, 256, 64, 512]),
 torch.Size([4, 256, 64, 512]))

### TODO:  RIORGANIZZARE QUESTE FUNZIONI

Domanda: ha davvero senso utilizzare get_preds etc e non solo do_inference? Sembra un po' tanto lavoro inutile ricostruire i labels e i video originali dal dataset...

In [6]:
# load trained model
load_epoch = 100000
model_filename = f"network_{load_epoch:06d}.pth"

# Path to the saved model checkpoint
models_relative_path = os.path.join(
    "models", "saved_models", params.run_name, model_filename
)
model_dir = os.path.realpath(os.path.join(
    config.basedir, models_relative_path))

# Load the model state dictionary
logger.info(
    f"Loading trained model '{params.run_name}' at epoch {load_epoch}...")
network.load_state_dict(torch.load(model_dir, map_location=params.device))

[12:29:36] [  INFO  ] [  __main__  ] < 13 > -- Loading trained model 'final_model' at epoch 100000...


<All keys matched successfully>

In [7]:
optimizer = optim.Adam(network.parameters(), lr=params.lr_start)
criterion = nn.NLLLoss(
    ignore_index=config.ignore_index,  # .to(
    # params.device, non_blocking=True)
)

In [8]:
# torch.set_float32_matmul_precision("high")

In [9]:
loss = training_step(
    dataset_loader=dataset_loader,
    params=params,
    sampler=sampler,
    network=network,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=None,

)



In [10]:
loss

{'loss': 0.2515716254711151}