In [1]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Dict, Any
import pickle
def eval_results_to_df(eval_results: Dict[str, Dict[int, Dict]], metric='mse'):
    results = {}
    model_type_index = None
    for ds_name, ds_res in eval_results.items():
        results[ds_name] = {(support_size, mt): ds_ss_res[metric][mt] for support_size, ds_ss_res in ds_res.items() for mt in ds_ss_res[metric]} 
    
    multi_ind = pd.MultiIndex.from_tuples(results[ds_name].keys(), names=['support', 'type'])
    return pd.DataFrame(results, index=multi_ind).transpose()

In [2]:
results_file = "/system/user/beck/pwbeck/meta/rnn-adaptation-multitask/runs/jfr_multitask_220228_170251/eval_results.p"
results_file = Path(results_file)
with open(results_file, 'rb') as f:
    results = pickle.load(f)

In [3]:
df_mse = eval_results_to_df(results, metric='mse')
df_rsquared = eval_results_to_df(results, metric='rsquared')
df_mse

support,10,10,20,20,30,30,50,50,70,70,100,100,2000,2000
type,jfr_model,no_finetune,jfr_model,no_finetune,jfr_model,no_finetune,jfr_model,no_finetune,jfr_model,no_finetune,jfr_model,no_finetune,jfr_model,no_finetune
test/R:2.0005479452697545_L:5.963974503396337e-05_C:4.4174291934218044e-07.npy,1.317416,0.122243,0.454707,0.123402,0.517704,0.124508,0.082714,0.121763,0.100761,0.124868,0.134050,0.124633,0.063158,0.125313
test/R:4.198512930894825_L:6.125137871802864e-05_C:7.699455426617554e-07.npy,3.268443,0.268411,1.320805,0.272736,0.184526,0.269293,0.173692,0.272428,0.097138,0.275362,0.088773,0.271261,0.083101,0.267959
test/R:1.3653511125173492_L:0.00011632633135961731_C:2.8486052695566855e-07.npy,0.958063,0.600272,0.691869,0.597996,4.652066,0.595758,0.867129,0.589073,1.258286,0.597971,0.808295,0.593235,37.999783,0.597119
test/R:5.181115923706953_L:3.770798162691948e-05_C:3.4315031531791823e-07.npy,1.738513,0.038726,3.012225,0.036999,0.376213,0.038674,0.017794,0.038519,0.017176,0.038879,0.016030,0.037437,0.013472,0.039338
test/R:1.7588075933627945_L:5.5252266184635496e-05_C:5.716758626022325e-07.npy,1.351758,0.174955,0.311251,0.172135,0.672056,0.172892,0.310419,0.169677,0.569198,0.176622,0.137558,0.170794,0.077815,0.170404
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
test/R:12.005887641787659_L:0.00013335306711790169_C:2.6305991197855294e-07.npy,1.707181,0.279959,3.468054,0.284848,0.211271,0.282265,0.187678,0.283341,0.130248,0.282587,0.125294,0.283293,0.093273,0.283088
test/R:4.057616546661245_L:3.485715198861148e-05_C:3.2534819958429176e-07.npy,2.609037,0.024521,0.555156,0.025989,0.122796,0.023507,0.074674,0.023844,0.036472,0.024700,0.017720,0.024824,0.012605,0.024999
test/R:4.285058750168105_L:5.875928798152061e-05_C:5.349069508631779e-07.npy,16.075912,0.174592,1.619135,0.180824,0.423494,0.177165,0.087530,0.178542,0.088639,0.178071,0.053014,0.179409,0.095223,0.182041
test/R:13.383517992395685_L:0.00013360139488719324_C:2.5508288077009096e-07.npy,9.605336,0.187084,4.204467,0.184101,0.710710,0.184602,0.539818,0.181988,0.153490,0.185422,0.063210,0.185622,0.066471,0.187568


In [4]:
display_df = df_mse.copy()
median = display_df.median(axis=0)
mean = display_df.mean(axis=0)
display_df.loc[' median'] = median
display_df.loc[' mean'] = mean
display(display_df.sort_index().style.background_gradient('Greens_r', axis=1).highlight_null('white'))

support,10,10,20,20,30,30,50,50,70,70,100,100,2000,2000
type,jfr_model,no_finetune,jfr_model,no_finetune,jfr_model,no_finetune,jfr_model,no_finetune,jfr_model,no_finetune,jfr_model,no_finetune,jfr_model,no_finetune
mean,4.410559,0.23745,4.176415,0.237388,0.766966,0.23758,0.283558,0.237388,0.209613,0.237537,0.159615,0.237378,0.260797,0.237373
median,2.672803,0.224893,2.333867,0.2261,0.397476,0.224755,0.157207,0.224749,0.12629,0.22613,0.102556,0.225905,0.078489,0.225352
test/R:1.0249437636881251_L:0.00010726452771948082_C:2.643912032291413e-07.npy,1.836744,0.646017,4.151481,0.639493,0.937441,0.643443,0.936413,0.6411,1.020072,0.646362,0.907114,0.650611,0.517948,0.647026
test/R:1.1586005688888497_L:8.693189183448564e-05_C:2.7407460137915236e-07.npy,10.214717,0.364601,1.742175,0.366567,0.535375,0.366005,0.474779,0.367682,0.517885,0.36583,0.47755,0.364772,0.573502,0.367256
test/R:1.1970406984962436_L:4.581432462508986e-05_C:3.8705381297963926e-07.npy,5.284896,0.065281,2.087538,0.062855,0.141056,0.063929,0.177723,0.064782,0.165126,0.065062,0.137618,0.064903,0.046164,0.06515
test/R:1.332428795631129_L:3.301060869307106e-05_C:4.700441470380736e-07.npy,1.396021,0.050309,0.831354,0.049008,0.276976,0.050754,0.044405,0.050351,0.038161,0.047739,0.04169,0.049901,0.027482,0.050766
test/R:1.3402772427181766_L:4.746310122094462e-05_C:4.956053200425364e-07.npy,19.066776,0.098507,0.707202,0.09823,2.246067,0.096683,0.446332,0.096822,0.602228,0.096901,0.09996,0.100426,0.055988,0.099559
test/R:1.344161221869895_L:0.00013678277752763392_C:4.937731935083017e-07.npy,4.344183,0.715721,8.938469,0.715042,0.903226,0.714321,0.981149,0.716755,0.885233,0.715569,0.630741,0.716534,0.572047,0.713543
test/R:1.3653511125173492_L:0.00011632633135961731_C:2.8486052695566855e-07.npy,0.958063,0.600272,0.691869,0.597996,4.652066,0.595758,0.867129,0.589073,1.258286,0.597971,0.808295,0.593235,37.999783,0.597119
test/R:1.407569823230979_L:3.945585691943242e-05_C:1.2746477136016913e-07.npy,3.842424,0.2096,4.291109,0.205056,0.192566,0.207087,0.137682,0.208384,0.163868,0.206062,0.10169,0.204555,0.102454,0.209645


support_size,20,20,30,30,100,100
model_type,jfr_model,no_finetune,jfr_model,no_finetune,jfr_model,no_finetune
test/R:2.0005479452697545_L:5.963974503396337e-05_C:4.4174291934218044e-07.npy,216.394974,1.043571,232.805511,1.043289,41.478764,1.048042
test/R:4.198512930894825_L:6.125137871802864e-05_C:7.699455426617554e-07.npy,160.537308,0.964205,475.004364,0.969389,19215.972656,0.969349
test/R:1.3653511125173492_L:0.00011632633135961731_C:2.8486052695566855e-07.npy,136.692108,1.791458,94.50428,1.77731,58.779438,1.785594
median,160.537308,1.043571,232.805511,1.043289,58.779438,1.048042
mean,171.20813,1.266412,267.438052,1.263329,6438.743619,1.267662
