In [1]:
from fastcore.xtras import load_pickle, save_pickle

from rxnfp.transformer_fingerprints import (
    RXNBERTFingerprintGenerator, get_default_model_and_tokenizer, generate_fingerprints
)
import numpy as np
from gpt3forchem.api_wrappers import fine_tune, query_gpt3, extract_regression_prediction

model, tokenizer = get_default_model_and_tokenizer()

rxnfp_generator = RXNBERTFingerprintGenerator(model, tokenizer)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
completions, test_prompts_filtered_frame = load_pickle('run_files/ada:ft-lsmoepfl-2022-11-08-18-02-20_completions.pkl')

In [3]:
predictions = np.array([extract_regression_prediction(completions, i) for i, completion in enumerate(completions["choices"])])

In [4]:
true = test_prompts_filtered_frame['completion'].apply(lambda x: int(x.split('@@@')[0])).values

In [5]:
test_prompts_filtered_frame

Unnamed: 0,prompt,completion,repr
0,What is the yield of the reaction with the fol...,98@@@,[CH3:1][O:2][C:3]1[CH:4]=[C:5]2[CH2:14][CH:13]...
1,What is the yield of the reaction with the fol...,86@@@,[H-].[Na+].[CH2:3]([OH:7])[C:4]#[C:5][CH3:6].C...
2,What is the yield of the reaction with the fol...,73@@@,[CH2:1]([C:5]1([CH2:30][CH2:31][CH2:32][CH3:33...
3,What is the yield of the reaction with the fol...,97@@@,[CH2:1]([N:3]1[C:11]([CH:12]2[CH2:17][CH2:16][...
4,What is the yield of the reaction with the fol...,64@@@,[Cl:1][C:2]1[CH:31]=[CH:30][C:5]([CH2:6][NH:7]...
...,...,...,...
4994,What is the yield of the reaction with the fol...,53@@@,[CH3:1][O:2][C:3]1[CH:8]=[CH:7][C:6]([N:9]2[C:...
4996,What is the yield of the reaction with the fol...,90@@@,[Br:1][C:2]1[CH:3]=[CH:4][C:5]2[C:11](=O)/[C:1...
4997,What is the yield of the reaction with the fol...,113@@@,[CH3:1][O:2][C:3](=[O:20])[C:4]1[CH:9]=[C:8]([...
4998,What is the yield of the reaction with the fol...,61@@@,[Cl:1][C:2]1[CH:19]=[CH:18][C:5]([O:6][CH2:7][...


In [6]:
fps = [rxnfp_generator.convert(row['repr']) for _, row in test_prompts_filtered_frame.iterrows()]

In [7]:
errors = np.clip(np.abs(predictions - true), 0, 100)

In [8]:
errors

array([98.,  2.,  4., ..., 58., 24.,  1.])

In [9]:
save_pickle('for_tmap.pkl', (fps, errors, test_prompts_filtered_frame['repr']))

In [10]:
import tmap as tm
from tqdm import tqdm 
from faerun import Faerun

lf = tm.LSHForest(256, 128)
mh_encoder = tm.Minhash()

In [11]:
mhfps = [mh_encoder.from_weight_array(fp, method="I2CWS") for fp in tqdm(fps)]

100%|██████████| 4465/4465 [00:01<00:00, 3769.44it/s]


In [12]:
# slow
lf.batch_add(mhfps)
lf.index()

# Layout
cfg = tm.LayoutConfiguration()
cfg.k = 50
cfg.kc = 50
cfg.sl_scaling_min = 1.0
cfg.sl_scaling_max = 1.0
cfg.sl_repeats = 1
cfg.sl_extra_scaling_steps = 2
cfg.placer = tm.Placer.Barycenter
cfg.merger = tm.Merger.LocalBiconnected
cfg.merger_factor = 2.0
cfg.merger_adjustment = 0
cfg.fme_iterations = 1000
cfg.sl_scaling_type = tm.ScalingType.RelativeToDesiredLength
cfg.node_size = 1 / 37
cfg.mmm_repeats = 1



# Get tree coordinates
x, y, s, t, _ = tm.layout_from_lsh_forest(lf, config=cfg)

In [13]:
# slow
f = Faerun(clear_color="#222222", coords=False, view="front",)
    
f.add_scatter(
"ReactionAtlas",
{
    "x": x, "y": y, 
    "c": [
        errors
    ], 
    "labels": test_prompts_filtered_frame['repr']
},
shader="smoothCircle",
point_scale=2.0,
categorical=[
    False
],
has_legend=True,

series_title=[
    "Prediction error", 
],

title_index=2,
legend_title="",
)

f.add_tree("reactiontree", {"from": s, "to": t}, point_helper="ReactionAtlas")

In [14]:
plot = f.plot("reaction_smiles", template="reaction_smiles")