In [1]:
%load_ext autoreload
%autoreload 2

In [28]:
from sklearn.model_selection import train_test_split

from gptchem.data import get_matbench_is_metal, get_photoswitch_data
from gptchem.querier import Querier

In [44]:
PROMPT_TEMPLATE = """Is {smiles} metallicity (based on bandgap)?

Examples:
---------
{examples}
"""

In [45]:
data = get_matbench_is_metal()

In [46]:
data = data.dropna(subset=["is_metal"])

In [47]:
train, test = train_test_split(data, test_size=0.2, random_state=42)

In [48]:
data

Unnamed: 0,composition,is_metal
0,Ag(AuS)2,True
1,Ag(W3Br7)2,True
2,Ag0.5Ge1Pb1.75S4,False
3,Ag0.5Ge1Pb1.75Se4,False
4,Ag2BBr,True
...,...,...
4916,ZrTaN3,False
4917,ZrTe,True
4918,ZrTi2O,True
4919,ZrTiF6,True


In [49]:
def create_example_string(
    data,
    num_examples: int = 10,
    representation_col: str = "composition",
    value_col: str = "is_metal",
):
    examples = []
    for i, row in data.sample(num_examples).iterrows():
        examples.append(f"- {row[representation_col]}: {row[value_col]}")
    return "\n".join(examples)

In [50]:
print(create_example_string(data))

- Na11Ti20O40: True
- LiMgSnPd: True
- U5Ge3: True
- KP(HO2)2: False
- Ba4Ga4SnSe12: False
- Ca2Os2O7: True
- CrPbO4: False
- Ni(PO3)4: True
- Ba3Sn0.87Bi2.13Se8: False
- NaSbSe2: False


In [79]:
prompt = PROMPT_TEMPLATE.format(
    smiles=test["composition"].iloc[2], examples=create_example_string(train, 50)
)

In [80]:
", ".join(test.iloc[:50]["composition"].to_list())

'BaAg2, ZrTe, Ga1Sb0.14As0.86, Nb2Tl4S11, U4S3, CoP2, Li9Ga13(Te7O22)3, CdGeP2, Sm2Se3, Ti2Be17, LiMnP2O7, Si15(TeP2)4, Sc6NiTe2, Cu9Se4(Cl3O7)2, In0.01Ga0.99As1, Fe3Si, Ti3Be, Y(AlSi)2, GeSe2, SrAg, Bi2Se3, Cs2Te, YSn3, Tl2O3, Mg2Zr14O5, Ca5Ir, NdMgNi4, ErB2Ir3, Ba3PN, Th(FeSi)2, Ge(SeO3)2, LiInTe2, HPbI3, CsAg2AsS3, CsNbSe2O7, Hf5Sn4, SmCo2, K2VCuS4, MgCr, ZrTi2O, Ca5Au2, SnB, CoP3, Al0.5In0.5P1, As2S3, Sr2Co2O5, Ga2O3, Zn0.9Ga0.1P0.1Se0.9, LiZn(Fe5O8)2, WO3'

In [81]:
", ".join(test.iloc[:50]["is_metal"].astype(str).to_list())

'True, True, False, False, True, False, False, False, False, True, True, False, True, True, False, True, True, True, False, True, False, False, True, False, True, True, True, True, True, True, False, False, True, False, False, True, True, False, True, True, True, True, False, False, False, True, False, False, True, False'

In [82]:
print(prompt)

Is Ga1Sb0.14As0.86 metallicity (based on bandgap)?

Examples:
---------
- LaVI5O16: False
- Tl2CdTe4: True
- Li3Fe(SbO3)4: True
- InSe: False
- Cs2NaMnF6: True
- CI4: False
- Y2Ge2O7: True
- La2V2IO9: False
- Er2(MoO4)3: False
- Cu9O13: True
- InPS4: False
- LiV3O4: True
- Nd(SiIr)2: True
- CeNbO4: True
- YCu3(WO3)4: True
- VB2: True
- YbZnAu2: True
- In0.1Ga0.9As0.9P0.1: False
- TiGaIr2: True
- Ti2Cd: True
- Ba2B6H4O13: False
- Cs2SCl6F: True
- Sn0.08Te1Pb0.92: False
- GaBi25O39: False
- CsGdO3: True
- ErB6: False
- In2S3: False
- Li8GeN4: False
- Nd5Ge4: True
- Zn0.25Ga0.75P0.75Se0.25: False
- Sr2TiO4: False
- Li5Mn5(SbO6)2: True
- LaCuSeO: False
- GaN: False
- Tm2Ru2O7: True
- Fe3B: True
- Zn0.94Hg0.06Te1: False
- Eu3BWO9: True
- TmSnRh: True
- ZrGe2: True
- Ag7(SI)2: True
- Cd0.06In0.94Te0.06As0.94: False
- MgF2: False
- K2Sn(AuS2)2: False
- SrGe2: False
- Zn0.86Hg0.14Te1: False
- Li2U(MoO5)2: True
- TbTe: True
- TaGaPt: True
- EuCuSeF: True



In [22]:
querier = Querier("ada", max_tokens=600)

In [26]:
for i, row in test.iterrows():
    print(row["SMILES"], row["E isomer pi-pi* wavelength in nm"])

CN1C(/N=N/C2=CC=CC=C2)=C(C)C=C1C 394.0
COC(C=CN=C1)=C1/N=N/C2=CC=CC=C2 332.0
FC1=C(F)C=C(F)C(F)=C1/N=N/C2=C(F)C(F)=CC(F)=C2F 303.0
CC(S1)=CC=C1/N=N/C2=CC=C(C(F)(F)F)C=C2 378.0
N#CCCNC(C=C1)=CC=C1/N=N/C2=CC=CC=C2OC 394.0
OC7=C8N=CC=CC8=C(/N=N/C9=NC=C(S(=O)(C%10=CC=C([N+]([O-])=O)C=C%10)=O)S9)C=C7 491.0
CC1=C(C(C)=NN1)/N=N/C2=CC=C(Cl)C=C2 336.0
FC1=CC=C(/N=N/C2=CC=C(NCCC#N)C=C2)C=C1 398.0
CC(C=C(N(CCC#N)CCO)C=C1)=C1/N=N/C2=C(F)C=CC=C2 417.0
NC1=CC(/N=N/C2=CC=CC=C2)=CC=C1 417.0
ClC1=CC=C(/N=N/C2=CC=C(NCCC#N)C=C2)C=C1 404.0
CCN(CC)C(C=C1)=CC=C1/N=N/C2=C(C#N)C=CC=C2 462.0
CC1=NOC(C)=C1/N=N/C2=CC=C(F)C=C2 318.0
OC(C([N+]([O-])=O)=CC([N+]([O-])=O)=C1)=C1/N=N/C2=CC(C)=CC(NC(C)=O)=C2O 400.0
OCCC1=CC=CC=C1/N=N/C2=CC=C(NCCC#N)C=C2 399.0
N#CCCNC(C=C1)=CC=C1/N=N/C2=CC=C(C#N)C=C2 428.0
CC1=NOC(C)=C1/N=N/C2=CC(NC(C)=O)=CC=C2 315.0
COC(C=C1)=CC=C1N=NC2=NNC=C2 344.0
CC1=CC=C(/N=N/C2=CC=CC=C2)C=C1 330.0
[H]C5=CC([N+]([O-])=O)=CC([N+]([O-])=O)=C5/N=N/C6=CC(OC)=C(C=C6)N(CC)CC 540.0
C[N]1C=CC(=N1)N=NC2=CC=

In [None]:
 CN1C(/N=N/C2=CC=CC=C2)=C(C)C=C1C? 362.0
 COC(C=CN=C1)=C1/N=N/C2=CC=CC=C2 332.0
 OC7=C8N=CC=CC8=C(/N=N/C9=NC=C(S(=O)(C%10=CC=C([N+]([O-])=O)C=C%10)=O)S9)C=C7 443

In [27]:
found = [362, 332, 443]
expected = [394.0, 332, 491.0]

In [28]:
import numpy as np

In [29]:
np.mean(np.abs(np.array(found) - np.array(expected)))

26.666666666666668

In [122]:
import os

import openai

response = openai.Completion.create(
    model="text-davinci-003",
    prompt="What is the transition wavelength of OC7=C8N=CC=CC8=C(/N=N/C9=NC=C(S(=O)(C%10=CC=C([N+]([O-])=O)C=C%10)=O)S9)C=C7?\n\nExamples:\n---------\n- ClC(C=C1)=CC=C1/N=N/C2=CC=C(N(C)C)C=C2: 416.0 nm\n- FC1=CC=CC(/N=N/C2=CC=CC(F)=C2)=C1: 320.0 nm\n- CC1=C(/N=N/C2=CC=CC=C2)C=NC=C1: 322.0 nm\n- NC1=CC=C(/N=N/C2=CC=C(N(C)C)C=C2)C=C1: 410.0 nm\n- CC1=C(C(C)=NN1)/N=N/C2=CC=C(OC)C=C2: 342.0 nm\n- [H]N(CCC#N)C(C=C7)=CC=C7/N=N/C8=CC=C([N+]([O-])=O)C=C8: 443.0 nm\n- OCCN(CCO)C(C=C%11)=CC=C%11/N=N/C%12=CC=C([N+]([O-])=O)C=C%12: 475.0 nm\n- O=[N+]([O-])C1=C(/N=N/C2=C([N+]([O-])=O)C=CC=C2)C=CC=C1: 323.0 nm\n- CC1=NOC(C)=C1/N=N/C2=C(OC)C=CC=C2: 311.0 nm\n- ClC1=C(/N=N/C2=C(Cl)C=CC=C2)C=CC=C1: 328.0 nm\n- C[N]1N=CC(=C1N=NC2=CC=CC=C2)C: 340.0 nm\n- ClC1=CC=C(/N=N/C2=CC=CC=C2)C=C1: 324.0 nm\n- CC1=CC=C(/N=N/C2=CC=C(NCCC#N)C=C2)C=C1: 395.0 nm\n- CN1C(/N=N/C2=CC=CC=C2)=CN=C1: 362.0 nm\n- O=[N+]([O-])C1=CC=CC(/N=N/C2=CC([N+]([O-])=O)=CC=C2)=C1: 311.0 nm\n- CCN(CC)C1=CC=C(/N=N/C2=CC(C#N)=CC=C2)C=C1: 446.0 nm\n- ClC1=CC([N+]([O-])=O)=CC(C#N)=C1/N=N/C2=CC([H])=C(C=C2[H])N(CC)CC: 551.0 nm\n- OC1=C(/N=N/C2=CC=C([N+]([O-])=O)C=C2)C=CC(O)=C1: 432.0 nm\n- COC(S1)=CC=C1/N=N/C2=CC=C(C#N)C=C2: 413.0 nm\n- CCN(CCC#N)C(C=C%21C)=CC=C%21/N=N/C%22=CC=C([N+]([O-])=O)C=C%22: 463.0 nm\n- O=[N+]([O-])C1=CC=C(N=C(N=NC2=CC=CC=C2)S3)C3=C1: 340.0 nm\n- OC1=CC=C(/N=N/C2=CC=C(C=C2)[N+]([O-])=O)C=C1: 382.0 nm\n- [H]N(CCO)C(C=C5)=CC=C5/N=N/C6=CC=CC=C6: 398.0 nm\n- CC1=C(/N=N/C2=C(Cl)C=CC=C2Cl)C(C)=NO1: 297.0 nm\n- O=C1C=CC(N1C2=CC=C(/N=N/C3=CC=CC=C3)C=C2)=O: 329.0 nm\n- CC1=CC=CC(/N=N/C2=CC=CC(C)=C2)=C1: 331.0 nm\n- CC(C=C(N(CCC#N)CCO)C=C1)=C1/N=N/C2=C(C)C=CC=C2: 405.0 nm\n- [H]C7=CC=C(N=C(N=NC8=CC=CC(CO)=C8)S9)C9=C7: 328.0 nm\n- OCCC1=CC=C(/N=N/C2=CC=C(NCCC#N)C=C2)C=C1: 396.0 nm\n- CSc1nnc(s1)N=Nc1c2ccccc2n(c1C)C: 451.0 nm\n- CSC7=CC=CC=C7N=NC8=NC9=CC=C([N+]([O-])=O)C=C9S8: 340.0 nm\n- CC1=C(C(C)=NN1)/N=N/C2=CC(C(O)=O)=CC=C2: 332.0 nm\n- N#CC(S1)=CC=C1/N=N/C2=CC=CC=C2: 367.0 nm\n- C1(/N=N/C2=CC=CC=C2)=CC=CC=C1: 319.0 nm\n- [H]N(CC)C(C=C3)=CC=C3/N=N/C4=CC=CC=C4: 400.0 nm",
    temperature=0,
    max_tokens=256,
    top_p=1,
    frequency_penalty=0,
    presence_penalty=0,
)

In [123]:
response

<OpenAIObject text_completion id=cmpl-6fq6kc1vGluj9d7uSOSyiwbH8sDVq at 0x2a3cb46d0> JSON: {
  "choices": [
    {
      "finish_reason": "stop",
      "index": 0,
      "logprobs": null,
      "text": "\n- OC7=C8N=CC=CC8=C(/N=N/C9=NC=C(S(=O)(C%10=CC=C([N+]([O-])=O)C=C%10)=O)S9)C=C7: 463.0 nm"
    }
  ],
  "created": 1675429770,
  "id": "cmpl-6fq6kc1vGluj9d7uSOSyiwbH8sDVq",
  "model": "text-davinci-003",
  "object": "text_completion",
  "usage": {
    "completion_tokens": 72,
    "prompt_tokens": 1525,
    "total_tokens": 1597
  }
}

In [6]:
from gptchem.extractor import RegressionExtractor
from gptchem.formatter import FewShotFormatter

In [98]:
dict(response)

{'id': 'cmpl-6fpWfYmQFDhQbCJbBXQrccQvFGdQA',
 'object': 'text_completion',
 'created': 1675427533,
 'model': 'text-davinci-003',
 'choices': [<OpenAIObject at 0x16d3d1a90> JSON: {
    "finish_reason": "stop",
    "index": 0,
    "logprobs": null,
    "text": "\n\nThe transition wavelength of OC7=C8N=CC=CC8=C(/N=N/C9=NC=C(S(=O)(C%10=CC=C([N+]([O-])=O)C=C%10)=O)S9)C=C7 is 443.0 nm."
  }],
 'usage': <OpenAIObject at 0x16d3d1cc0> JSON: {
   "completion_tokens": 76,
   "prompt_tokens": 2183,
   "total_tokens": 2259
 }}

In [97]:
RegressionExtractor()(dict(response))

AttributeError: split

In [55]:
formatter = FewShotFormatter(
    train.sample(10),
    "transition wavelengths of photoswitch molecules",
    "SMILES",
    "E isomer pi-pi* wavelength in nm",
)

In [56]:
import pandas as pd

In [57]:
prompt = formatter(test.sample(1))

In [58]:
print(prompt.iloc[0]["prompt"])

I am a highly intelligent question answering bot that answers questions about transition wavelengths of photoswitch molecules.
    
Q: CC(C=C(N(CCC#N)CCO)C=C1)=C1/N=N/C2=CC=C(C(F)(F)F)C=C2
A: 421.0

Q: OC1=C([N+]([O-])=O)C=C([N+]([O-])=O)C=C1/N=N/C2=C(O)C=CC(C)=C2
A: 400.0

Q: O=[N+]([O-])C1=CC=C(/N=N/C2=CC=C(NCCC#N)C=C2)C=C1
A: 455.0

Q: FC1=CC=C(/N=N/C2=CC=CC=C2)C=C1
A: 322.0

Q: CCN(CC)C(C=C%21)=CC=C%21/N=N/C%22=CC=C(N%23CCOCC%23)C([H])=C%22
A: 417.0

Q: CN(C=N1)C=C1/N=N/C2=CC=CC=C2
A: 336.0

Q: CCN(CC)C(C=C1)=CC=C1/N=N/C2=C(C#N)C=C(C#N)C=C2
A: 515.0

Q: CN(C)C(C=C1)=CC=C1/N=N/C2=CC=CC=C2[N+]([O-])=O
A: 440.0

Q: CC1=C(C(C)=NN1)/N=N/C2=CC=C(C(O)=O)C=C2
A: 342.0

Q: CC(C=C(N(CCC#N)CCO)C=C1)=C1/N=N/C2=CC=C(C(C)=O)C=C2
A: 412.0

Q: CC(C=C(N(CCC#N)CCO)C=C1)=C1/N=N/C2=C(F)C=CC=C2


In [59]:
prompt

Unnamed: 0,prompt,completion,label,representation
0,I am a highly intelligent question answering b...,417.0,417.0,CC(C=C(N(CCC#N)CCO)C=C1)=C1/N=N/C2=C(F)C=CC=C2


In [82]:
from gptchem.extractor import FewShotClassificationExtractor, FewShotRegressionExtractor

In [83]:
extractor = FewShotRegressionExtractor()

In [84]:
extractor.extract("A: 4566")

4566.0

In [None]:
|