In [18]:
from glob import glob 
from fastcore.helpers import load_pickle
from gpt3forchem.output import get_polymer_prompt_data, composition_mismatch

import pandas as pd 
import numpy as np 

**ToDo: handle the excluded category in the metrics**

In [62]:
polymer_inverse_results = glob('results/20221117_polymer_inverse/*.pkl')

In [63]:
polymer_inverse_results

['results/20221117_polymer_inverse/2022-11-18-00-43-49_results.pkl',
 'results/20221117_polymer_inverse/2022-11-18-09-37-13_results.pkl',
 'results/20221117_polymer_inverse/2022-11-18-01-59-56_results.pkl',
 'results/20221117_polymer_inverse/2022-11-18-01-21-23_results.pkl',
 'results/20221117_polymer_inverse/2022-11-17-23-12-09_results.pkl',
 'results/20221117_polymer_inverse/2022-11-18-02-42-27_results.pkl',
 'results/20221117_polymer_inverse/2022-11-17-23-53-14_results.pkl',
 'results/20221117_polymer_inverse/2022-11-17-22-32-49_results.pkl',
 'results/20221117_polymer_inverse/2022-11-18-10-27-30_results.pkl',
 'results/20221117_polymer_inverse/2022-11-17-22-02-42_results.pkl',
 'results/20221117_polymer_inverse/2022-11-17-13-51-46_results.pkl']

In [64]:
res = load_pickle(polymer_inverse_results[-3])

In [65]:
res.keys()

dict_keys(['train_file', 'valid_file', 'modelname', 'predictions', 'completions', 'metrics', 'optimal_metrics', 'num_train_points', 'num_test_points', 'exclude_category'])

In [39]:
res['metrics'][0].keys()

dict_keys(['composition_mismatches', 'losses', 'kldiv_score', 'valid_smiles_fraction', 'unique_smiles_fraction', 'novel_smiles_fraction', 'generated_sequences'])

In [88]:
def summarize_at_temperature(subd, t, d):
    val_data = pd.read_json(d['valid_file'], lines=True, orient='records')
    desired_compositions = val_data['prompt'].apply(lambda x: get_polymer_prompt_data(x)[0])
    found_compositions = subd['composition_mismatches']['composition']
    composition_losses =[composition_mismatch(x, y) for x, y in zip(desired_compositions, found_compositions)]
    composition_losses = pd.DataFrame(composition_losses)
    mean_loss = np.mean(subd['losses'])
    mean_composition_mismatch = np.mean(composition_losses['mean'])
    
    mean_min_norm_lev = subd['composition_mismatches']['NormalizedLevenshtein_min'].min()
    mean_mean_norm_lev = subd['composition_mismatches']['NormalizedLevenshtein_mean'].mean()
    longest_common_subs_max_mean = (subd['composition_mismatches']['LongestCommonSubsequence_max']/subd['composition_mismatches']['smiles'].apply(lambda x: len(x))).mean()
    longest_common_subs_mean_mean = (subd['composition_mismatches']['LongestCommonSubsequence_mean']/subd['composition_mismatches']['smiles'].apply(lambda x: len(x))).mean()

    return {
        'mean_loss': mean_loss, 
        'mean_composition_mismatch': mean_composition_mismatch, 
        'mean_min_norm_lev': mean_min_norm_lev,
        'mean_mean_norm_lev': mean_mean_norm_lev, 
        'longest_common_subs_max_mean': longest_common_subs_max_mean,
        'frac_valid': subd['valid_smiles_fraction'],
        'frac_novel': subd['novel_smiles_fraction'],
        'frac_unique': subd['unique_smiles_fraction'],
        'kldiv': subd['kldiv_score'],
        'exclude_category': d['exclude_category'] if 'exclude_category' in d and d['exclude_category'] else "None",
    }

In [89]:
combined_results = []
optimal_results = []
for res in polymer_inverse_results: 
    try:
        d = load_pickle(res) 
        for t, td in d['metrics'].items():
            td_sum = summarize_at_temperature(td, t, d)
            td_sum['temperature'] = t

            combined_results.append(td_sum)


        td_sum = summarize_at_temperature(d['optimal_metrics'], t, d)

        optimal_results.append(td_sum)
    except Exception as e:
        print(e)
        pass

tuple indices must be integers or slices, not str


In [90]:
combined_results

[{'mean_loss': 3.1831838672622568,
  'mean_composition_mismatch': 0.2364217252396166,
  'mean_min_norm_lev': 0.23076923076923078,
  'mean_mean_norm_lev': 0.39776472262916895,
  'longest_common_subs_max_mean': 0.46317247249908217,
  'frac_valid': 1.0,
  'frac_novel': 0.9041533546325878,
  'frac_unique': 0.9041533546325878,
  'kldiv': 0.7320543462034269,
  'exclude_category': 'None',
  'temperature': 0},
 {'mean_loss': 3.1161642055510743,
  'mean_composition_mismatch': 0.2356230031948882,
  'mean_min_norm_lev': 0.20930232558139536,
  'mean_mean_norm_lev': 0.40535658238788075,
  'longest_common_subs_max_mean': 0.4727222487058808,
  'frac_valid': 1.0,
  'frac_novel': 1.0,
  'frac_unique': 1.0,
  'kldiv': 0.9737706489468884,
  'exclude_category': 'None',
  'temperature': 0.25},
 {'mean_loss': 3.0562803322483587,
  'mean_composition_mismatch': 0.23722044728434505,
  'mean_min_norm_lev': 0.23529411764705882,
  'mean_mean_norm_lev': 0.409555279597167,
  'longest_common_subs_max_mean': 0.478827

In [91]:
combined_results = pd.DataFrame(combined_results)
optimal_results = pd.DataFrame(optimal_results)

In [92]:
combined_results.groupby('exclude_category').mean()

Unnamed: 0_level_0,mean_loss,mean_composition_mismatch,mean_min_norm_lev,mean_mean_norm_lev,longest_common_subs_max_mean,frac_valid,frac_novel,frac_unique,kldiv,temperature
exclude_category,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
,3.097136,0.948711,0.231964,0.411237,0.468905,0.993163,0.98345,0.98345,0.895397,0.735
very large,1.240223,0.657578,0.216544,0.429968,0.418076,0.993257,0.898743,0.898743,0.643746,0.75


In [93]:
agg_frame = combined_results[['temperature', 'longest_common_subs_max_mean', 'mean_composition_mismatch', 'mean_loss', 'frac_valid', 'frac_novel', 'frac_unique', 'kldiv', 'exclude_category']].round(2).groupby(by=['temperature', 'exclude_category']).agg(['mean', 'std', 'count'])
agg_frame

Unnamed: 0_level_0,Unnamed: 1_level_0,longest_common_subs_max_mean,longest_common_subs_max_mean,longest_common_subs_max_mean,mean_composition_mismatch,mean_composition_mismatch,mean_composition_mismatch,mean_loss,mean_loss,mean_loss,frac_valid,frac_valid,frac_valid,frac_novel,frac_novel,frac_novel,frac_unique,frac_unique,frac_unique,kldiv,kldiv,kldiv
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,count,mean,std,count,mean,std,count,mean,...,count,mean,std,count,mean,std,count,mean,std,count
temperature,exclude_category,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2
0.0,,0.44875,0.016421,8,0.2375,0.010351,8,3.11375,0.052082,8,1.0,...,8,0.89625,0.019955,8,0.89625,0.019955,8,0.72625,0.011877,8
0.0,very large,0.4,0.0,2,0.245,0.007071,2,0.88,0.070711,2,1.0,...,2,0.295,0.007071,2,0.295,0.007071,2,0.44,0.014142,2
0.25,,0.457143,0.01496,7,0.235714,0.011339,7,3.095714,0.025071,7,1.0,...,7,1.0,0.0,7,1.0,0.0,7,0.971429,0.006901,7
0.25,very large,0.41,0.0,2,0.245,0.007071,2,1.14,0.028284,2,1.0,...,2,1.0,0.0,2,1.0,0.0,2,0.7,0.014142,2
0.5,,0.465714,0.013973,7,0.238571,0.01069,7,3.081429,0.015736,7,1.0,...,7,1.0,0.0,7,1.0,0.0,7,0.952857,0.007559,7
0.5,very large,0.42,0.0,2,0.245,0.007071,2,1.255,0.021213,2,1.0,...,2,1.0,0.0,2,1.0,0.0,2,0.7,0.0,2
0.75,,0.465714,0.013973,7,0.318571,0.21248,7,3.065714,0.025071,7,1.0,...,7,1.0,0.0,7,1.0,0.0,7,0.932857,0.011127,7
0.75,very large,0.42,0.0,2,0.245,0.007071,2,1.33,0.042426,2,1.0,...,2,1.0,0.0,2,1.0,0.0,2,0.67,0.014142,2
1.0,,0.468571,0.015736,7,0.534286,0.598856,7,3.084286,0.022991,7,0.998571,...,7,1.0,0.0,7,1.0,0.0,7,0.91,0.014142,7
1.0,very large,0.42,0.0,2,0.44,0.070711,2,1.345,0.035355,2,1.0,...,2,1.0,0.0,2,1.0,0.0,2,0.655,0.007071,2


In [61]:
optimal_results.agg(['mean', 'std'])

Unnamed: 0,mean_loss,mean_composition_mismatch,mean_min_norm_lev,mean_mean_norm_lev,longest_common_subs_max_mean,frac_valid,frac_novel,frac_unique,kldiv
mean,3.102406,0.0,0.241295,0.398214,0.471575,1.0,1.0,1.0,0.986653
std,0.02682,0.0,0.014233,0.004026,0.013839,0.0,0.0,0.0,0.002894


In [54]:
def create_classification_performance_table(df):
    rows = []
    row_template = "{temperature} & \\num⁍ {frac_valid} \\pm {frac_valid_std} ⁌ &   \\num⁍ {frac_novel} \\pm {frac_novel_std} ⁌   &  \\num⁍ {kl_div} \\pm {kl_div_std} ⁌   & \\num⁍ {longest_common_subs_max_mean} \\pm {longest_common_subs_max_mean_std}  ⁌ & \\num⁍ {mean_composition_mismatch} \\pm {mean_composition_mismatch_std} ⁌  & \\num⁍ {mean_loss} \\pm {mean_loss_std} ⁌    \\\\"
    for i, row in df.round(2).iterrows():
        row_dict = {
            "temperature": row.name,
            "frac_valid": row["frac_valid"]["mean"],
            "frac_valid_std": row["frac_valid"]["std"],
            "frac_novel": row["frac_novel"]["mean"],
            "frac_novel_std": row["frac_novel"]["std"],
            "kl_div": row["kldiv"]["mean"],
            "kl_div_std": row["kldiv"]["std"],
            "longest_common_subs_max_mean": row["longest_common_subs_max_mean"]["mean"],
            "longest_common_subs_max_mean_std": row["longest_common_subs_max_mean"]["std"],
            "mean_composition_mismatch": row["mean_composition_mismatch"]["mean"],
            "mean_composition_mismatch_std": row["mean_composition_mismatch"]["std"],
            "mean_loss": row["mean_loss"]["mean"],
            "mean_loss_std": row["mean_loss"]["std"],
        }
        rows.append(row_template.format(**row_dict).replace("⁍", "{").replace("⁌", "}"))
    return rows

In [58]:
print('\n'.join(create_classification_performance_table(agg_frame)))

0.0 & \num{ 1.0 \pm 0.0 } &   \num{ 0.9 \pm 0.02 }   &  \num{ 0.73 \pm 0.01 }   & \num{ 0.45 \pm 0.02  } & \num{ 0.24 \pm 0.01 }  & \num{ 3.11 \pm 0.05 }    \\
0.25 & \num{ 1.0 \pm 0.0 } &   \num{ 1.0 \pm 0.0 }   &  \num{ 0.97 \pm 0.01 }   & \num{ 0.46 \pm 0.01  } & \num{ 0.24 \pm 0.01 }  & \num{ 3.1 \pm 0.03 }    \\
0.5 & \num{ 1.0 \pm 0.0 } &   \num{ 1.0 \pm 0.0 }   &  \num{ 0.95 \pm 0.01 }   & \num{ 0.47 \pm 0.01  } & \num{ 0.24 \pm 0.01 }  & \num{ 3.08 \pm 0.02 }    \\
0.75 & \num{ 1.0 \pm 0.0 } &   \num{ 1.0 \pm 0.0 }   &  \num{ 0.93 \pm 0.01 }   & \num{ 0.47 \pm 0.01  } & \num{ 0.32 \pm 0.21 }  & \num{ 3.07 \pm 0.03 }    \\
1.0 & \num{ 1.0 \pm 0.0 } &   \num{ 1.0 \pm 0.0 }   &  \num{ 0.91 \pm 0.01 }   & \num{ 0.47 \pm 0.02  } & \num{ 0.53 \pm 0.6 }  & \num{ 3.08 \pm 0.02 }    \\
1.25 & \num{ 0.99 \pm 0.01 } &   \num{ 1.0 \pm 0.0 }   &  \num{ 0.9 \pm 0.02 }   & \num{ 0.48 \pm 0.01  } & \num{ 1.66 \pm 1.2 }  & \num{ 3.1 \pm 0.02 }    \\
1.5 & \num{ 0.96 \pm 0.02 } &   \num{ 1.0 \pm