In [2]:
%load_ext autoreload
%autoreload 2

In [6]:
import pandas as pd
from sklearn.model_selection import train_test_split

from gptchem.data import get_matbench_is_metal, get_photoswitch_data
from gptchem.querier import Querier

In [1]:
PROMPT_TEMPLATE = """Is does {smiles} have a large transition wavelength based on the exammples below?

Examples:
---------
{examples}
"""

In [4]:
data = get_photoswitch_data()

In [8]:
data["class"] = pd.qcut(data["E isomer pi-pi* wavelength in nm"], 2, labels=[False, True])

In [9]:
train, test = train_test_split(data, test_size=0.2, random_state=42)

In [10]:
data

Unnamed: 0,SMILES,rate of thermal isomerisation from Z-E in s-1,Solvent used for thermal isomerisation rates,Z PhotoStationaryState,E PhotoStationaryState,E isomer pi-pi* wavelength in nm,Extinction,E isomer n-pi* wavelength in nm,Extinction coefficient in M-1 cm-1,Z isomer pi-pi* wavelength in nm,...,CAM-B3LYP/6-31G** DFT Z isomer n-pi* wavelength in nm,BHLYP/6-31G* DFT E isomer pi-pi* wavelength in nm,BHLYP/6-31G* DFT E isomer n-pi* wavelength in nm,BHLYP/6-31G* Z isomer pi-pi* wavelength in nm,BHLYP/6-31G* DFT Z isomer n-pi* wavelength in nm,name,selfies,wavelength_cat,inchi,class
0,C[N]1C=CC(=N1)N=NC2=CC=CC=C2,2.100000e-07,MeCN,76.0,72.0,310.0,1.67,442.0,0.0373,290.0,...,,,,,,,[C][NH0][N][=N][C][=Branch1][Ring2][=N][Ring1]...,very small,InChI=1S/C10H10N4/c1-14-8-7-10(13-14)12-11-9-5...,False
1,C[N]1C=NC(=N1)N=NC2=CC=CC=C2,3.800000e-07,MeCN,90.0,84.0,310.0,1.87,438.0,0.0505,272.0,...,,,,,,,[C][NH0][C][=N][C][=Branch1][Ring2][=N][Ring1]...,very small,InChI=1S/C9H9N5/c1-14-7-10-9(13-14)12-11-8-5-3...,False
2,C[N]1C=C(C)C(=N1)N=NC2=CC=CC=C2,1.500000e-06,MeCN,96.0,87.0,325.0,1.74,428.0,0.0612,286.0,...,,,,,,,[C][NH0][C][=C][Branch1][C][C][C][=Branch1][Ri...,very small,InChI=1S/C11H12N4/c1-9-8-15(2)14-11(9)13-12-10...,False
3,C[N]1C=C(C=N1)N=NC2=CC=CC=C2,7.600000e-09,MeCN,98.0,70.0,328.0,1.66,417.0,0.0640,275.0,...,401.0,,,,,,[C][NH0][C][=C][Branch1][Branch1][C][=N][Ring1...,very small,InChI=1S/C10H10N4/c1-14-8-10(7-11-14)13-12-9-5...,False
4,C[N]1N=C(C)C(=C1C)N=NC2=CC=CC=C2,7.700000e-07,MeCN,98.0,98.0,335.0,2.27,425.0,0.0963,296.0,...,449.0,,,,,"phenyl-(1,3,5-trimethylpyrazol-4-yl)diazene",[C][NH0][N][=C][Branch1][C][C][C][=Branch1][Br...,very small,InChI=1S/C12H14N4/c1-9-12(10(2)16(3)15-9)14-13...,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
398,OC%38=C%39N=CC=CC%39=C(/N=N/C%40=NC%41=CC(C)=C...,,,,,456.0,,,,,...,,,,,,,[O][C][=C][N][=C][C][=C][C][Ring1][=Branch1][=...,medium,InChI=1S/C18H14N4OS/c1-10-8-14-16(9-11(10)2)24...,True
399,OC%42=C%43N=CC=CC%43=C(/N=N/C%44=NC%45=CC=CC=C...,,,,,437.0,,,,,...,,,,,,,[O][C][=C][N][=C][C][=C][C][Ring1][=Branch1][=...,medium,InChI=1S/C16H11N5O/c22-14-8-7-11(10-4-3-9-17-1...,True
400,N#CC1C(SC(/N=N/C2=NC(C=CC([N+]([O-])=O)=C3)=C3...,,,,,545.0,,,,,...,,,,,,,[N][#C][C][C][Branch2][Ring2][#Branch2][S][C][...,large,InChI=1S/C18H20N6O2S2/c19-9-12-15(10-4-2-1-3-5...,True
401,N#Cc5c(c6ccc(Cl)cc6)c(/N=N/C7=NC(C=CC([N+]([O-...,,,,,535.0,,,,,...,,,,,,,[N][#C][C][C][Branch1][N][C][=C][C][=C][Branch...,large,InChI=1S/C18H9ClN6O2S2/c19-10-3-1-9(2-4-10)15-...,True


In [11]:
def create_example_string(
    data,
    num_examples: int = 10,
    representation_col: str = "SMILES",
    value_col: str = "class",
):
    examples = []
    for i, row in data.sample(num_examples).iterrows():
        examples.append(f"- {row[representation_col]}: {row[value_col]}")
    return "\n".join(examples)

In [12]:
print(create_example_string(data))

- [H]C1=C(C=C([H])C(/N=N/C2=C(C#N)C=C([N+]([O-])=O)C=C2C#N)=C1)N(CC)CC: True
- CN1C(/N=N/C2=CC=CC=C2)=CC=C1: False
- CCN(CC)C1=CC=C(/N=N/C2=CC=CC=C2)C=C1: True
- Cc1[nH]c2c(c1N=Nc1n[nH]c(n1)S)cccc2: True
- CC1=C(C(C)=NN1)/N=N/C2=CC=C(Cl)C=C2: False
- CC1=C(C(C)=NN1)/N=N/C2=CC(OC)=CC=C2: False
- CC1=C(/N=N/C2=CC(F)=CC(F)=C2)C(C)=NO1: False
- CN(C)C1=C(/N=N/C2=CC=CC=C2)C=NC=C1: False
- CC1=C(/N=N/C2=CC=C(N(C)C)C=C2C)C=C([N+]([O-])=O)C=C1: True
- ClC%11=CC([N+]([O-])=O)=CC(C#N)=C%11/N=N/C%12=CC([H])=C(C=C%12OC)N(CC)CC: True


In [16]:
prompt = PROMPT_TEMPLATE.format(
    smiles=test["SMILES"].iloc[2], examples=create_example_string(train, 100)
)

In [14]:
", ".join(test.iloc[:50]["SMILES"].to_list())

'CC1=C(C(C)=NN1)/N=N/C2=CC=C(Br)C=C2, FC1=CC=C(/N=N/C2=CC=C(NCCC#N)C=C2)C=C1, N#CCCNC(C=C1)=CC=C1/N=N/C2=CC=CC=C2C#N, FC1=C(F)C=C(F)C(F)=C1/N=N/C2=C(F)C(F)=CC(F)=C2F, NC1=CC(CCC2=C3C=CC(N)=C2)=C(/N=N\\3)C=C1, CC1=C(C(C)=NN1)/N=N/C2=CC=C(OC)C=C2, CCN(CC)C1=CC=C(/N=N/C2=CC=CC=C2)C=C1, OC%14=C%15N=CC=CC%15=C(/N=N/C%16=NC(C%17=CC=CC=C%17)=CS%16)C=C%14, OCCN(CCC#N)C(C=C%13)=CC=C%13/N=N/C%14=CC=CC=C%14, Oc1c(/C=N/c2c(C(C)C)cc(S(=O)(O)=O)cc2C(C)C)cc(/N=N/c3ccc([N+]([O-])=O)cc3)cc1, ClC7=CC=CC=C7N=NC8=NC9=CC=C([N+]([O-])=O)C=C9S8, CC1=C(C(C)=NN1)/N=N/C2=C(C)C=CC=C2, N#CCCNC(C=C1)=CC=C1/N=N/C2=CC=C(C(F)(F)F)C=C2, OC%26=C(N=CC=C%27)C%27=C(/N=N/C%28=CC=C(C(C)=O)C=C%28)C=C%26, FC1=CC=C(/N=N/C2=CC=C(F)C=C2)C=C1, FC1=CC(/N=N/C2=CC=C(NCCC#N)C=C2)=CC=C1, CCN(CC)C1=CC=C(/N=N/C2=CC=C(C#N)C=C2)C=C1, [H]C1=CC([N+]([O-])=O)=CC([H])=C1/N=N/C2=CC([H])=C(C=C2[H])N(CC)CC, CN1C(/N=N/C2=CC=CC=C2)=C(C)C=C1C, O=[N+]([O-])C1=CC=C(/N=N/C2=CC=C(C=C2)[N+]([O-])=O)C=C1, C1(/N=N/C2=CC=CC=C2)=CC=CN=C1, OC%38=C%39N=CC=CC%

In [81]:
", ".join(test.iloc[:50]["is_metal"].astype(str).to_list())

'True, True, False, False, True, False, False, False, False, True, True, False, True, True, False, True, True, True, False, True, False, False, True, False, True, True, True, True, True, True, False, False, True, False, False, True, True, False, True, True, True, True, False, False, False, True, False, False, True, False'

In [17]:
print(prompt)

Is does N#CCCNC(C=C1)=CC=C1/N=N/C2=CC=CC=C2C#N have a large transition wavelength based on the exammples below?

Examples:
---------
- ClC(C=C1)=CC=C1/N=N/C2=CC=C(N(C)C)C=C2: True
- C12=CC=CC=C1CCC3=CC=CC=C3/N=N\2: nan
- FC1=CC(C(OCC)=O)=CC(F)=C1/N=N/C2=CC=C(C(OCC)=O)C=C2: False
- [H]C7=CC=C(N=C(N=NC8=CC=CC(CO)=C8)S9)C9=C7: False
- CC(C=C1)=CC=C1/N=N/C2=CC=C(N(CC)CC)C=C2: True
- C1(/N=N/C2=CC=NN2)=CC=NN1: False
- NC1=CC=C(/N=N/C2=CC=C(N(CC)CC)C=C2)C=C1: True
- IC1=C(/N=N/C2=CC=CC=C2)C=NC=C1: False
- OC(C=C(O)C=C1)=C1/N=N/C2=CC=CC=C2: True
- CC1=C(C(C)=NN1)/N=N/C2=CC(C)=CC=C2: False
- CC(S1)=CC=C1/N=N/C2=CC=C(C(F)(F)F)C=C2: False
- C12=CC=CC=C1N=C(/N=N/C3=NC4=CC=CC=C4O3)O2: True
- CC1=CC=CC=C1/N=N/C2=CC=CC=C2C: False
- O=[N+]([O-])C1=C(/N=N/C2=C([N+]([O-])=O)C=CC=C2)C=CC=C1: False
- CC(C=C(N(CCC#N)CCO)C=C1)=C1/N=N/C2=C(F)C=CC=C2: True
- N#CCCNC(C=C1)=CC=C1/N=N/C2=CC=CC(OC)=C2: True
- CC(C=C(N(CCC#N)CCO)C=C1)=C1/N=N/C2=C(C#N)C=CC=C2: True
- CC1=NOC(C)=C1/N=N/C2=C(Cl)C=CC=C2: False
- CCN(

In [22]:
querier = Querier("ada", max_tokens=600)

In [26]:
for i, row in test.iterrows():
    print(row["SMILES"], row["E isomer pi-pi* wavelength in nm"])

CN1C(/N=N/C2=CC=CC=C2)=C(C)C=C1C 394.0
COC(C=CN=C1)=C1/N=N/C2=CC=CC=C2 332.0
FC1=C(F)C=C(F)C(F)=C1/N=N/C2=C(F)C(F)=CC(F)=C2F 303.0
CC(S1)=CC=C1/N=N/C2=CC=C(C(F)(F)F)C=C2 378.0
N#CCCNC(C=C1)=CC=C1/N=N/C2=CC=CC=C2OC 394.0
OC7=C8N=CC=CC8=C(/N=N/C9=NC=C(S(=O)(C%10=CC=C([N+]([O-])=O)C=C%10)=O)S9)C=C7 491.0
CC1=C(C(C)=NN1)/N=N/C2=CC=C(Cl)C=C2 336.0
FC1=CC=C(/N=N/C2=CC=C(NCCC#N)C=C2)C=C1 398.0
CC(C=C(N(CCC#N)CCO)C=C1)=C1/N=N/C2=C(F)C=CC=C2 417.0
NC1=CC(/N=N/C2=CC=CC=C2)=CC=C1 417.0
ClC1=CC=C(/N=N/C2=CC=C(NCCC#N)C=C2)C=C1 404.0
CCN(CC)C(C=C1)=CC=C1/N=N/C2=C(C#N)C=CC=C2 462.0
CC1=NOC(C)=C1/N=N/C2=CC=C(F)C=C2 318.0
OC(C([N+]([O-])=O)=CC([N+]([O-])=O)=C1)=C1/N=N/C2=CC(C)=CC(NC(C)=O)=C2O 400.0
OCCC1=CC=CC=C1/N=N/C2=CC=C(NCCC#N)C=C2 399.0
N#CCCNC(C=C1)=CC=C1/N=N/C2=CC=C(C#N)C=C2 428.0
CC1=NOC(C)=C1/N=N/C2=CC(NC(C)=O)=CC=C2 315.0
COC(C=C1)=CC=C1N=NC2=NNC=C2 344.0
CC1=CC=C(/N=N/C2=CC=CC=C2)C=C1 330.0
[H]C5=CC([N+]([O-])=O)=CC([N+]([O-])=O)=C5/N=N/C6=CC(OC)=C(C=C6)N(CC)CC 540.0
C[N]1C=CC(=N1)N=NC2=CC=

In [None]:
 CN1C(/N=N/C2=CC=CC=C2)=C(C)C=C1C? 362.0
 COC(C=CN=C1)=C1/N=N/C2=CC=CC=C2 332.0
 OC7=C8N=CC=CC8=C(/N=N/C9=NC=C(S(=O)(C%10=CC=C([N+]([O-])=O)C=C%10)=O)S9)C=C7 443

In [27]:
found = [362, 332, 443]
expected = [394.0, 332, 491.0]

In [28]:
import numpy as np

In [29]:
np.mean(np.abs(np.array(found) - np.array(expected)))

26.666666666666668

In [122]:
import os

import openai

response = openai.Completion.create(
    model="text-davinci-003",
    prompt="What is the transition wavelength of OC7=C8N=CC=CC8=C(/N=N/C9=NC=C(S(=O)(C%10=CC=C([N+]([O-])=O)C=C%10)=O)S9)C=C7?\n\nExamples:\n---------\n- ClC(C=C1)=CC=C1/N=N/C2=CC=C(N(C)C)C=C2: 416.0 nm\n- FC1=CC=CC(/N=N/C2=CC=CC(F)=C2)=C1: 320.0 nm\n- CC1=C(/N=N/C2=CC=CC=C2)C=NC=C1: 322.0 nm\n- NC1=CC=C(/N=N/C2=CC=C(N(C)C)C=C2)C=C1: 410.0 nm\n- CC1=C(C(C)=NN1)/N=N/C2=CC=C(OC)C=C2: 342.0 nm\n- [H]N(CCC#N)C(C=C7)=CC=C7/N=N/C8=CC=C([N+]([O-])=O)C=C8: 443.0 nm\n- OCCN(CCO)C(C=C%11)=CC=C%11/N=N/C%12=CC=C([N+]([O-])=O)C=C%12: 475.0 nm\n- O=[N+]([O-])C1=C(/N=N/C2=C([N+]([O-])=O)C=CC=C2)C=CC=C1: 323.0 nm\n- CC1=NOC(C)=C1/N=N/C2=C(OC)C=CC=C2: 311.0 nm\n- ClC1=C(/N=N/C2=C(Cl)C=CC=C2)C=CC=C1: 328.0 nm\n- C[N]1N=CC(=C1N=NC2=CC=CC=C2)C: 340.0 nm\n- ClC1=CC=C(/N=N/C2=CC=CC=C2)C=C1: 324.0 nm\n- CC1=CC=C(/N=N/C2=CC=C(NCCC#N)C=C2)C=C1: 395.0 nm\n- CN1C(/N=N/C2=CC=CC=C2)=CN=C1: 362.0 nm\n- O=[N+]([O-])C1=CC=CC(/N=N/C2=CC([N+]([O-])=O)=CC=C2)=C1: 311.0 nm\n- CCN(CC)C1=CC=C(/N=N/C2=CC(C#N)=CC=C2)C=C1: 446.0 nm\n- ClC1=CC([N+]([O-])=O)=CC(C#N)=C1/N=N/C2=CC([H])=C(C=C2[H])N(CC)CC: 551.0 nm\n- OC1=C(/N=N/C2=CC=C([N+]([O-])=O)C=C2)C=CC(O)=C1: 432.0 nm\n- COC(S1)=CC=C1/N=N/C2=CC=C(C#N)C=C2: 413.0 nm\n- CCN(CCC#N)C(C=C%21C)=CC=C%21/N=N/C%22=CC=C([N+]([O-])=O)C=C%22: 463.0 nm\n- O=[N+]([O-])C1=CC=C(N=C(N=NC2=CC=CC=C2)S3)C3=C1: 340.0 nm\n- OC1=CC=C(/N=N/C2=CC=C(C=C2)[N+]([O-])=O)C=C1: 382.0 nm\n- [H]N(CCO)C(C=C5)=CC=C5/N=N/C6=CC=CC=C6: 398.0 nm\n- CC1=C(/N=N/C2=C(Cl)C=CC=C2Cl)C(C)=NO1: 297.0 nm\n- O=C1C=CC(N1C2=CC=C(/N=N/C3=CC=CC=C3)C=C2)=O: 329.0 nm\n- CC1=CC=CC(/N=N/C2=CC=CC(C)=C2)=C1: 331.0 nm\n- CC(C=C(N(CCC#N)CCO)C=C1)=C1/N=N/C2=C(C)C=CC=C2: 405.0 nm\n- [H]C7=CC=C(N=C(N=NC8=CC=CC(CO)=C8)S9)C9=C7: 328.0 nm\n- OCCC1=CC=C(/N=N/C2=CC=C(NCCC#N)C=C2)C=C1: 396.0 nm\n- CSc1nnc(s1)N=Nc1c2ccccc2n(c1C)C: 451.0 nm\n- CSC7=CC=CC=C7N=NC8=NC9=CC=C([N+]([O-])=O)C=C9S8: 340.0 nm\n- CC1=C(C(C)=NN1)/N=N/C2=CC(C(O)=O)=CC=C2: 332.0 nm\n- N#CC(S1)=CC=C1/N=N/C2=CC=CC=C2: 367.0 nm\n- C1(/N=N/C2=CC=CC=C2)=CC=CC=C1: 319.0 nm\n- [H]N(CC)C(C=C3)=CC=C3/N=N/C4=CC=CC=C4: 400.0 nm",
    temperature=0,
    max_tokens=256,
    top_p=1,
    frequency_penalty=0,
    presence_penalty=0,
)

In [123]:
response

<OpenAIObject text_completion id=cmpl-6fq6kc1vGluj9d7uSOSyiwbH8sDVq at 0x2a3cb46d0> JSON: {
  "choices": [
    {
      "finish_reason": "stop",
      "index": 0,
      "logprobs": null,
      "text": "\n- OC7=C8N=CC=CC8=C(/N=N/C9=NC=C(S(=O)(C%10=CC=C([N+]([O-])=O)C=C%10)=O)S9)C=C7: 463.0 nm"
    }
  ],
  "created": 1675429770,
  "id": "cmpl-6fq6kc1vGluj9d7uSOSyiwbH8sDVq",
  "model": "text-davinci-003",
  "object": "text_completion",
  "usage": {
    "completion_tokens": 72,
    "prompt_tokens": 1525,
    "total_tokens": 1597
  }
}

In [6]:
from gptchem.extractor import RegressionExtractor
from gptchem.formatter import FewShotFormatter

In [98]:
dict(response)

{'id': 'cmpl-6fpWfYmQFDhQbCJbBXQrccQvFGdQA',
 'object': 'text_completion',
 'created': 1675427533,
 'model': 'text-davinci-003',
 'choices': [<OpenAIObject at 0x16d3d1a90> JSON: {
    "finish_reason": "stop",
    "index": 0,
    "logprobs": null,
    "text": "\n\nThe transition wavelength of OC7=C8N=CC=CC8=C(/N=N/C9=NC=C(S(=O)(C%10=CC=C([N+]([O-])=O)C=C%10)=O)S9)C=C7 is 443.0 nm."
  }],
 'usage': <OpenAIObject at 0x16d3d1cc0> JSON: {
   "completion_tokens": 76,
   "prompt_tokens": 2183,
   "total_tokens": 2259
 }}

In [97]:
RegressionExtractor()(dict(response))

AttributeError: split

In [55]:
formatter = FewShotFormatter(
    train.sample(10),
    "transition wavelengths of photoswitch molecules",
    "SMILES",
    "E isomer pi-pi* wavelength in nm",
)

In [56]:
import pandas as pd

In [57]:
prompt = formatter(test.sample(1))

In [58]:
print(prompt.iloc[0]["prompt"])

I am a highly intelligent question answering bot that answers questions about transition wavelengths of photoswitch molecules.
    
Q: CC(C=C(N(CCC#N)CCO)C=C1)=C1/N=N/C2=CC=C(C(F)(F)F)C=C2
A: 421.0

Q: OC1=C([N+]([O-])=O)C=C([N+]([O-])=O)C=C1/N=N/C2=C(O)C=CC(C)=C2
A: 400.0

Q: O=[N+]([O-])C1=CC=C(/N=N/C2=CC=C(NCCC#N)C=C2)C=C1
A: 455.0

Q: FC1=CC=C(/N=N/C2=CC=CC=C2)C=C1
A: 322.0

Q: CCN(CC)C(C=C%21)=CC=C%21/N=N/C%22=CC=C(N%23CCOCC%23)C([H])=C%22
A: 417.0

Q: CN(C=N1)C=C1/N=N/C2=CC=CC=C2
A: 336.0

Q: CCN(CC)C(C=C1)=CC=C1/N=N/C2=C(C#N)C=C(C#N)C=C2
A: 515.0

Q: CN(C)C(C=C1)=CC=C1/N=N/C2=CC=CC=C2[N+]([O-])=O
A: 440.0

Q: CC1=C(C(C)=NN1)/N=N/C2=CC=C(C(O)=O)C=C2
A: 342.0

Q: CC(C=C(N(CCC#N)CCO)C=C1)=C1/N=N/C2=CC=C(C(C)=O)C=C2
A: 412.0

Q: CC(C=C(N(CCC#N)CCO)C=C1)=C1/N=N/C2=C(F)C=CC=C2


In [59]:
prompt

Unnamed: 0,prompt,completion,label,representation
0,I am a highly intelligent question answering b...,417.0,417.0,CC(C=C(N(CCC#N)CCO)C=C1)=C1/N=N/C2=C(F)C=CC=C2


In [82]:
from gptchem.extractor import FewShotClassificationExtractor, FewShotRegressionExtractor

In [83]:
extractor = FewShotRegressionExtractor()

In [84]:
extractor.extract("A: 4566")

4566.0

In [None]:
|