In [1]:
import json
import random
import re

from frozendict import frozendict

from pubchem_scraper.augment import augment
from pubchem_scraper.datatypes import Example, Molecule, merge_molecules
from pubchem_scraper.pubchem_schema import SimpleElement, SimpleStringWithMarkup

In [2]:
with open("./data/selected.json") as f:
    data = json.load(f)
    data: list[SimpleElement] = [SimpleElement.model_validate(x) for x in data]

with open("./data/prompt.md") as f:
    prompt = f.read()

In [3]:
def is_aliased(name: str) -> bool:
    regex = r"^(compound|ligand|derivative|complex|pyrazole|amide|urea|hydroxyl|ketone|pyridazinone|piperazine|cyclohexyl|ester|acid|analog|conjugate|inhibitor)( (compound|ligand|derivative|complex|pyrazole|amide|urea|hydroxyl|ketone|pyridazinone|piperazine|cyclohexyl|ester|acid|analog|conjugate|inhibitor))? [1-9][0-9]?[a-z]?$"  # noqa: E501
    return re.match(regex, name) is not None


def create_ft_example(element: SimpleStringWithMarkup):
    string = element.string

    mols = []
    for markup in element.markup:
        name = markup.comp_hit(string)
        m = re.match(r"(.*) \((.*)\)", name)
        if m:
            name, anything = m.groups()
            anything = [anything]
        else:
            anything = []

        if is_aliased(name):
            anything.append(re.search(r"\d+[a-z]?$", name).group(0))  # type: ignore

        mols.append(Molecule(name=name, alternatives=anything))

    mols = merge_molecules(mols)
    return Example(
        sys_prompt=prompt,
        user_prompt=string,
        response=json.dumps([m.model_dump() for m in mols], separators=(",", ":")),
    )

In [4]:
training_data = []

# Add the original data
for element in data:
    training_data.append(create_ft_example(element.string))

# Add with augments
for element in data:
    string = element.string
    try:
        augmented = augment(string, n=2)
    except Exception:
        augmented = string

    training_data.append(create_ft_example(augmented))

In [5]:
conversations = [
    (
        frozendict({"role": "system", "content": ex.sys_prompt}),
        frozendict({"role": "user", "content": ex.user_prompt}),
        frozendict({"role": "assistant", "content": ex.response}),
    )
    for ex in training_data
]
conversations = list(set(conversations))

In [6]:
random.shuffle(conversations)

train = conversations[: int(0.8 * len(conversations))]
valid = conversations[int(0.8 * len(conversations)) :]

In [7]:
with open("./data/conversations_train.json", "w") as f:
    json.dump([{"messages": ex} for ex in train], f, indent=2)

with open("./data/conversations_valid.json", "w") as f:
    json.dump([{"messages": ex} for ex in valid], f, indent=2)