In [63]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [64]:
import pandas as pd
from glob import glob 
from fastcore.xtras import load_pickle

from gpt3forchem.output import get_regression_metrics
from gpt3forchem.api_wrappers import extract_prediction

from pycm import ConfusionMatrix

## Classification

In [65]:
all_res = glob('results/20221130_freesolv/*')

In [66]:
metrics = []
baselines = []

for res in all_res:
    res = load_pickle(res)
    cm = res['cm']
    cm = ConfusionMatrix(list(map(lambda x: str(x).strip(), res['cm'].actual_vector)), list(map(lambda x: str(x).strip(), res['cm'].predict_vector)))
    baseline_cm = res['baseline']['cm']
    baseline_cm = ConfusionMatrix(list(map(lambda x: str(x).strip(), baseline_cm.actual_vector)), list(map(lambda x: str(x).strip(), baseline_cm.predict_vector)))
    metrics.append(
        {
            'train_size': res['train_size'],
            'representation': res['representation'],
            'accuracy': cm.ACC_Macro,
            'f1_macro': cm.F1_Macro,
            'f1_micro': cm.F1_Micro
        })
    baselines.append(
        {
            'train_size': res['train_size'],
            'representation': res['representation'],
            'accuracy': baseline_cm.ACC_Macro,
            'f1_macro': baseline_cm.F1_Macro,
            'f1_micro': baseline_cm.F1_Micro  
        })

In [67]:
metrics = pd.DataFrame(metrics)
baselines = pd.DataFrame(baselines)

In [76]:
metrics.groupby(['representation', 'train_size']).agg(['mean', 'std', 'count'])

Unnamed: 0_level_0,Unnamed: 1_level_0,accuracy,accuracy,accuracy,f1_macro,f1_macro,f1_macro,f1_micro,f1_micro,f1_micro
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,count,mean,std,count,mean,std,count
representation,train_size,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
inchi,10,0.787,0.032527,2,0.173926,0.003525,2,0.4675,0.081317,2
inchi,50,0.833,0.007071,2,0.222916,0.052498,2,0.5825,0.017678,2
inchi,500,0.933333,0.002857,2,0.833803,0.014607,2,0.833333,0.007142,2
iupac_name,10,0.802,0.03677,2,0.18678,0.007602,2,0.505,0.091924,2
iupac_name,50,0.861,0.029698,2,0.368536,0.147473,2,0.6525,0.074246,2
iupac_name,500,0.933333,0.014285,2,0.7596,0.123082,2,0.833333,0.035712,2
selfies,10,0.772,0.039598,2,0.160942,0.068168,2,0.43,0.098995,2
selfies,50,0.873,0.012728,2,0.284819,0.014746,2,0.6825,0.03182,2
selfies,500,0.921212,0.014285,2,0.83144,0.047642,2,0.80303,0.035712,2
smiles,10,0.8415,0.008699,4,0.249858,0.015695,4,0.60375,0.021747,4


In [70]:
baselines.groupby(['train_size']).agg(['mean', 'std'])

Unnamed: 0_level_0,accuracy,accuracy,f1_macro,f1_macro,f1_micro,f1_micro
Unnamed: 0_level_1,mean,std,mean,std,mean,std
train_size,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
10,0.8316,0.000843,0.233844,0.026455,0.579,0.002108
50,0.894,0.0,0.43784,0.0,0.735,0.0
500,0.8,0.0,0.0,0.0,0.0,0.0


## Regression

In [71]:
all_res_regression = glob('results/20221129_freesolv_regression/*')

In [72]:
metrics_regression = []
baselines_regression = []

for res in all_res_regression:
    res = load_pickle(res)
    metrics_regression.append(
        {
            'train_size': res['train_size'],
            'representation': res['representation'],
            'r2': res['metrics']['r2'],
            'max_error': res['metrics']['max_error'],
            'mean_absolute_error': res['metrics']['mean_absolute_error'],
            'mean_squared_error': res['metrics']['mean_squared_error'],
            'rmse': res['metrics']['rmse'],
        })
    baselines_regression.append(
        {
            'train_size': res['train_size'],
            'representation': res['representation'],
            'r2': res['baseline']['r2'],
            'max_error': res['baseline']['max_error'],
            'mean_absolute_error': res['baseline']['mean_absolute_error'], 
            'mean_squared_error': res['baseline']['mean_squared_error'],
            'rmse': res['baseline']['rmse'],
        })

In [73]:
metrics_regression = pd.DataFrame(metrics_regression)

baselines_regression = pd.DataFrame(baselines_regression)

In [74]:
metrics_regression.groupby(['representation', 'train_size']).agg(['mean', 'std'])

Unnamed: 0_level_0,Unnamed: 1_level_0,r2,r2,max_error,max_error,mean_absolute_error,mean_absolute_error,mean_squared_error,mean_squared_error,rmse,rmse
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std,mean,std,mean,std,mean,std
representation,train_size,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2
inchi,10,-0.245881,0.053283,18.935,2.241528,3.27485,0.088106,19.134028,0.818318,4.373746,0.093549
inchi,50,0.010655,0.062585,17.86,0.169706,2.8587,0.046457,14.674768,0.928305,3.829807,0.121195
inchi,500,0.651623,0.035518,13.955,0.106066,1.378434,0.065925,5.380518,0.548563,2.318085,0.118322
iupac_name,10,-0.418369,0.108768,17.99,0.579828,3.686525,0.141103,21.783073,1.670438,4.665517,0.17902
iupac_name,50,0.082566,0.170367,17.74,0.0,2.76535,0.300167,13.60813,2.527027,3.680926,0.34326
iupac_name,500,0.773709,0.004334,7.695,1.180868,1.122172,0.011214,3.494948,0.066934,1.869435,0.017902
selfies,10,-0.030135,0.118407,17.78,0.3995,2.9027,0.14388,15.62531,1.512037,3.949803,0.191173
selfies,50,0.164391,0.086608,16.196667,2.875036,2.647633,0.178014,12.223212,1.12759,3.49362,0.163538
selfies,500,0.764539,0.012558,10.065,1.491995,1.137525,0.023213,3.636576,0.193952,1.906642,0.050862
smiles,10,-0.178439,0.169525,19.046667,2.051747,3.189067,0.337366,17.906469,2.677896,4.222078,0.310857


In [75]:
baselines_regression.groupby(['train_size']).agg(['mean', 'std'])

Unnamed: 0_level_0,r2,r2,max_error,max_error,mean_absolute_error,mean_absolute_error,mean_squared_error,mean_squared_error,rmse,rmse
Unnamed: 0_level_1,mean,std,mean,std,mean,std,mean,std,mean,std
train_size,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
10,0.139089,0.027464,20.501197,0.07145,2.730168,0.041038,13.914439,0.149266,3.730157,0.019943
50,0.709372,0.008226,9.432677,0.606,1.582631,0.014952,4.577601,0.001755,2.139533,0.00041
500,0.91615,0.006269,3.954292,0.296657,0.868814,0.036917,1.379811,0.13444,1.173245,0.060993
