In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import mlflow
import os
import pathlib
import pandas as pd
import plotly.express as px

In [None]:
DATA_DIR = pathlib.Path(os.getenv('DATA_DIR'))

In [None]:
with hydra.initialize_config_module('smc01.postprocessing.conf'):
    cfg = hydra.compose('train')

In [None]:
cfg.logging

In [None]:
client = mlflow.tracking.client.MlflowClient(tracking_uri=cfg.logging.mlflow.tracking_uri)

In [None]:
def runs_to_df(runs):
    rows = []
    for r in runs:
        rmse_history = client.get_metric_history(r.info.run_id, 'Val/RMSE')

        min_rmse = min(rmse_history, key=lambda x: x.value) if len(rmse_history) > 0 else None


        rows.append({
            'model': r.data.params['model/_target_'].split('.')[-1],
            'train_begin': r.data.params['train_begin'],
            'min_rmse': min_rmse.value if min_rmse else None,
            'start_time': r.info.start_time,
            'station_subset': r.data.params.get('dataset/station_set_file', ''),
        })
        
    df = pd.DataFrame(rows)
    df['start_time'] = pd.to_datetime(df['start_time'], unit='ms')
        
    return df

In [None]:
raw_model_runs = client.search_runs(
    experiment_ids="2",
    filter_string="tags.`mlflow.runName`='raw_model_gdps_metar'",
    run_view_type=mlflow.entities.ViewType.ACTIVE_ONLY,
)

In [None]:
attention_runs = client.search_runs(
    experiment_ids="2",
    filter_string="tags.`mlflow.runName`='attention_gdps_metar_finetune'",
    run_view_type=mlflow.entities.ViewType.ACTIVE_ONLY,
)

In [None]:
attention_runs_2 = client.search_runs(
    experiment_ids="2",
    filter_string="tags.`mlflow.runName`='attention_gdps_metar'",
    run_view_type=mlflow.entities.ViewType.ACTIVE_ONLY,
)

In [None]:
mos_runs = client.search_runs(
    experiment_ids="2",
    filter_string="tags.`mlflow.runName`='emos_gdps_metar' params.n_features='1'",
    run_view_type=mlflow.entities.ViewType.ACTIVE_ONLY,
)

In [None]:
dfs = [
    runs_to_df(attention_runs),
    runs_to_df(attention_runs_2),
    runs_to_df(mos_runs),
]

In [None]:
raw_model_runs = client.search_runs(
    experiment_ids="2",
    filter_string=f"tags.`mlflow.runName`='raw_model_gdps_metar' params.`dataset.station_set_file`='{DATA_DIR!s}/bootstrap_set.csv'",
    run_view_type=mlflow.entities.ViewType.ACTIVE_ONLY,
)

In [None]:
raw_model_metric = raw_model_runs[0].data.metrics['Val/RMSE']

In [None]:
raw_model_metric

In [None]:
df = pd.concat(dfs)

In [None]:
df

In [None]:
df = df[(df['start_time'] > '2022-05-01') & (df['station_subset'] == str(DATA_DIR / 'bootstrap_set.csv'))]

In [None]:
fig = px.scatter(data_frame=df, x='train_begin', y='min_rmse', color='model', labels={'min_rmse': 'RMSE on validation set (°C)'}, title='Validation loss for post-processing models on GDPS data')
fig.add_hline(y=raw_model_metric, annotation_text='Raw GDPS')