# Compute metrics for different window lengths

**See file [visual_win_lens.ipynb](visual_win_lens.ipynb) for interactive plots.**

In [22]:
from pathlib import Path

CUR_ABS_DIR = Path.cwd().resolve()
PROJ_DIR = (CUR_ABS_DIR / '../../../').resolve()
OUTPUT_DIR = (PROJ_DIR / 'swissrivernetwork/benchmark/outputs/ray_results/').resolve()
DUMP_DIR = (CUR_ABS_DIR / 'outputs/').resolve()

LOAD_LATEST_RESULTS = True

# ISSUE_TAG = "\033[91m[issue]\033[0m "  # Red
# INFO_TAG = "[info] "  # Blue
# SUCCESS_TAG = "\033[92m[success]\033[0m "  # Green
#
# from rich.console import Console as rConsole
# from rich.text import Text as rText
#
# console = rConsole()
#
#
# def rprint(text=None):
#     if text is None:
#         console.print()
#         return
#     text = str(text)
#     if text.startswith(ISSUE_TAG):
#         t1, t2 = rText("[issue]", style="red bold"), text[len(ISSUE_TAG):]
#     elif text.startswith(INFO_TAG):
#         t1, t2 = rText("[info]", style="blue bold"), rText(text[len(INFO_TAG):], style="none")
#     elif text.startswith(SUCCESS_TAG):
#         t1, t2 = rText("[success]", style="green bold"), rText(text[len(SUCCESS_TAG):], style="none")
#     else:
#         t1, t2 = rText("", style="none"), rText(text, style="none")
#     console.print(t1, t2, end='')
#
#
# rprint(f'{INFO_TAG} Project directory: {PROJ_DIR}')
# rprint(f'{INFO_TAG} Output directory: {OUTPUT_DIR}')
# rprint(f'{INFO_TAG} Current directory: {CUR_ABS_DIR}')

ISSUE_TAG = '<span style="color:red;">[issue]</span> '
INFO_TAG = '<span style="color:blue;">[info]</span> '
SUCCESS_TAG = '<span style="color:green;">[success]</span> '


def print(text=None):
    from IPython.display import display, HTML
    if text is None:
        display(HTML("<br>"))
        return
    text = str(text)
    display(HTML(text))


print(f'{INFO_TAG} Project directory: {PROJ_DIR}')
print(f'{INFO_TAG} Current directory: {CUR_ABS_DIR}')
print(f'{INFO_TAG} Output directory: {OUTPUT_DIR}')
print(f'{INFO_TAG} Dump directory: {DUMP_DIR}')

In [23]:
import numpy as np


def curate_window_lens(window_lens: np.array, data_name: str, mode: str = 'subsequences', max_days: int = np.inf) -> list:
    """
    Curate window lengths based on dataset characteristics.

    Args:
        window_lens (list of int): List of window lengths to curate.
        data_name (str): Name of the dataset ('swiss-1990', 'swiss-2010', 'zurich').
        mode (str): Mode of operation. "full" for full dataset,
    """
    if data_name == 'swiss-1990':
        if mode == 'full':
            max_days = min(2188, max_days)
        elif mode == 'subsequences':
            max_days = min(853, max_days)
    elif data_name == 'swiss-2010':
        if mode == 'full' or mode == 'subsequences':
            max_days = min(1096, max_days)
    elif data_name == 'zurich':
        if mode == 'full' or mode == 'subsequences':
            max_days = min(1035, max_days)
    else:
        raise ValueError(f'Unknown data name: {data_name}')
    window_lens = window_lens[window_lens < max_days]
    window_lens = np.unique(np.concatenate((window_lens, [max_days])))
    window_lens = np.sort(window_lens)
    return window_lens

In [24]:
import pandas as pd
import os


def read_results_from_csv(output_dir: Path, graph_name: str, method: str):
    file_path = os.path.join(output_dir, f'{graph_name}_{method}_window_lens_results.csv')
    if os.path.exists(file_path):
        df = pd.read_csv(file_path)
        # print(f"{SUCCESS_TAG} Loaded results from {file_path}.")
        return df
    else:
        # print(f"{INFO_TAG} No results found at {file_path}.")
        return pd.DataFrame()


def save_results_to_csv(df: pd.DataFrame, output_dir: Path, graph_name: str, method: str):
    output_dir.mkdir(parents=True, exist_ok=True)
    file_path = os.path.join(output_dir, f'{graph_name}_{method}_window_lens_results.csv')
    df.to_csv(file_path, index=False)
    # print(f"{SUCCESS_TAG} Saved results to {file_path}.")
    return file_path

In [25]:
WINDOW_LENS = np.concatenate(
    ([1, 3, 5, 7, 15], 30 * np.arange(1, 13), [365], 30 * np.arange(13, 25), 365 * np.arange(2, 11))
)

GRAPH_NAMES = ['swiss-1990', 'swiss-2010', 'zurich']
METHODS = ['lstm_embedding', 'transformer_embedding', 'lstm', 'graphlet', 'stgnn']

In [30]:
from swissrivernetwork.benchmark.ray_evaluation import process_method

def run_for_graph(graph_name: str):

    for m in METHODS[0:2]:  # test [2:3]
        print()
        print(f'{INFO_TAG} #### Starting processing for method="{m}" on graph="{graph_name}"...')

        max_days = 500 if m == 'transformer_embedding' else np.inf
        window_lens = curate_window_lens(WINDOW_LENS, graph_name, mode='subsequences', max_days=max_days)
        print(f'{INFO_TAG} Curated window lengths for graph="{graph_name}", method="{m}": {window_lens}')

        df_res = read_results_from_csv(DUMP_DIR, graph_name, m)

        for i_len, window_len in enumerate(window_lens):
            print(f"{INFO_TAG} ==== Processing window_len={window_len} [{i_len + 1}/{len(window_lens)}] ====")

            if not df_res.empty and window_len in df_res['window_len'].values:
                print(f"{SUCCESS_TAG} Results for window_len={window_len} already exist. Skipping...")
                continue

            df_data = process_method(
                graph_name, m, output_dir=OUTPUT_DIR, settings={'window_len': window_len, 'verbose': 1, 'env': 'notebook'}
            )
            new_row = {'window_len': window_len}
            for metric in ['RMSE', 'MAE', 'NSE']:
                for stat_measure in ['Mean', 'Std', 'Median', 'Min', 'Max']:
                    col_name = f'{metric}_{stat_measure}'
                    if metric in df_data.columns:
                        value = df_data[df_data['Station'] == stat_measure][metric]
                        if not value.empty:
                            new_row[col_name] = value.values[0]
                        else:
                            new_row[col_name] = np.nan
                    else:
                        new_row[col_name] = np.nan
            df_res = pd.concat([df_res, pd.DataFrame([new_row])], ignore_index=True)

            file_path = save_results_to_csv(df_res, DUMP_DIR, graph_name, m)

            print(f'{SUCCESS_TAG} Completed processing for window_len={window_len}. Results saved to {file_path}.')

        print(f'{SUCCESS_TAG} #### Finished processing for method="{m}" on graph="{graph_name}"!')

In [31]:
# Compute results and save to CSV:
for graph_name in GRAPH_NAMES[0:]:
    print(f'{INFO_TAG} ================= Starting processing for graph="{graph_name}" =================')
    run_for_graph(graph_name)