# Predict general data

In [17]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
import shutil
import csv
import re
from collections import defaultdict
from datetime import datetime
import pandas as pd
import numpy as np
import torch

import sys
sys.path.append("./missing_data")

from mpl_toolkits.basemap import Basemap
from build_dataset import get_files_from_folder, extract_dates_pattern_airmass_rgb_20200101_0000
from build_dataset import load_cyclones_track_noheader, compute_pixel_scale, inside_tile, calc_tile_offsets, save_single_tile
from build_dataset import split_into_tiles_subfolders_and_track_dates, create_and_save_tile_from_complete_df
from build_dataset import create_final_df_csv
from build_dataset import get_gruppi_date, create_tile_videos
from medicane_utils.geo_const import latcorners, loncorners, x_center, y_center, create_basemap_obj
from medicane_utils.load_files import load_all_images

from view_test_tiles import create_labeled_images_with_tiles

from arguments import prepare_finetuning_args, Args
from dataset.datasets import MedicanesClsDataset
from torch.utils.data import DataLoader
import models
from timm.models import create_model

basemap_obj = create_basemap_obj()
file_master_df = "all_data_full_tiles.csv"

output_dir = "../airmassRGB/supervised/" 

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:

df_data = pd.read_csv(file_master_df, dtype={
        "path": 'string',
        "tile_offset_x": 'int16',
        "tile_offset_y": 'int16',
        "label": 'int16',
        "lat": 'object',
        "lon": 'object',
        "x_pix": 'object',
        "y_pix": 'object',
        "name": 'string',
        "source": 'string'
    }, parse_dates=['datetime'])
df_data.drop(columns="Unnamed: 0", inplace=True)

In [27]:
df_2023 = df_data[df_data.datetime > datetime(2023,9,1)]

In [28]:
df_2023.drop(columns=['lat','lon','x_pix','y_pix','name'], inplace=True)
df_2023

Unnamed: 0,path,datetime,tile_offset_x,tile_offset_y,label,source
1382628,../fromgcloud/airmass_rgb_20230903_0000.png,2023-09-03,0,0,0,[]
1382629,../fromgcloud/airmass_rgb_20230903_0000.png,2023-09-03,213,0,0,[]
1382630,../fromgcloud/airmass_rgb_20230903_0000.png,2023-09-03,426,0,0,[]
1382631,../fromgcloud/airmass_rgb_20230903_0000.png,2023-09-03,639,0,0,[]
1382632,../fromgcloud/airmass_rgb_20230903_0000.png,2023-09-03,852,0,0,[]
...,...,...,...,...,...,...
1413163,../fromgcloud/airmass_rgb_20230912_0000.png,2023-09-12,213,196,0,[]
1413164,../fromgcloud/airmass_rgb_20230912_0000.png,2023-09-12,426,196,0,[]
1413165,../fromgcloud/airmass_rgb_20230912_0000.png,2023-09-12,639,196,0,[]
1413166,../fromgcloud/airmass_rgb_20230912_0000.png,2023-09-12,852,196,0,[]


In [12]:
def sub_select_frequency(df, freq='20min'):
    # selezionare ore intere
    df['dt_floor'] = df['datetime'].dt.floor(freq)
    mask = df['datetime'] == df['dt_floor']
    df_filtered = df[mask]
    #grouped = df_filtered.groupby("path", dropna=False)
    return df_filtered

In [16]:
df_2023_20m = sub_select_frequency(df_2023)
df_2023_20m.shape[0]

7644

In [21]:
df_videos = create_tile_videos(df_2023_20m, supervised=False)
df_videos.label=0
df_videos

Unnamed: 0,video_id,tile_offset_x,tile_offset_y,path,label,start_time,end_time,orig_paths
0,0,0,0,03-09-2023_0500_0_0,0,2023-09-03 00:00:00,2023-09-03 05:00:00,"[../fromgcloud/airmass_rgb_20230903_0000.png, ..."
1,1,0,0,03-09-2023_1020_0_0,0,2023-09-03 05:20:00,2023-09-03 10:20:00,"[../fromgcloud/airmass_rgb_20230903_0520.png, ..."
2,2,0,0,03-09-2023_1540_0_0,0,2023-09-03 10:40:00,2023-09-03 15:40:00,"[../fromgcloud/airmass_rgb_20230903_1040.png, ..."
3,3,0,0,04-09-2023_0100_0_0,0,2023-09-03 16:00:00,2023-09-04 01:00:00,"[../fromgcloud/airmass_rgb_20230903_1600.png, ..."
4,4,0,0,04-09-2023_0620_0_0,0,2023-09-04 01:20:00,2023-09-04 06:20:00,"[../fromgcloud/airmass_rgb_20230904_0120.png, ..."
...,...,...,...,...,...,...,...,...
463,463,1065,196,10-09-2023_2220_1065_196,0,2023-09-10 17:20:00,2023-09-10 22:20:00,"[../fromgcloud/airmass_rgb_20230910_1720.png, ..."
464,464,1065,196,11-09-2023_0340_1065_196,0,2023-09-10 22:40:00,2023-09-11 03:40:00,"[../fromgcloud/airmass_rgb_20230910_2240.png, ..."
465,465,1065,196,11-09-2023_0900_1065_196,0,2023-09-11 04:00:00,2023-09-11 09:00:00,"[../fromgcloud/airmass_rgb_20230911_0400.png, ..."
466,466,1065,196,11-09-2023_1420_1065_196,0,2023-09-11 09:20:00,2023-09-11 14:20:00,"[../fromgcloud/airmass_rgb_20230911_0920.png, ..."


In [24]:
create_and_save_tile_from_complete_df(df_videos, output_dir)

Creazione delle folder per i 468 video...
Salvati 6537 file - Erano già presenti 951 file - File totali 7488


In [26]:
df_dataset_csv = create_final_df_csv(df_videos, output_dir)
df_dataset_csv.to_csv("./general_inference_set.csv", index=False)

In [None]:
# bonus -> mi guardo il video del periodo
grouped = df_2023.groupby("path", dropna=False)
print(len(grouped))
create_labeled_images_with_tiles(grouped, 'daniel_complete_tiles.gif', basemap_obj, 213, 196)

2545

# Inference on this set

In [None]:
# carico i dati
args = prepare_finetuning_args()



dataset_val = MedicanesClsDataset(
        anno_path="./general_inference_set.csv",
        data_root=args.data_root,
        mode='validation',  # oppure 'test'
        clip_len=args.num_frames,
        transform=None
    )
nb_classes = 2

data_loader_val = DataLoader(
    dataset_val,
    batch_size=args.batch_size,
    shuffle=True,         # Per estrarre sample casuali
    num_workers=args.num_workers,
    pin_memory=args.pin_mem,
    drop_last=False
)

# carico il modello
get_prediction = True
model = create_model(
    args.model,
    num_classes=args.nb_classes,
    drop_rate=0.0,
    drop_path_rate=args.drop_path,
    #attn_drop_rate=0.0,
    drop_block_rate=None,
    **args.__dict__
)

# Carica i pesi del checkpoint nel modello
checkpoint_path = "output/checkpoint-best.pth"  
if os.path.exists(checkpoint_path):
    ckpt = torch.load(checkpoint_path, map_location="cpu")
    if "model" in ckpt:
        missing = model.load_state_dict(ckpt["model"], strict=False)
        print(f"Checkpoint loaded. Missing keys: {missing.missing_keys}")
    else:
        # Altri formati di caricamento possibili, a seconda di come hai salvato.
        model.load_state_dict(ckpt, strict=False)
    print("Checkpoint caricato correttamente.")
else:
    print("ATTENZIONE: file checkpoint non trovato. Userai i pesi random del modello.")


model.to(args.device)
model.eval()   


In [29]:
def predict_label(model, videos):    
    with torch.no_grad():
        logits = model(videos)  # (B, nb_classes)
        predicted_classes = torch.argmax(logits, dim=1)  # intero con l'indice di classe
    
    return predicted_classes

def get_path_pred_label(model, data_loader):
    all_paths = []
    all_labels = []
    all_preds = []
    for videos, labels, folder_path in data_loader:
        videos = videos.to(args.device)
        predicted_classes = predict_label(model, videos) # shape (batch, num_class)
        labels = labels.detach().cpu().numpy()
        pred_classes = predicted_classes.detach().cpu().numpy()
        
        all_labels.extend(labels)
        all_preds.extend(pred_classes)
        all_paths.extend(folder_path)

    return all_paths, all_preds, all_labels

def create_df_predictions(all_paths, all_preds, all_labels):
    video_folder_name = pd.Series(all_paths).str.split('/').str.get(-1)
    predictions_series = pd.Series(all_preds)
    labels_series = pd.Series(all_labels)
    res_df = pd.concat([video_folder_name, predictions_series, labels_series], axis=1)
    res_df.columns = ['path', 'predictions', 'labels']

    return res_df

In [None]:
all_paths, all_preds, all_labels = get_path_pred_label(model, data_loader_val)

In [None]:
df_predictions = create_df_predictions(all_paths, all_preds, all_labels)

In [None]:
df_filtrato_on_video_path = df_videos.merge(df_predictions, on='path')