In [104]:
from glob import glob 
from fastcore.helpers import load_pickle
import pandas as pd 
import matplotlib.pyplot as plt 
import numpy as np 
from gpt3forchem.output import get_polymer_prompt_data
plt.style.use(['science', 'nature'])

  warn(
  warn(
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
all_results = glob('results/20220919_polymer_inverse/*.pkl')

In [3]:
all_results

['results/20220919_polymer_inverse/2022-09-20-16-33-10_results.pkl',
 'results/20220919_polymer_inverse/2022-09-19-17-29-13_results.pkl',
 'results/20220919_polymer_inverse/2022-09-19-17-01-14_results.pkl',
 'results/20220919_polymer_inverse/2022-09-20-16-57-08_results.pkl']

In [84]:
d = load_pickle(all_results[0])

In [109]:
pd.read_json(d['valid_file'], lines=True, orient='records')['prompt'].iloc[0]

'what is a polymer with large adsorption energy and 4 A, 4 B, 12 W, and 12 R?###'

In [106]:
pd.read_json(d['valid_file'], lines=True, orient='records')['prompt'].apply(lambda x: get_polymer_prompt_data(x)[0])

0      {'R': 2, 'W': 2, 'A': 4, 'B': 4}
1      {'R': 8, 'W': 8, 'A': 6, 'B': 8}
2      {'R': 8, 'W': 2, 'A': 2, 'B': 8}
3      {'R': 2, 'W': 6, 'A': 8, 'B': 2}
4      {'R': 0, 'W': 8, 'A': 0, 'B': 0}
                     ...               
308    {'R': 6, 'W': 0, 'A': 0, 'B': 2}
309    {'R': 4, 'W': 2, 'A': 0, 'B': 8}
310    {'R': 0, 'W': 4, 'A': 4, 'B': 8}
311    {'R': 0, 'W': 0, 'A': 4, 'B': 8}
312    {'R': 0, 'W': 4, 'A': 8, 'B': 6}
Name: prompt, Length: 313, dtype: object

In [49]:
def summarize_at_temperature(subd, t, d): 
    val_data = pd.read_json(d['valid_file'], lines=True, orient='records')
    
    mean_loss = np.mean(subd[0])
    mean_composition_mismatch = subd[1]['mean'].mean()
    mean_min_norm_lev = subd[1]['NormalizedLevenshtein_min'].min()
    mean_mean_norm_lev = subd[1]['NormalizedLevenshtein_mean'].mean()
    longest_common_subs_max_mean = (subd[1]['LongestCommonSubsequence_max']/subd[1]['smiles'].apply(lambda x: len(x))).mean()
    longest_common_subs_mean_mean = (subd[1]['LongestCommonSubsequence_mean']/subd[1]['smiles'].apply(lambda x: len(x))).mean()
    frac_valid = len(subd[1])/len(d['predictions'][t])
    return {
        'mean_loss': mean_loss, 
        'mean_composition_mismatch': mean_composition_mismatch, 
        'mean_min_norm_lev': mean_min_norm_lev,
        'mean_mean_norm_lev': mean_mean_norm_lev, 
        'longest_common_subs_max_mean': longest_common_subs_max_mean,
        'frac_valid': frac_valid
    }

In [68]:
combined_results = []
optimal_results = []
for res in all_results: 
    try:
        d = load_pickle(res) 
        for t, td in d['metrics'].items():
            td_sum = summarize_at_temperature(td, t, d)
            td_sum['temperature'] = t

            combined_results.append(td_sum)


        td_sum = summarize_at_temperature(d['optimal_metrics'], t, d)

        optimal_results.append(td_sum)
    except Exception:
        pass

In [69]:
optimal_results

[{'mean_loss': 3.093990835092083,
  'mean_composition_mismatch': 3.9536741214057507,
  'mean_min_norm_lev': 0.14545454545454545,
  'mean_mean_norm_lev': 0.39870773944352805,
  'longest_common_subs_max_mean': 0.550439674267297,
  'frac_valid': 1.0},
 {'mean_loss': 3.0630722756851894,
  'mean_composition_mismatch': 4.0095846645367414,
  'mean_min_norm_lev': 0.19607843137254902,
  'mean_mean_norm_lev': 0.3964039276245443,
  'longest_common_subs_max_mean': 0.5483179366169209,
  'frac_valid': 1.0},
 {'mean_loss': 3.1088144975065295,
  'mean_composition_mismatch': 4.017571884984026,
  'mean_min_norm_lev': 0.2033898305084746,
  'mean_mean_norm_lev': 0.3999965845012095,
  'longest_common_subs_max_mean': 0.5558546254085541,
  'frac_valid': 1.0}]

In [51]:
combined_results = pd.DataFrame(combined_results)

In [70]:
optimal_results = pd.DataFrame(optimal_results)

In [72]:
optimal_results.agg(['mean', 'std'])

Unnamed: 0,mean_loss,mean_composition_mismatch,mean_min_norm_lev,mean_mean_norm_lev,longest_common_subs_max_mean,frac_valid
mean,3.088626,3.99361,0.181641,0.398369,0.551537,1.0
std,0.023338,0.034815,0.031551,0.00182,0.003886,0.0


In [57]:
agg_frame = combined_results[['temperature', 'longest_common_subs_max_mean', 'mean_composition_mismatch', 'mean_loss', 'frac_valid']].round(2).groupby('temperature').agg(['mean', 'std', 'count'])
agg_frame

Unnamed: 0_level_0,longest_common_subs_max_mean,longest_common_subs_max_mean,longest_common_subs_max_mean,mean_composition_mismatch,mean_composition_mismatch,mean_composition_mismatch,mean_loss,mean_loss,mean_loss,frac_valid,frac_valid,frac_valid
Unnamed: 0_level_1,mean,std,count,mean,std,count,mean,std,count,mean,std,count
temperature,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
0.0,0.523333,0.005774,3,4.24,0.043589,3,3.096667,0.030551,3,1.0,0.0,3
0.25,0.533333,0.005774,3,4.24,0.043589,3,3.086667,0.011547,3,1.0,0.0,3
0.5,0.54,0.0,3,4.24,0.043589,3,3.06,0.02,3,1.0,0.0,3
0.75,0.546667,0.005774,3,4.243333,0.037859,3,3.073333,0.015275,3,1.0,0.0,3
1.0,0.55,0.0,3,4.266667,0.032146,3,3.066667,0.025166,3,1.0,0.0,3
1.25,0.553333,0.005774,3,4.2,0.079373,3,3.096667,0.015275,3,0.996667,0.005774,3
1.5,0.58,0.0,3,3.976667,0.120554,3,3.083333,0.063509,3,0.963333,0.005774,3


In [None]:
temperature & fraction valid & maximum common subsequence & composition mismatch & MAE \\

In [56]:
def create_classification_performance_table(df):
    rows = []
    row_template = "{temperature} & \\num⁍ {frac_valid} \\pm {frac_valid_std} ⁌ &  \\num⁍ {longest_common_subs_max_mean} \\pm {longest_common_subs_max_mean_std}  ⁌ & \\num⁍ {mean_composition_mismatch} \\pm {mean_composition_mismatch_std} ⁌  & \\num⁍ {mean_loss} \\pm {mean_loss_std} ⁌ \\\\"
    for i, row in df.round(2).iterrows():
        row_dict = {
            "temperature": row.name,
            "frac_valid": row["frac_valid"]["mean"],
            "frac_valid_std": row["frac_valid"]["std"],
            "longest_common_subs_max_mean": row["longest_common_subs_max_mean"]["mean"],
            "longest_common_subs_max_mean_std": row["longest_common_subs_max_mean"]["std"],
            "mean_composition_mismatch": row["mean_composition_mismatch"]["mean"],
            "mean_composition_mismatch_std": row["mean_composition_mismatch"]["std"],
            "mean_loss": row["mean_loss"]["mean"],
            "mean_loss_std": row["mean_loss"]["std"],
        }
        rows.append(row_template.format(**row_dict).replace("⁍", "{").replace("⁌", "}"))
    return rows

In [60]:
print('\n'.join(create_classification_performance_table(agg_frame)))

0.0 & \num{ 1.0 \pm 0.0 } &  \num{ 0.52 \pm 0.01  } & \num{ 4.24 \pm 0.04 }  & \num{ 3.1 \pm 0.03 } \\
0.25 & \num{ 1.0 \pm 0.0 } &  \num{ 0.53 \pm 0.01  } & \num{ 4.24 \pm 0.04 }  & \num{ 3.09 \pm 0.01 } \\
0.5 & \num{ 1.0 \pm 0.0 } &  \num{ 0.54 \pm 0.0  } & \num{ 4.24 \pm 0.04 }  & \num{ 3.06 \pm 0.02 } \\
0.75 & \num{ 1.0 \pm 0.0 } &  \num{ 0.55 \pm 0.01  } & \num{ 4.24 \pm 0.04 }  & \num{ 3.07 \pm 0.02 } \\
1.0 & \num{ 1.0 \pm 0.0 } &  \num{ 0.55 \pm 0.0  } & \num{ 4.27 \pm 0.03 }  & \num{ 3.07 \pm 0.03 } \\
1.25 & \num{ 1.0 \pm 0.01 } &  \num{ 0.55 \pm 0.01  } & \num{ 4.2 \pm 0.08 }  & \num{ 3.1 \pm 0.02 } \\
1.5 & \num{ 0.96 \pm 0.01 } &  \num{ 0.58 \pm 0.0  } & \num{ 3.98 \pm 0.12 }  & \num{ 3.08 \pm 0.06 } \\
