In [84]:
from gpt3forchem.api_wrappers import fine_tune, query_gpt3, extract_regression_prediction

from glob import glob 
import pandas as pd 

import time 
import numpy as np 
from rdkit import Chem
from guacamol.utils.chemistry import is_valid
from gpt3forchem.output import test_inverse_bandgap


from fastcore.xtras import save_pickle

In [65]:
PROMPT_TEMPLATE_bandgap_inverse = "What is a molecule with a bandgap of {} eV###"
COMPLETION_TEMPLATE_bandgap_inverse = "{}@@@"


def generate_inverse_photoswitch_prompts(
    data: pd.DataFrame
) -> pd.DataFrame:
    prompts = []
    completions = []
    smiles = []
    for i, row in data.iterrows():
        
        if is_valid(row["smiles"]):
            prompt = PROMPT_TEMPLATE_bandgap_inverse.format(
                np.round(row["gap"] , 1)
            )
            smiles_ = Chem.MolToSmiles(Chem.MolFromSmiles(row["smiles"]))
            completion = COMPLETION_TEMPLATE_bandgap_inverse.format(smiles_)
            prompts.append(prompt)
            completions.append(completion)
            smiles.append(smiles_)

    prompts = pd.DataFrame(
        {"prompt": prompts, "completion": completions, "SMILES": smiles}
    )

    return prompts

In [3]:
def load_gaps(filename):
    with open(filename) as f:
        lines = f.readlines()

    smiles_file = filename.replace("_gaps", "")

    smiles_ = []
    with open(smiles_file) as f:
        for line in f.readlines():
            smiles_.append(line.strip())
    

    gaps = []
    smiles = []
    path = []
    for i, line in enumerate(lines):
        try:
            smile, gap = line.split()
            num = int(smile.split('/')[-1])
            gaps.append(float(gap))
            smiles.append(smiles_[num])
            path.append(smile)
        except:
            pass
    temperature = float(filename.split("sets")[-1].split("_")[0])
    return pd.DataFrame({"smiles": smiles, "gap": gaps, "temperature": temperature, "path": path})

In [30]:
def compile_res(files): 
    res = []
    for f in files:
        try: 
            res.append(load_gaps(f))
        except Exception:
            pass

    df = pd.concat(res)
    df = df.sort_values("temperature")
    return df

In [32]:
def train_inverse_model(train_prompts, representation):

    train_size = len(train_prompts)

    filename_base = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    train_filename = f"run_files/{filename_base}_iterative_train_prompts_bandgap_inverse_{representation}_{train_size}.jsonl"

    train_prompts.to_json(train_filename, orient="records", lines=True)

    modelname = fine_tune(train_filename, train_filename, "ada")
    return modelname, train_filename

In [15]:
initial_sample_for_bias = glob('for_more_xtb_opt/2022-12-07-15-51-30_smiles_extrapolation_sets*gaps.txt') + glob('for_more_xtb_opt/2022-12-07-15-51-30_smiles_extrapolation*gaps.txt')

iteration_1 = glob("for_more_xtb_opt/2022-12-09-23-20-30_smiles_iteration_1_biased_sets*gaps.txt")
iteration_2 = glob("for_more_xtb_opt/2022-12-11-23-54-28_smiles_iteration_2_biased_sets*gaps.txt")

In [16]:
initial_sample_res = compile_res(initial_sample_for_bias)
iteration_1_res = compile_res(iteration_1)
iteration_2_res = compile_res(iteration_2)

In [18]:
all_res = pd.concat([initial_sample_res, iteration_1_res, iteration_2_res])

In [73]:
large_gap = all_res[all_res["gap"] > 4.2]

In [74]:
large_gap = large_gap.drop_duplicates('smiles')

In [75]:
large_gap_prompts = generate_inverse_photoswitch_prompts(large_gap)

In [76]:
large_gap_prompts

Unnamed: 0,prompt,completion,SMILES
0,What is a molecule with a bandgap of 17.4 eV###,C@@@,C
1,What is a molecule with a bandgap of 4.3 eV###,FC(F)(F)n1ccnc1@@@,FC(F)(F)n1ccnc1
2,What is a molecule with a bandgap of 7.1 eV###,C#CC@@@,C#CC
3,What is a molecule with a bandgap of 4.6 eV###,[H]/N=C(\N)N1CCNCC1@@@,[H]/N=C(\N)N1CCNCC1
4,What is a molecule with a bandgap of 4.4 eV###,C/C=C\N1CCCC1@@@,C/C=C\N1CCCC1
...,...,...,...
378,What is a molecule with a bandgap of 11.5 eV###,NCN@@@,NCN
379,What is a molecule with a bandgap of 8.8 eV###,CON@@@,CON
380,What is a molecule with a bandgap of 4.5 eV###,Oc1c[nH]cn1@@@,Oc1c[nH]cn1
381,What is a molecule with a bandgap of 5.8 eV###,NP@@@,NP


In [78]:
train_inverse_model(large_gap_prompts, "smiles")

Traceback (most recent call last):
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/bin/openai", line 8, in <module>
    sys.exit(main())
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/_openai_scripts.py", line 63, in main
    args.func(args)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/cli.py", line 545, in sync
    resp = openai.wandb_logger.WandbLogger.sync(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 74, in sync
    fine_tune_logged = [
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 75, in <listcomp>
    cls._log_fine_tune(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 125, in _log_fine_tune
    wandb_run = cls._get_wandb_run(run_path)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/pyth

('ada:ft-lsmoepfl-2022-12-13-15-11-46',
 'run_files/2022-12-13-16-07-07_iterative_train_prompts_bandgap_inverse_smiles_383.jsonl')

In [36]:
def test_inverse_model(
    modelname,
    test_prompts,
    df_train,
    max_tokens: int = 250,
    temperatures=None,
    representation="SMILES",
):
    temperatures = temperatures or [0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5]
    train_smiles = df_train["SMILES"].to_list()
    results = []
    for temperature in temperatures:
        try:
            print(f"Testing temperature {temperature} for {representation}")
            result = test_inverse_bandgap(
                test_prompts,
                modelname,
                train_smiles=train_smiles,
                temperature=temperature,
                max_tokens=max_tokens,
                representation=representation,
            )

            results.append(result)
        except Exception as e:
            print(e)
            pass

    return results

In [81]:
random_bandgaps = np.random.normal(5.0, 0.2, size=len(large_gap_prompts)*3) 
test_set_biased = pd.concat([large_gap.copy(), large_gap.copy(), large_gap.copy()])

test_set_biased['gap'] = random_bandgaps
filename_base = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())

test_prompts = generate_inverse_photoswitch_prompts(test_set_biased)
valid_filename_random_biased = f"run_files/{filename_base}_iterative_valid_prompts_bandgap_inverse_smiles_random_biased.jsonl"
test_set_biased.to_json(valid_filename_random_biased, orient="records", lines=True)

In [83]:
iteration_3_test_results = test_inverse_model("ada:ft-lsmoepfl-2022-12-13-15-11-46", test_prompts, large_gap_prompts)

Testing temperature 0 for SMILES
Internal server error
Testing temperature 0.25 for SMILES


2022-12-13 16:55:37.393 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:931 - Got predictions, example: CC(C)Cc1ccccc1
2022-12-13 16:55:37.395 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:942 - Loaded predictions. Example: CC(C)Cc1ccccc1
2022-12-13 16:55:41.132 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:953 - Calculating Frechet ChemNet distance for 383 samples
2022-12-13 16:55:41.253 | INFO     | gpt3forchem.output:_load_chemnet:218 - Saved ChemNet model to '/var/folders/m9/_txh68y946s4pxy1x2wnd3lh0000gn/T/ChemNet_v0.13_pretrained.h5'
2022-12-13 16:55:42.035568: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2022-12-13 16:55:47.035 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:962 - Computed frechet score: (17.62690471042294, 0.029440590072918955)
2022-12-13 16:55:54.760 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:973 - Computed KL div score: 0.45009408235468606


Testing temperature 0.5 for SMILES


2022-12-13 17:02:43.924 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:931 - Got predictions, example: Cc1cccc(C(F)(F)F)c1
2022-12-13 17:02:43.948 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:942 - Loaded predictions. Example: Cc1cccc(C(F)(F)F)c1
2022-12-13 17:02:44.118 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:953 - Calculating Frechet ChemNet distance for 383 samples
2022-12-13 17:02:44.220 | INFO     | gpt3forchem.output:_load_chemnet:218 - Saved ChemNet model to '/var/folders/m9/_txh68y946s4pxy1x2wnd3lh0000gn/T/ChemNet_v0.13_pretrained.h5'
2022-12-13 17:02:49.086 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:962 - Computed frechet score: (11.186543144741442, 0.10674540962720445)
2022-12-13 17:02:50.433 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:973 - Computed KL div score: 0.837534923388844


Testing temperature 0.75 for SMILES


2022-12-13 17:09:34.379 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:931 - Got predictions, example: CCCCCCCCCCCCCCCCCCCCCC(=O)NCc1ccccc1
2022-12-13 17:09:34.381 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:942 - Loaded predictions. Example: CCCCCCCCCCCCCCCCCCCCCC(=O)NCc1ccccc1
2022-12-13 17:09:34.482 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:953 - Calculating Frechet ChemNet distance for 383 samples
2022-12-13 17:09:34.945 | INFO     | gpt3forchem.output:_load_chemnet:218 - Saved ChemNet model to '/var/folders/m9/_txh68y946s4pxy1x2wnd3lh0000gn/T/ChemNet_v0.13_pretrained.h5'
2022-12-13 17:09:51.905 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:962 - Computed frechet score: (6.121723241495822, 0.2939502780282441)
2022-12-13 17:09:53.396 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:973 - Computed KL div score: 0.9315006115979767


Testing temperature 1.0 for SMILES


2022-12-13 17:16:40.220 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:931 - Got predictions, example: CCCCCCCCCCC(=O)NCc1ccccc1
2022-12-13 17:16:40.222 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:942 - Loaded predictions. Example: CCCCCCCCCCC(=O)NCc1ccccc1
2022-12-13 17:16:40.331 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:953 - Calculating Frechet ChemNet distance for 383 samples
2022-12-13 17:16:40.392 | INFO     | gpt3forchem.output:_load_chemnet:218 - Saved ChemNet model to '/var/folders/m9/_txh68y946s4pxy1x2wnd3lh0000gn/T/ChemNet_v0.13_pretrained.h5'
2022-12-13 17:16:45.694 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:962 - Computed frechet score: (3.766299699694244, 0.47082917339362973)
2022-12-13 17:16:47.406 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:973 - Computed KL div score: 0.9489100808942522


Testing temperature 1.25 for SMILES


2022-12-13 17:23:36.324 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:931 - Got predictions, example: FC(=O)O
2022-12-13 17:23:36.325 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:942 - Loaded predictions. Example: FC(=O)O
2022-12-13 17:23:36.505 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:953 - Calculating Frechet ChemNet distance for 383 samples
2022-12-13 17:23:36.566 | INFO     | gpt3forchem.output:_load_chemnet:218 - Saved ChemNet model to '/var/folders/m9/_txh68y946s4pxy1x2wnd3lh0000gn/T/ChemNet_v0.13_pretrained.h5'
2022-12-13 17:23:41.978 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:962 - Computed frechet score: (3.7839300068054698, 0.4691719243069576)
2022-12-13 17:23:43.372 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:973 - Computed KL div score: 0.9420475770785324


Testing temperature 1.5 for SMILES


2022-12-13 17:30:34.883 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:931 - Got predictions, example: COc1cpc[nH]c1
2022-12-13 17:30:34.884 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:942 - Loaded predictions. Example: COc1cpc[nH]c1
2022-12-13 17:30:35.118 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:953 - Calculating Frechet ChemNet distance for 345 samples
2022-12-13 17:30:35.172 | INFO     | gpt3forchem.output:_load_chemnet:218 - Saved ChemNet model to '/var/folders/m9/_txh68y946s4pxy1x2wnd3lh0000gn/T/ChemNet_v0.13_pretrained.h5'
2022-12-13 17:30:38.276 | DEBUG    | gpt3forchem.output:test_inverse_bandgap:962 - Computed frechet score: (9.607042235652898, 0.14640061932079423)
Traceback (most recent call last):
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/rdkit/ML/Descriptors/MoleculeDescriptors.py", line 88, in CalcDescriptors
    res[i] = fn(mol)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-pack

In [85]:
save_pickle(f"run_files/{filename_base}-iteration_3_results_extrapolation_smiles_more.pkl", iteration_3_test_results)

In [86]:
smiles_random_biased_sets = []

for res in iteration_3_test_results:
    result = {}
    result["temperature"] = res["meta"]["temperature"]
    result["smiles"] = set(
        res["predictions"][res["valid_smiles"]]
    )
    result['original_prediction_indices'] = [i for i, x in enumerate(res["predictions"]) if x in result['smiles']]
    result['expected'] = [res['expectations'][i] for i in result['original_prediction_indices']]
    smiles_random_biased_sets.append(result)

for res in smiles_random_biased_sets:
    temp = res["temperature"]
    smiles_set = res["smiles"]
    with open(f"for_more_xtb_opt/{filename_base}_smiles_iteration_3_biased_sets{temp}.txt", "w") as f:
        for i, smiles in enumerate(smiles_set):
            if i != len(smiles_set) - 1:
                f.write(smiles + "\n")
            else:
                f.write(smiles)