In [9]:
from datasets import load_from_disk
from torch_geometric.data import Data
import re
import selfies as sf
from rdkit import Chem
import torch
from tqdm import tqdm

testset_path = '/data/text-mol/data/Mol-LLM-v7.1/mistralai-Mistral-7B-Instruct-v0.3_string+graph_q32_test_3.3M_0415'
testset = load_from_disk(testset_path)
remove_tasks = ['alchemy_homo',
 'alchemy_homo_lumo_gap',
 'alchemy_lumo',]
testset = testset.filter(lambda x: x["task"] not in remove_tasks, num_proc=50)
testset

Dataset({
    features: ['task', 'x', 'edge_index', 'edge_attr', 'additional_x', 'additional_edge_index', 'additional_edge_attr', 'input_mol_string', 'prompt_text', 'target_text'],
    num_rows: 55757
})

In [2]:
tasks = set(testset["task"])
tasks

{'aqsol-logS',
 'bace',
 'chebi-20-mol2text',
 'chebi-20-text2mol',
 'forward_reaction_prediction',
 'orderly-forward_reaction_prediction',
 'orderly-retrosynthesis',
 'presto-forward_reaction_prediction',
 'presto-retrosynthesis',
 'qm9_homo',
 'qm9_homo_lumo_gap',
 'qm9_lumo',
 'reagent_prediction',
 'retrosynthesis',
 'smol-forward_synthesis',
 'smol-molecule_captioning',
 'smol-molecule_generation',
 'smol-property_prediction-bbbp',
 'smol-property_prediction-clintox',
 'smol-property_prediction-esol',
 'smol-property_prediction-hiv',
 'smol-property_prediction-lipo',
 'smol-property_prediction-sider',
 'smol-retrosynthesis'}

In [3]:
classification_tasks = {
 'bace',
 'smol-property_prediction-bbbp',
 'smol-property_prediction-clintox',
 'smol-property_prediction-hiv',
 'smol-property_prediction-sider',
 }

In [4]:
unique_tasks = set(testset["task"])
unique_tasks = {
 'bace',
 'chebi-20-mol2text',
 'chebi-20-text2mol',
 'forward_reaction_prediction',
 'reagent_prediction',
 }

In [5]:
task_specific_instances = {}
import tqdm

iter_bar = tqdm.tqdm(unique_tasks)
for task in iter_bar:
    task_instance = testset.filter(lambda x: x["task"] == task, num_proc=50)
    task_specific_instances.update({task: task_instance})


task_examples = []
for task in unique_tasks:
    task_examples.append(task_specific_instances[task][0])
for example in task_examples:
    print(example["target_text"])

Filter (num_proc=50): 100%|██████████| 55757/55757 [00:01<00:00, 34253.49 examples/s]
Filter (num_proc=50): 100%|██████████| 55757/55757 [00:01<00:00, 36687.21 examples/s]
Filter (num_proc=50): 100%|██████████| 55757/55757 [00:01<00:00, 39025.73 examples/s]
Filter (num_proc=50): 100%|██████████| 55757/55757 [00:01<00:00, 38991.99 examples/s]
Filter (num_proc=50): 100%|██████████| 55757/55757 [00:01<00:00, 38741.57 examples/s]
100%|██████████| 5/5 [00:12<00:00,  2.44s/it]


<SELFIES> [O][C][C][=C][C][=C][Branch1][O][S][C][=C][C][=C][C][=C][Ring1][=Branch1][Br][C][=C][Ring1][=C] </SELFIES> </s>
<DESCRIPTION> The molecule is a L-serine derivative obtained by formal condensation between N-butyl-L-serinamide and 2-thienylacetic acid. It is a member of thiophenes, a monocarboxylic acid amide and a L-serine derivative. </DESCRIPTION> </s>
<SELFIES> [C][C][C][C][=N][C][C][C][N][Ring1][=Branch1][C][C][Ring1][O].[C][N][Branch1][C][C][C][=O].[Cl] </SELFIES> </s>
<BOOLEAN> True </BOOLEAN> </s>
<SELFIES> [N][C][=C][Branch1][=Branch1][N+1][=Branch1][C][=O][O-1][C][=C][C][Branch1][#Branch2][O][C][=C][C][=C][C][=C][Ring1][=Branch1][=C][Ring1][S][Cl] </SELFIES> </s>


In [6]:
def process_instance(instance):
    x = instance['x']
    # make x a tensor
    x = torch.tensor(x)
    edge_index = instance['edge_index']
    # make edge_index a tensor
    edge_index = torch.tensor(edge_index)
    edge_attr = instance['edge_attr']
    # make edge_attr a tensor
    edge_attr = torch.tensor(edge_attr)

    task = instance['task']
    selfies = instance['input_mol_string']
    prompt_text = instance['prompt_text']
    target_text = instance['target_text']

    selfies_remove_patterns = [
        r"<SELFIES>\s*",
        r"\s*</SELFIES>",
    ]
    # using re.sub to remove the pattern
    for remove_pattern in selfies_remove_patterns:
        selfies = re.sub(remove_pattern, '', selfies)
    try:
        smiles = sf.decoder(selfies)
        mol = Chem.MolFromSmiles(smiles)
        smiles = Chem.MolToSmiles(mol, isomericSmiles=False, canonical=True)
    except:
        print(selfies)

    system_prompt = "You are a helpful assistant for molecular chemistry, to address tasks including molecular property classification, molecular property regression, chemical reaction prediction, molecule captioning, molecule generation."
    prompt_remove_patterns = [
        system_prompt,
        r"\n",
        r"<mol>",
        r"<s>",
        r"\[INST\]\s*",
        r"\s*\[/INST\]\s*",
        r"<GRAPH>.*</GRAPH>",
        r"^\s*",
        r"\s*$",
        r"<DESCRIPTION>\s*",
        r"\s*</DESCRIPTION>",
    ]
    # using re.sub to remove the pattern
    for remove_pattern in prompt_remove_patterns:
        prompt_text = re.sub(remove_pattern, '', prompt_text)

    prompt_replace_patterns = [
        [r"<SELFIES>.*</SELFIES>", "<INPUT>"],
    ]
    for pattern, replacement in prompt_replace_patterns:
        prompt_text = re.sub(pattern, replacement, prompt_text)

    instruction = prompt_text

    target_remove_patterns = [
        r"\s*</s>\s*",
        r"<SELFIES>\s*",
        r"\s*</SELFIES>",
        r"<BOOLEAN>\s*",
        r"\s*</BOOLEAN>",
        r"<FLOAT>\s*",
        r"s*</FLOAT>",
        r"<DESCRIPTION>\s*",
        r"\s*</DESCRIPTION>",
        r"^\s*",
        r"\s*$",
    ]
    # using re.sub to remove the pattern
    for remove_pattern in target_remove_patterns:
        target_text = re.sub(remove_pattern, '', target_text)

    target_replace_patterns = [
        ["<|+|>", "+"],
        ["<|-|>", "-"],
        ["<|.|>", "."],
        ["<|0|>", "0"],
        ["<|1|>", "1"],
        ["<|2|>", "2"],
        ["<|3|>", "3"],
        ["<|4|>", "4"],
        ["<|5|>", "5"],
        ["<|6|>", "6"],
        ["<|7|>", "7"],
        ["<|8|>", "8"],
        ["<|9|>", "9"],
    ]
    for pattern, replacement in target_replace_patterns:
        target_text = target_text.replace(pattern, replacement)
    output = target_text

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, task=task, smiles=smiles, instruction=instruction, output=output)
    return data

In [11]:


processed_testset = []
iter_bar = tqdm(testset)
for example in iter_bar:
    processed_testset.append(process_instance(example))

processed_testset[0]

100%|██████████| 55757/55757 [01:11<00:00, 783.99it/s] 


Data(x=[11, 9], edge_index=[2, 22], edge_attr=[22, 3], task='aqsol-logS', smiles='Cc1ccc(C(C)(C)C)cc1', instruction='Predict the log solubility of <INPUT> in water.', output='-4.4720')

In [12]:
from torch_geometric.data import InMemoryDataset, Data
import torch

class MyDataset(InMemoryDataset):
    def __init__(self, data_list):
        # No need for root; we're not reading/writing from disk
        super().__init__(None)
        self.data, self.slices = self.collate(data_list)

inmemory_testset = MyDataset(processed_testset)
# save the dataset
inmemory_testset_path = '/data/text-mol/data/Mol-LLM-v7.1/llamo_test/test.pt'
inmemory_trainset_path = '/data/text-mol/data/Mol-LLM-v7.1/llamo_test/train.pt'
torch.save((inmemory_testset.data, inmemory_testset.slices), inmemory_testset_path)
torch.save((inmemory_testset.data, inmemory_testset.slices), inmemory_trainset_path)




In [20]:
processed_testset[0]

Data(x=[11, 9], edge_index=[2, 22], edge_attr=[22, 3], task='aqsol-logS', smiles='Cc1ccc(C(C)(C)C)cc1', instruction='Predict the log solubility of <INPUT> in water.', output='-4.4720')

In [13]:
import datasets


classification_testset = testset.filter(lambda x: x["task"] in classification_tasks, num_proc=50)
classification_testset_path = '/data/text-mol/data/Mol-LLM-v7.1/mistralai-Mistral-7B-Instruct-v0.3_string+graph_q32_test_3.3M_classification'
classification_testset.save_to_disk(classification_testset_path)

Saving the dataset (1/1 shards): 100%|██████████| 7460/7460 [00:00<00:00, 15772.88 examples/s]


In [21]:
processed_classification_testset = []
iter_bar = tqdm(classification_testset)
for example in iter_bar:
    processed_classification_testset.append(process_instance(example))

processed_classification_testset[0]

100%|██████████| 7460/7460 [00:11<00:00, 632.99it/s]


Data(x=[35, 9], edge_index=[2, 78], edge_attr=[78, 3], task='smol-property_prediction-hiv', smiles='CC1=C(CO)C(=O)OC(C(CO)C2CCC3C4CC(O)C5(O)CC=CC(=O)C5(C)C4CCC23C)C1', instruction='Does <INPUT> inhibit viral replication for HIV?', output='False')

In [33]:
inmemory_classification_testset = MyDataset(processed_classification_testset)

classification_testset_path = '/data/text-mol/data/Mol-LLM-v7.1/llamo_test_classification/test.pt'
classification_trainset_path = '/data/text-mol/data/Mol-LLM-v7.1/llamo_test_classification/train.pt'
torch.save((inmemory_classification_testset.data, inmemory_classification_testset.slices), classification_testset_path)
torch.save((inmemory_classification_testset.data, inmemory_classification_testset.slices), classification_trainset_path)