In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import hydra
import mlflow
import pandas as pd
import plotly.express as px

In [3]:
with hydra.initialize_config_module('smc01.postprocessing.conf'):
    cfg = hydra.compose('train')

In [4]:
client = mlflow.tracking.client.MlflowClient(tracking_uri=cfg.logging.mlflow.tracking_uri)

In [5]:
def runs_to_df(runs):
    rows = []
    for r in runs:
        rmse_history = client.get_metric_history(r.info.run_id, 'Val/RMSE')

        min_rmse = min(rmse_history, key=lambda x: x.value) if len(rmse_history) > 0 else None


        rows.append({
            'model': r.data.params.get('model._target_', r.data.params.get('model/_target_', '')).split('.')[-1],
            'train_begin': r.data.params.get('split.train_begin', r.data.params.get('split/train_begin', '')),
            'val_begin': r.data.params.get('split.val_begin', r.data.params.get('split/val_begin', '')),
            'min_rmse': min_rmse.value if min_rmse else None,
            'start_time': r.info.start_time,
            'end_time': r.info.end_time,
            'station_subset': r.data.params.get('dataset/station_set_file', r.data.params.get('dataset.station_set_file', '')),
            'run_name': r.data.tags['mlflow.runName'],
            'freeze_upper': 'True' == r.data.params.get('freeze_upper', False),
            'split_name': r.data.params.get('split.name', ''),
            'test_rmse': r.data.metrics.get('Test/RMSE'),
            
        })
        
    df = pd.DataFrame(rows)
    df['start_time'] = pd.to_datetime(df['start_time'], unit='ms')
    df['end_time'] = pd.to_datetime(df['end_time'], unit='ms')

        
    return df

In [6]:
mos_runs = client.search_runs(
    experiment_ids="12",
    filter_string="tags.mlflow.runName LIKE 'emos_gdps_metar_progressive_wval_%_1feature'",
    run_view_type=mlflow.entities.ViewType.ACTIVE_ONLY,
)

In [7]:
attention_runs = client.search_runs(
    experiment_ids="12",
    filter_string="tags.mlflow.runName LIKE 'attention_gdps_metar_finetune_progressive_%'",
    run_view_type=mlflow.entities.ViewType.ACTIVE_ONLY,
)

In [8]:
mos_df = runs_to_df(mos_runs)
mos_df['n_days'] = mos_df['run_name'].str.split('_').str[-2].astype(int)
mos_df['model'] = 'MOS'

In [9]:
attention_df = runs_to_df(attention_runs)
attention_df['n_days'] = attention_df['run_name'].str.split('_').str[5].astype(int)
attention_df['model'] = 'Attention'
attention_df.loc[attention_df['run_name'].str.contains('freeze'), 'model'] = 'Attention (Freeze)'

In [10]:
attention_df

Unnamed: 0,model,train_begin,val_begin,min_rmse,start_time,end_time,station_subset,run_name,freeze_upper,split_name,test_rmse,n_days
0,Attention,2019-01-01,2020-07-01,,2022-05-29 17:58:32.642,2022-05-29 18:06:47.244,,attention_gdps_metar_finetune_progressive_700,False,train_6_7_val_7,3.719376,700
1,Attention (Freeze),2019-01-01,2020-07-01,,2022-05-29 16:15:38.889,2022-05-29 16:23:36.237,,attention_gdps_metar_finetune_progressive_30_f...,False,train_6_7_val_7,3.974317,30
2,Attention (Freeze),2019-01-01,2020-07-01,,2022-05-29 16:07:18.658,2022-05-29 16:15:17.889,,attention_gdps_metar_finetune_progressive_30_f...,False,train_6_7_val_7,3.973757,30
3,Attention (Freeze),2019-01-01,2020-07-01,,2022-05-29 15:58:59.648,2022-05-29 16:06:56.970,,attention_gdps_metar_finetune_progressive_30_f...,False,train_6_7_val_7,3.973626,30
4,Attention (Freeze),2019-01-01,2020-07-01,,2022-05-29 15:27:48.009,2022-05-29 15:36:11.866,,attention_gdps_metar_finetune_progressive_180_...,False,train_6_7_val_7,3.888072,180
5,Attention (Freeze),2019-01-01,2020-07-01,,2022-05-29 15:14:42.177,2022-05-29 15:22:48.811,,attention_gdps_metar_finetune_progressive_90_f...,False,train_6_7_val_7,3.836269,90
6,Attention,2019-01-01,2020-07-01,,2022-05-29 15:01:47.975,2022-05-29 15:09:42.191,,attention_gdps_metar_finetune_progressive_700,False,train_6_7_val_7,3.796904,700
7,Attention (Freeze),2019-01-01,2020-07-01,,2022-05-29 14:58:44.849,2022-05-29 15:06:32.586,,attention_gdps_metar_finetune_progressive_700_...,False,train_6_7_val_7,3.635069,700
8,Attention (Freeze),2019-01-01,2020-07-01,,2022-05-29 14:42:55.311,2022-05-29 14:50:57.099,,attention_gdps_metar_finetune_progressive_365_...,False,train_6_7_val_7,3.661539,365
9,Attention (Freeze),2019-01-01,2020-07-01,,2022-05-29 14:42:51.113,2022-05-29 14:51:29.943,,attention_gdps_metar_finetune_progressive_270_...,False,train_6_7_val_7,3.739127,270


In [11]:
df = pd.concat([attention_df, mos_df])

In [22]:
fig_data = df[df['model'] != 'Attention'].groupby(by=['model', 'n_days']).mean().reset_index().sort_values('n_days')


fig = px.line(
    data_frame=fig_data,
    x='n_days',
    y='test_rmse', 
    color='model', 
    markers=True, 
    labels={
        'n_days': 'Length of training+validation set (days)',
        'test_rmse': 'RMSE on test set (°C)', 'model': 'Model'
    }, 
    width=600)
fig.add_hline(y=4.25, annotation_text='Raw GDPS', line_dash='dash')


fig.update_layout(
    xaxis = dict(
        tickmode = 'array',
        tickvals = [30, 60, 90, 180, 270, 365, 700],
    )
)

fig.write_image('smc01_transfer.png', width=800, height=400)

In [23]:
fig_data

Unnamed: 0,model,n_days,freeze_upper,test_rmse
0,Attention (Freeze),30,False,3.9739
7,MOS,30,False,4.188961
1,Attention (Freeze),60,False,4.027578
8,MOS,60,False,4.179546
2,Attention (Freeze),90,False,3.857738
9,MOS,90,False,4.111689
3,Attention (Freeze),180,False,3.883092
10,MOS,180,False,3.996546
4,Attention (Freeze),270,False,3.739127
11,MOS,270,False,3.844854
