In [None]:
import torch

def available_gpus():
    gpus = torch.cuda.device_count()
    return [torch.cuda.get_device_name(i) for i in range(gpus)]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("GPUs disponibles:", available_gpus())

## Libraries

In [2]:
from libraries.utils import read_csv
import pandas as pd
import torch

from metrics_libraries.metrics_calculator import MetricsCalculator

## Parameters

In [3]:
MISSION = 2
PHASE = 1

WINDOW_SIZE = 50
BATCH_SIZE = 1024

CHANNELS = ["allchannels", "subset", "target"][2]
FIRST_CHANNEL_NUMBER = 18  # Only if CHANNELS == "subset" 
LAST_CHANNEL_NUMBER = 28  # Only if CHANNELS == "subset"

MODEL_SAVE_PATH = f"../models/AutoEnconderFullWindow/Phase1_Channels18-28_window50_epochs25_lr0.0001.pth"
CHANNELS_INFO_PATH = "../data/Mission2-ESA/channels.csv"
ESA_ANOMALIES_PATH = "../esa-anomalies/anomalies_mission2.csv"
INPUT_DATA_PATH = f'../data/Mission2-Preprocessed/data_preprocessed_target_frequency-previous_2000_2003.csv'

In [4]:
mission1_phases_dates = {
    "test_start_date": "2007-01-01T00:00:00",
    "test_end_date": "2014-01-01T00:00:00",

    "phase1_start_date_train": "2000-01-01T00:00:00",
    "phase1_end_date_train": "2000-03-11T00:00:00",
    "phase1_start_date_val": "2000-03-11T00:00:00",
    "phase1_end_date_val": "2000-04-01T00:00:00",
    
    "phase2_start_date_train": "2000-01-01T00:00:00",
    "phase2_end_date_train": "2000-09-01T00:00:00",
    "phase2_start_date_val": "2000-09-01T00:00:00",
    "phase2_end_date_val": "2000-11-01T00:00:00",
    
    "phase3_start_date_train": "2000-01-01T00:00:00",
    "phase3_end_date_train": "2001-07-01T00:00:00",
    "phase3_start_date_val": "2001-07-01T00:00:00",
    "phase3_end_date_val": "2001-11-01T00:00:00",
    
    "phase4_start_date_train": "2000-01-01T00:00:00",
    "phase4_end_date_train": "2003-04-01T00:00:00",
    "phase4_start_date_val": "2003-04-01T00:00:00",
    "phase4_end_date_val": "2003-07-01T00:00:00",
    
    "phase5_start_date_train": "2000-01-01T00:00:00",
    "phase5_end_date_train": "2006-10-01T00:00:00",
    "phase5_start_date_val": "2006-10-01T00:00:00",
    "phase5_end_date_val": "2007-01-01T00:00:00"
}

mission2_phases_dates = {
    "test_start_date": "2001-10-01T00:00:00",
    "test_end_date": "2003-07-01T00:00:00",

    "phase1_start_date_train": "2000-01-01T00:00:00",
    "phase1_end_date_train": "2000-01-24T00:00:00",
    "phase1_start_date_val": "2000-01-24T00:00:00",
    "phase1_end_date_val": "2000-02-01T00:00:00",
    
    "phase2_start_date_train": "2000-01-01T00:00:00",
    "phase2_end_date_train": "2000-05-01T00:00:00",
    "phase2_start_date_val": "2000-05-01T00:00:00",
    "phase2_end_date_val": "2000-06-01T00:00:00",
    
    "phase3_start_date_train": "2000-01-01T00:00:00",
    "phase3_end_date_train": "2000-09-01T00:00:00",
    "phase3_start_date_val": "2000-09-01T00:00:00",
    "phase3_end_date_val": "2000-11-01T00:00:00",
    
    "phase4_start_date_train": "2000-01-01T00:00:00",
    "phase4_end_date_train": "2001-07-01T00:00:00",
    "phase4_start_date_val": "2001-07-01T00:00:00",
    "phase4_end_date_val": "2001-10-01T00:00:00"
}

missions_phases_dates = {
    1: mission1_phases_dates,
    2: mission2_phases_dates
}

In [5]:
start_date_train = pd.to_datetime(missions_phases_dates[MISSION][f"phase{PHASE}_start_date_train"])
end_date_train = pd.to_datetime(missions_phases_dates[MISSION][f"phase{PHASE}_end_date_train"])
start_date_val = pd.to_datetime(missions_phases_dates[MISSION][f"phase{PHASE}_start_date_val"])
end_date_val = pd.to_datetime(missions_phases_dates[MISSION][f"phase{PHASE}_end_date_val"])
start_date_test = pd.to_datetime(missions_phases_dates[MISSION]["test_start_date"])
end_date_test = pd.to_datetime(missions_phases_dates[MISSION]["test_end_date"])

if CHANNELS == "target":
    channels_info = pd.read_csv(CHANNELS_INFO_PATH)
    channels_list = list(channels_info[channels_info['Target']=="YES"]['Channel'])
else:
    channels_list = None if CHANNELS == "allchannels" else [f"channel_{i}" for i in range(FIRST_CHANNEL_NUMBER, LAST_CHANNEL_NUMBER+1)]

## Load model

In [6]:
checkpoint = torch.load(MODEL_SAVE_PATH)
model = checkpoint['model']   # Load the full model
threshold_list = checkpoint['threshold']  # Access the threshold metadata
scaler = checkpoint['scaler']  # Access the scaler metadata

## Load data

In [7]:
data = read_csv(INPUT_DATA_PATH, sep=";")
if channels_list is not None:
    data = data[channels_list]

# Filtrar los datos entre start_date_train y end_date_train
data_train = data.loc[(data.index >= start_date_train) & (data.index < end_date_train)]
data_val = data.loc[(data.index >= start_date_val) & (data.index < end_date_val)]
data_test = data.loc[(data.index >= start_date_test) & (data.index < end_date_test)]

In [8]:
df_train_normalized = pd.DataFrame(scaler.transform(data_train), index=data_train.index, columns=data_train.columns)
df_val_normalized = pd.DataFrame(scaler.transform(data_val), index=data_val.index, columns=data_val.columns)
df_test_normalized = pd.DataFrame(scaler.transform(data_test), index=data_test.index, columns=data_test.columns)

## Predict

In [None]:
anomalies_train = model.predict(threshold_list, df_train_normalized, WINDOW_SIZE, BATCH_SIZE, device)
anomalies_train.head()

In [None]:
anomalies_val = model.predict(threshold_list, df_val_normalized, WINDOW_SIZE, BATCH_SIZE, device)
anomalies_val.head()

In [None]:
anomalies_test = model.predict(threshold_list, df_test_normalized, WINDOW_SIZE, BATCH_SIZE, device)
anomalies_test.head()

## Calculate metrics

In [12]:
metrics_calculator = MetricsCalculator(ESA_ANOMALIES_PATH, CHANNELS_INFO_PATH, channels_list)

In [None]:
metrics_train = metrics_calculator.get_metrics(anomalies_train, start_date_train, end_date_train)
metrics_calculator.print_metrics_table(metrics_train)

In [None]:
anomalies_val = metrics_calculator.get_metrics(anomalies_val, start_date_val, end_date_val)
metrics_calculator.print_metrics_table(anomalies_val)

In [None]:
anomalies_test = metrics_calculator.get_metrics(anomalies_test, start_date_test, end_date_test)
metrics_calculator.print_metrics_table(anomalies_test)