## Libraries

In [None]:
import torch

def available_gpus():
    gpus = torch.cuda.device_count()
    return [torch.cuda.get_device_name(i) for i in range(gpus)]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("GPUs disponibles:", available_gpus())

In [2]:
from datetime import date
from typing import List, Literal, Tuple

import dash_bootstrap_components as dbc
import numpy as np
import pandas as pd
import plotly.graph_objs as go
import polars as pl
from dash import Dash, Input, Output, dcc, html
from plotly.subplots import make_subplots
from plotly_resampler import FigureResampler
from plotly_resampler.aggregation import MinMaxLTTB
from libraries.utils import read_csv
from jupyter_dash import JupyterDash

## Parameters

In [3]:
MISSION = 2
PHASE = 5

WINDOW_SIZE = 50
BATCH_SIZE = 256

CHANNELS = ["allchannels", "subset", "target"][2]

MODEL_SAVE_PATH = f"../models/Mission{MISSION}-AutoEnconderFullWindow/Phase4_target_window50_percentile99_epochs25_lr0.0001__2025-01-07_22-14-40.pth"
# MODEL_SAVE_PATH = f"../models/Mission{MISSION}-AutoEnconderLastEvent/Phase3_Channels18-28_window50_percentile99_epochs25_lr0.0001__2025-01-07_14-06-25.pth"
# MODEL_SAVE_PATH = f"../models/Mission{MISSION}-VariationalAutoencoderFullWindow/Phase1_Channels18-28_window50_percentile99_epochs25_lr0.0001__2025-01-07_16-24-43.pth"
# MODEL_SAVE_PATH = f"../models/Mission{MISSION}-VariationalAutoencoderLastEvent/Phase1_Channels18-28_window50_percentile99_epochs25_lr0.0001__2025-01-07_17-24-26.pth"

CHANNELS_INFO_PATH = f"../data/Mission{MISSION}-ESA/channels.csv"
ESA_ANOMALIES_PATH = f"../esa-anomalies/anomalies_mission{MISSION}.csv"
METRICS_SAVE_PATH = f"../metrics/metrics.csv"

In [4]:
first_channel_number = 41 if MISSION == 1 else 18  # Only if CHANNELS == "subset" 
last_channel_number = 46 if MISSION == 1 else 28  # Only if CHANNELS == "subset"

if CHANNELS == "subset":
    input_data_path = f'../data/Mission{MISSION}-Preprocessed/data_preprocessed_channels{first_channel_number}_{last_channel_number}_frequency-previous_2000_{2013 if MISSION == 1 else 2003}.csv'
else:
    input_data_path = f'../data/Mission{MISSION}-Preprocessed/data_preprocessed_{CHANNELS}_frequency-previous_2000_{2013 if MISSION == 1 else 2003}.csv'

In [5]:
mission1_phases_dates = {
    "test_start_date": "2007-01-01T00:00:00",
    "test_end_date": "2014-01-01T00:00:00",

    "phase1_start_date_train": "2000-01-01T00:00:00",
    "phase1_end_date_train": "2000-03-11T00:00:00",
    "phase1_start_date_val": "2000-03-11T00:00:00",
    "phase1_end_date_val": "2000-04-01T00:00:00",
    
    "phase2_start_date_train": "2000-01-01T00:00:00",
    "phase2_end_date_train": "2000-09-01T00:00:00",
    "phase2_start_date_val": "2000-09-01T00:00:00",
    "phase2_end_date_val": "2000-11-01T00:00:00",
    
    "phase3_start_date_train": "2000-01-01T00:00:00",
    "phase3_end_date_train": "2001-07-01T00:00:00",
    "phase3_start_date_val": "2001-07-01T00:00:00",
    "phase3_end_date_val": "2001-11-01T00:00:00",
    
    "phase4_start_date_train": "2000-01-01T00:00:00",
    "phase4_end_date_train": "2003-04-01T00:00:00",
    "phase4_start_date_val": "2003-04-01T00:00:00",
    "phase4_end_date_val": "2003-07-01T00:00:00",
    
    "phase5_start_date_train": "2000-01-01T00:00:00",
    "phase5_end_date_train": "2006-10-01T00:00:00",
    "phase5_start_date_val": "2006-10-01T00:00:00",
    "phase5_end_date_val": "2007-01-01T00:00:00"
}

mission2_phases_dates = {
    "test_start_date": "2001-10-01T00:00:00",
    "test_end_date": "2003-07-01T00:00:00",

    "phase1_start_date_train": "2000-01-01T00:00:00",
    "phase1_end_date_train": "2000-01-24T00:00:00",
    "phase1_start_date_val": "2000-01-24T00:00:00",
    "phase1_end_date_val": "2000-02-01T00:00:00",
    
    "phase2_start_date_train": "2000-01-01T00:00:00",
    "phase2_end_date_train": "2000-05-01T00:00:00",
    "phase2_start_date_val": "2000-05-01T00:00:00",
    "phase2_end_date_val": "2000-06-01T00:00:00",
    
    "phase3_start_date_train": "2000-01-01T00:00:00",
    "phase3_end_date_train": "2000-09-01T00:00:00",
    "phase3_start_date_val": "2000-09-01T00:00:00",
    "phase3_end_date_val": "2000-11-01T00:00:00",
    
    "phase4_start_date_train": "2000-01-01T00:00:00",
    "phase4_end_date_train": "2001-07-01T00:00:00",
    "phase4_start_date_val": "2001-07-01T00:00:00",
    "phase4_end_date_val": "2001-10-01T00:00:00"
}

missions_phases_dates = {
    1: mission1_phases_dates,
    2: mission2_phases_dates
}

In [6]:
start_date = pd.to_datetime(missions_phases_dates[MISSION][f"phase1_start_date_train"])
end_date = pd.to_datetime(missions_phases_dates[MISSION]["test_end_date"])
end_date = pd.to_datetime("2000-03-01T00:00:00") # TODO quitar

if CHANNELS == "target":
    channels_info = pd.read_csv(CHANNELS_INFO_PATH)
    channels_list = list(channels_info[channels_info['Target']=="YES"]['Channel'])
else:
    channels_list = None if CHANNELS == "allchannels" else [f"channel_{i}" for i in range(first_channel_number, last_channel_number+1)]

## Predict

### Load model

In [7]:
checkpoint = torch.load(MODEL_SAVE_PATH)
model = checkpoint['model']   # Load the full model
threshold_list = checkpoint['threshold']  # Access the threshold metadata
scaler = checkpoint['scaler']  # Access the scaler metadata

### Load data

In [8]:
data = read_csv(input_data_path, sep=";")
if channels_list is not None:
    data = data[channels_list]

# Filtrar los datos entre start_date_train y end_date_train
data = data.loc[(data.index >= start_date) & (data.index < end_date)]

In [9]:
df_normalized = pd.DataFrame(scaler.transform(data), index=data.index, columns=data.columns)

### Predict

In [None]:
anomalies = model.predict(threshold_list, df_normalized, WINDOW_SIZE, BATCH_SIZE, device)
anomalies.head()

### Format anomalies

In [None]:
import pandas as pd

def format_anomalies(anomalies: pd.DataFrame) -> pd.DataFrame:
    formatted_data = []

    # Iterar sobre cada canal (columna)
    for channel in anomalies.columns:
        channel_data = anomalies[channel]
        is_active = False  # Para rastrear si estamos dentro de una secuencia activa
        start_time = None  # Almacenar el tiempo de inicio de la anomalía

        # Iterar por cada fila en el canal
        for time, value in channel_data.items():
            if value == 1 and not is_active:
                # Detectamos el inicio de una anomalía
                is_active = True
                start_time = time
            elif value == 0 and is_active:
                # Detectamos el final de una anomalía
                is_active = False
                end_time = time
                # Guardar el resultado
                formatted_data.append({"Channel": channel, "StartTime": start_time, "EndTime": end_time})

        # Manejar el caso en que una anomalía sigue activa hasta el final del DataFrame
        if is_active:
            formatted_data.append({"Channel": channel, "StartTime": start_time, "EndTime": channel_data.index[-1]})

    # Convertir los resultados en un nuevo DataFrame
    anomalies_formatted = pd.DataFrame(formatted_data)

    # Ordenar el DataFrame por StartTime
    anomalies_formatted = anomalies_formatted.sort_values(by="StartTime").reset_index(drop=True)

    return anomalies_formatted

df_anomalies_predicted = format_anomalies(anomalies)
df_anomalies_predicted

## Load ESA anomalies

In [15]:
# ---------------------------------------------------
# Auxiliar functions
# ---------------------------------------------------
def import_data(
    input_data_path: str,
    channels: List[str] = None,
    start_date: date = None,
    end_date: date = None,
) -> pd.DataFrame:
    """Import data from a csv file  and filter it by channels and date.

    Args:
        input_data_path (str): path to the data file
        channels (List[str], optional): list of channels to filter. Defaults to None.
        start_date (date, optional): start date to filter the data. Defaults to None.
        end_date (date, optional): end date to filter the data. Defaults to None.

    Returns:
        pd.DataFrame: dataframe with the data
    """
    result = pl.scan_csv(input_data_path, separator=";", n_rows=10)

    schema = {}
    schema["time"] = pl.Datetime
    columns = result.collect_schema().names()

    for column in columns:
        if "time" not in column:
            schema[column] = pl.Float32

    result = pl.scan_csv(input_data_path, separator=";", schema=schema)

    if start_date is not None:
        result = result.filter(pl.col("time") >= start_date)
        result = result.filter(pl.col("time") <= end_date)

    if channels is not None:
        result = result.select(["time"] + channels)

    result = result.collect()
    result = result.to_pandas()
    result.set_index("time", inplace=True)

    return result

In [16]:
# ---------------------------------------------------
# Code: Read data
# ---------------------------------------------------

## Read preprcessed data
df_data= import_data(
    input_data_path=input_data_path,
    channels=channels_list,
    start_date=start_date,
    end_date=end_date,
)


## Import df anomalies
df_anomalies_esa = pd.read_csv(ESA_ANOMALIES_PATH)
df_anomalies_esa = df_anomalies_esa[
    (df_anomalies_esa["Category"] == "Anomaly") | (df_anomalies_esa["Category"] == "Rare Event")
]
if channels_list is not None:
    df_anomalies_esa = df_anomalies_esa[df_anomalies_esa["Channel"].isin(channels_list)]



## Import predictions
if channels_list is not None:
    df_anomalies_predicted = df_anomalies_predicted[df_anomalies_predicted["Channel"].isin(channels_list)]

## Dashboard

In [None]:
# ---------------------------------------------------
# Code: Dashboard
# ---------------------------------------------------

# If no ESA anomalies are found in the selected date range, an error message is displayed.

# anomaly_category_colors
anomaly_category_colors = {
    "Rare Event": "blue",
    "Anomaly": "red",
    "Communication Gap": "green",
}
# app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

app.layout = html.Div(
    [
        dbc.Alert(
            "No anomalies found in the filtered data",
            color="danger",
            id="error-banner",
            is_open=False,
        ),
        dcc.DatePickerRange(
            id="date-picker-range",
            start_date=date(2000, 2, 1),
            end_date=date(2000, 2, 29),
            display_format="DD-MM-YYYY",
        ),
        dcc.Graph(id="time-series-graph"),
    ]
)


@app.callback(
    [Output("error-banner", "is_open"), Output("time-series-graph", "figure")],
    [Input("date-picker-range", "start_date"), Input("date-picker-range", "end_date")],
)
def update_graph(start_date: str, end_date: str) -> Tuple[bool, go.Figure]:
    # data filtered by date
    filtered_df_data = df_data[(df_data.index >= start_date) & (df_data.index <= end_date)]

    # esa anomalies filtered by date
    filtered_anomalies_esa = df_anomalies_esa[
        (df_anomalies_esa["StartTime"] >= start_date)
        & (df_anomalies_esa["EndTime"] <= end_date)
    ]

    # predicted anomalies filtered by date
    filtered_anomalies_predicted = df_anomalies_predicted[
        (df_anomalies_predicted["StartTime"] >= start_date)
        & (df_anomalies_predicted["EndTime"] <= end_date)
    ]

    # channels = filtered_anomalies_esa.Channel.unique()
    channels1 = list(filtered_anomalies_esa.Channel)
    channels2 = list(filtered_anomalies_predicted.Channel)
    channels = np.array(list(set(channels1 + channels2)))
    channels = np.array(sorted(channels, key=lambda x: int(x.split('_')[1])))
    # channels.sort()
    if len(channels) == 0:
        return True, go.Figure()
    
    # esa_anomalies = filtered_anomalies_esa.ID.unique()
    # esa_anomalies.sort()

    # Supblots. The number of rows is the number of channels + 1 (SPE plot from PitIA)
    fig = FigureResampler(
        make_subplots(
            # rows=len(channels) + 1 + len(all_channels),
            rows=len(channels),
            cols=1,
            shared_xaxes="columns",
            subplot_titles=channels,
            vertical_spacing=0.01,
        ),
        default_downsampler=MinMaxLTTB(parallel=True),
        create_overview=True,
    )


    for i, channel in enumerate(channels):
        fig.add_trace(
            go.Scattergl(
                name=channel,
                showlegend=True,
                x=filtered_df_data.index,
                y=filtered_df_data[channel],
            ),
            hf_x=filtered_df_data.index,
            hf_y=filtered_df_data[channel],
            # row=i + 2,
            row=i + 1,
            col=1,
        )

        esa_anomalies = filtered_anomalies_esa[filtered_anomalies_esa["Channel"] == channel]
        for _, anomaly_ in esa_anomalies.iterrows():
            fig.add_shape(
                type="rect",
                # x0=max([start_date, anomaly_["StartTime"].values[0]]),
                x0=max([start_date, anomaly_["StartTime"]]),
                y0=filtered_df_data[channel].min(),
                # x1=min([end_date, anomaly_["EndTime"].values[0]]),
                x1=min([end_date, anomaly_["EndTime"]]),
                y1=filtered_df_data[channel].max(),
                # line=dict(color="red", width=1),
                line=dict(color=anomaly_category_colors[anomaly_["Category"]], width=1),
                # fillcolor=anomaly_category_colors[anomaly_["Category"].values[0]],
                fillcolor="yellow",
                opacity=0.25,
                # row=i + 2,
                row=i + 1,
                col=1,
            )
            fig.add_trace(
                go.Scattergl(
                    # x=[anomaly_["StartTime"].values[0]],
                    x=[anomaly_["StartTime"]],
                    y=[
                        filtered_df_data[channel].min() + filtered_df_data[channel].min() / 10
                    ],
                    mode="text",
                    marker=dict(size=8),
                    # text=anomaly_["Category"].values[0]
                    text=anomaly_["Category"]
                    + " "
                    # + anomaly_["ID"].values[0],
                    + anomaly_["ID"],
                    line=dict(color="black", width=1),
                    showlegend=False,
                ),
                # row=i + 2,
                row=i + 1,
                col=1,
            )
    
        predicted_anomalies = filtered_anomalies_predicted[filtered_anomalies_predicted["Channel"] == channel]
        for _, anomaly_ in predicted_anomalies.iterrows():
            fig.add_shape(
                type="rect",
                # x0=max([start_date, anomaly_["StartTime"].values[0]]),
                x0=max([start_date, str(anomaly_["StartTime"])]),
                y0=filtered_df_data[channel].min(),
                # x1=min([end_date, anomaly_["EndTime"].values[0]]),
                x1=min([end_date, str(anomaly_["EndTime"])]),
                y1=filtered_df_data[channel].max(),
                line=dict(color="red", width=1),
                # fillcolor=anomaly_category_colors[anomaly_["Category"].values[0]],
                fillcolor="blue",
                opacity=0.25,
                # row=i + 2,
                row=i + 1,
                col=1,
            )

    for annotation in fig["layout"]["annotations"]:
        annotation["xanchor"] = "left"
        annotation["x"] = 0
    HEIGHT = 250 * len(channels) if len(channels) < 10 else 2000
    fig.update_layout(height=HEIGHT)
    return False, fig


app.run_server(debug=True);