In [None]:
from IPython.display import Markdown, HTML
from pathlib import Path
import glob
import io
import pickle
import base64
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

pd.options.display.max_columns = None

In [None]:
aggregation = 'median'
metric = 'mae'

root_dir = Path.cwd().parent
results_dir = root_dir / 'results'

def get_job_stats_dir(job_name, variant):
    return results_dir / f'job_stats_{job_name}_{variant}'

def get_notebooks_dir(job_name, variant):
    return results_dir / 'notebooks' / f'{job_name}_{variant}'

def load_stats(job_name, variant):
    pkl_dir = get_job_stats_dir(job_name, variant)
    pkls = [pickle.load(open(p, 'rb')) for p in sorted(pkl_dir.glob('*.pkl'))]
    assert len(pkls) > 0, f'No stats found in {pkl_dir}!'
    stats = pd.concat(pkls)
    stats['var_synth_count'] = stats.apply(lambda x: x['var_synth'].count(',') + 1, axis=1)
    stats['var_ml_count'] = stats.apply(lambda x: x['var_ml'].count(',') + 1, axis=1)
    stats['synth_enabled'] = stats.apply(lambda x: 'Yes' if x['synth_mul_factor'] else 'No', axis=1)
    return stats

def select_best_across_iterations(df_stats, column: str, method: str):
    def fn(df):
        df_sorted = df.sort_values(by=[column])
        if method == 'min':
            row = df_sorted.head(1)
        elif method == 'median':
            row = df_sorted.iloc[[len(df) // 2]]
        row = row.drop(columns=['name'])
        return row
    df_best = df_stats.groupby(['name']).apply(fn)
    df_best = df_best.reset_index()
    df_best = df_best.drop(columns=['level_1']) # introduced during groupby
    return df_best
           
def print_tables(stats):
    display(Markdown(f'## Table: unaggregated sorted by `lw_{metric}_test`'))
    display(stats.sort_values(by=[f'lw_{metric}_test']))
    display(Markdown(f'## Table: unaggregated sorted by `sw_{metric}_test`'))
    display(stats.sort_values(by=[f'sw_{metric}_test']))

    # Summary statistics aggregated over all iterations per configuration
    # Sorted by M*E Test (LW and SW)
    lw_agg_stats = select_best_across_iterations(stats, f'lw_{metric}_test', aggregation)
    sw_agg_stats = select_best_across_iterations(stats, f'sw_{metric}_test', aggregation)
    
    display(Markdown(f'## Table: aggregated by `name` and sorted by `lw_{metric}_test`'))
    display(lw_agg_stats.sort_values(by=[f'lw_{metric}_test']))
    display(Markdown(f'## Table: aggregated by `name` and sorted by `sw_{metric}_test`'))
    display(sw_agg_stats.sort_values(by=[f'sw_{metric}_test']))
    
def draw_heatmap(stats, x, y, x_label=None, y_label=None):
    x_label = x_label or x
    y_label = y_label or y
    cbar_label = 'Flux'

    display(Markdown(f'## Heatmap: aggregated by `{x}` & `{y}`'))
    agg_stats = stats.groupby([x, y]).agg(aggregation)
    agg_stats = agg_stats.reset_index()

    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))

    sns.heatmap(agg_stats.pivot(y, x, f'sw_{metric}_test'), ax=axs[0])
    axs[0].set_title(f'Shortwave {metric.upper()}')

    sns.heatmap(agg_stats.pivot(y, x, f'lw_{metric}_test'), ax=axs[1])
    axs[1].set_title(f'Longwave {metric.upper()}')

    for j in [0, 1]:
        ax = axs[j]
        ax.invert_yaxis()
        ax.set_xlabel(x_label)
        ax.set_ylabel(y_label if j == 0 else '')
        if j == 1:            
            ax.collections[0].colorbar.set_label(cbar_label)
    plt_show_svg()

x_labels = {
    'hidden_size': 'Hidden size',
    'synth_enabled': 'Data augmentation'
}
    
def draw_boxplots(stats, ylim=None):
    display(Markdown(f'## Boxplots'))
    var_names = ['hidden_size', 'n_hidden_layers', 'synth_mul_factor', 'synth_enabled', 'unif_ratio', 'stretch_factor', 'loss', 'activation', 'l1_penalty', 'l2_penalty', 'var_regularizer_factor', 'dropout_ratio', 'var_synth_count', 'var_ml_count']
    for var_name in var_names:
        if var_name not in stats:
            continue
        if len(stats.groupby([var_name]).agg(aggregation)) == 1:
            continue
        fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
        plot = sns.boxplot(x=var_name, y=f"sw_{metric}_test", data=stats, ax=axs[0])
        plot.set(
            xlabel=x_labels.get(var_name, var_name),
            ylabel="Shortwave mean absolute error in W m⁻²",
            ylim=ylim
        )        
        plot = sns.boxplot(x=var_name, y=f"lw_{metric}_test", data=stats, ax=axs[1])
        plot.set(
            xlabel=x_labels.get(var_name, var_name),
            ylabel="Longwave mean absolute error in W m⁻²",
            ylim=ylim
        )
        plt_show_svg()

def plt_show_svg(fig=None):
    from IPython.display import HTML
    if fig is not None and hasattr(fig, 'to_image'):
        # plotly
        svg = fig.to_image(format="svg")
    else:
        if fig is None:
            fig = plt.gcf()
        f = io.BytesIO()
        fig.savefig(f, format='svg', bbox_inches='tight')
        plt.close(fig)
        svg = f.getvalue()
    svg_url = 'data:image/svg+xml;base64,' + base64.b64encode(svg).decode()
    display(HTML(f'<img src="{svg_url}"></img>'))

# Job: main_*

In [None]:
stats = load_stats(job_name='mlp', variant='default')

print_tables(stats)

draw_heatmap(stats, x='hidden_size', y='n_hidden_layers', x_label='Hidden size', y_label='Hidden layers')

draw_boxplots(stats)

In [None]:
import plotly.express as px

def with_noise(df):
    stats_with_noise = df.copy()
    stats_with_noise['var_ml_count'] += np.random.normal(0, 0.1, len(stats_with_noise))
    stats_with_noise['n_hidden_layers'] += np.random.normal(0, 0.1, len(stats_with_noise))
    stats_with_noise['hidden_size'] += np.random.normal(0, 0.05, len(stats_with_noise))
    stats_with_noise['l1_penalty'] += np.random.normal(0, 0.000005, len(stats_with_noise))
    stats_with_noise['l2_penalty'] += np.random.normal(0, 0.000005, len(stats_with_noise))
    return stats_with_noise

columns = {
    'var_ml_count': 'Input quantities',
    'n_hidden_layers': 'Hidden layers',
    'hidden_size': 'Hidden size',
    'l1_penalty': 'L1',
    'l2_penalty': 'L2',
    'sw_mae_test': 'Shortwave mean absolute error in W m⁻²',
    'lw_mae_test': 'Longwave mean absolute error in W m⁻²',
}
column_names = list(columns.values())
column_names.remove(columns['sw_mae_test'])
column_names.remove(columns['lw_mae_test'])

def rename_cols(df):
    return df.rename(columns=columns)

px_kw = dict(
    width=1000,
    height=500,
    color_continuous_scale=px.colors.diverging.Tealrose,
    dimensions=column_names
)

display(Markdown(f'## SW: all data'))
sw_stats_with_noise = with_noise(stats)
px.parallel_coordinates(rename_cols(sw_stats_with_noise), color=columns[f"sw_{metric}_test"], 
                        **px_kw).show()

display(Markdown(f'## SW: aggregated over iterations'))
sw_stats_with_noise = with_noise(select_best_across_iterations(stats, f'sw_{metric}_test', aggregation))
px.parallel_coordinates(rename_cols(sw_stats_with_noise), color=columns[f"sw_{metric}_test"], 
                        **px_kw).show()

display(Markdown(f'## LW: all data'))
lw_stats_with_noise = with_noise(stats)
px.parallel_coordinates(rename_cols(lw_stats_with_noise), color=columns[f"lw_{metric}_test"], 
                        **px_kw).show()

display(Markdown(f'## LW: aggregated over iterations'))
lw_stats_with_noise = with_noise(select_best_across_iterations(stats, f'lw_{metric}_test', aggregation))
px.parallel_coordinates(rename_cols(lw_stats_with_noise), color=columns[f"lw_{metric}_test"], 
                        **px_kw).show()

# Job: synth_*

In [None]:
stats = load_stats(job_name='mlp', variant='synthia')

print_tables(stats)

draw_boxplots(stats, ylim=0)

In [None]:
nb_dir = get_notebooks_dir(job_name='mlp', variant='synthia')

for xw in ['sw', 'lw']:
    stats_sorted = stats[stats['synth_mul_factor'] > 0].sort_values(by=[f'{xw}_{metric}_test'])
    median_idx = len(stats_sorted) // 2 -1
    row = stats_sorted[median_idx:median_idx+1]
    iteration = int(row.iteration)
    for p in glob.glob(str(nb_dir / '*.txt')):
        with open(p) as f:
            content = f.read()
            if f'iteration={iteration}' in content:
                display(Markdown(f'## {xw}'))
                display(row)
                print(p.replace("txt", "html"))
                break