In [1]:
from datasets import load_from_disk
from torch_geometric.data import Data
import re
import selfies as sf
from rdkit import Chem
import torch
from tqdm import tqdm

testset_path = '/data/text-mol/data/Mol-LLM-v7.1/mistralai-Mistral-7B-Instruct-v0.3_string+graph_q32_test_3.3M_0415'
testset = load_from_disk(testset_path)
remove_tasks = ['alchemy_homo',
 'alchemy_homo_lumo_gap',
 'alchemy_lumo',]
testset = testset.filter(lambda x: x["task"] not in remove_tasks, num_proc=50)
testset

  from .autonotebook import tqdm as notebook_tqdm


Dataset({
    features: ['task', 'x', 'edge_index', 'edge_attr', 'additional_x', 'additional_edge_index', 'additional_edge_attr', 'input_mol_string', 'prompt_text', 'target_text'],
    num_rows: 55757
})

In [2]:
tasks = set(testset["task"])
tasks

{'aqsol-logS',
 'bace',
 'chebi-20-mol2text',
 'chebi-20-text2mol',
 'forward_reaction_prediction',
 'orderly-forward_reaction_prediction',
 'orderly-retrosynthesis',
 'presto-forward_reaction_prediction',
 'presto-retrosynthesis',
 'qm9_homo',
 'qm9_homo_lumo_gap',
 'qm9_lumo',
 'reagent_prediction',
 'retrosynthesis',
 'smol-forward_synthesis',
 'smol-molecule_captioning',
 'smol-molecule_generation',
 'smol-property_prediction-bbbp',
 'smol-property_prediction-clintox',
 'smol-property_prediction-esol',
 'smol-property_prediction-hiv',
 'smol-property_prediction-lipo',
 'smol-property_prediction-sider',
 'smol-retrosynthesis'}

In [3]:
classification_tasks = {
 'bace',
 'smol-property_prediction-bbbp',
 'smol-property_prediction-clintox',
 'smol-property_prediction-hiv',
 'smol-property_prediction-sider',
 }

In [4]:
unique_tasks = set(testset["task"])
unique_tasks = {
 'bace',
 'chebi-20-mol2text',
 'chebi-20-text2mol',
 'forward_reaction_prediction',
 'reagent_prediction',
 }

In [5]:
import numpy as np

def process_instance(instance):
    # Normalize Arrow->Python variability
    for k, v in instance.items():
        # Convert numpy arrays to lists
        if isinstance(v, np.ndarray):
            instance[k] = v.tolist()
            
    x = instance['x']
    # make x a tensor
    x = torch.tensor(x, dtype=torch.float32).tolist()
    edge_index = instance['edge_index']
    # make edge_index a tensor
    edge_index = torch.tensor(edge_index, dtype=torch.long).tolist()
    edge_attr = instance['edge_attr']
    # make edge_attr a tensor
    edge_attr = torch.tensor(edge_attr, dtype=torch.float32).tolist()

    task = instance['task']
    selfies = instance['input_mol_string']
    prompt_text = instance['prompt_text']
    target_text = instance['target_text']

    selfies_remove_patterns = [
        r"<SELFIES>\s*",
        r"\s*</SELFIES>",
    ]
    # using re.sub to remove the pattern
    for remove_pattern in selfies_remove_patterns:
        selfies = re.sub(remove_pattern, '', selfies)
    try:
        smiles = sf.decoder(selfies)
        mol = Chem.MolFromSmiles(smiles)
        smiles = Chem.MolToSmiles(mol, isomericSmiles=False, canonical=True)
    except:
        print(selfies)

    system_prompt = "You are a helpful assistant for molecular chemistry, to address tasks including molecular property classification, molecular property regression, chemical reaction prediction, molecule captioning, molecule generation."
    prompt_remove_patterns = [
        system_prompt,
        r"\n",
        r"<mol>",
        r"<s>",
        r"\[INST\]\s*",
        r"\s*\[/INST\]\s*",
        r"<GRAPH>.*</GRAPH>",
        r"^\s*",
        r"\s*$",
        r"<DESCRIPTION>\s*",
        r"\s*</DESCRIPTION>",
    ]
    # using re.sub to remove the pattern
    for remove_pattern in prompt_remove_patterns:
        prompt_text = re.sub(remove_pattern, '', prompt_text)

    prompt_replace_patterns = [
        [r"<SELFIES>.*</SELFIES>", "<INPUT>"],
    ]
    for pattern, replacement in prompt_replace_patterns:
        prompt_text = re.sub(pattern, replacement, prompt_text)

    instruction = prompt_text

    target_remove_patterns = [
        r"\s*</s>\s*",
        r"<SELFIES>\s*",
        r"\s*</SELFIES>",
        r"<BOOLEAN>\s*",
        r"\s*</BOOLEAN>",
        r"<FLOAT>\s*",
        r"s*</FLOAT>",
        r"<DESCRIPTION>\s*",
        r"\s*</DESCRIPTION>",
        r"^\s*",
        r"\s*$",
    ]
    # using re.sub to remove the pattern
    for remove_pattern in target_remove_patterns:
        target_text = re.sub(remove_pattern, '', target_text)

    target_replace_patterns = [
        ["<|+|>", "+"],
        ["<|-|>", "-"],
        ["<|.|>", "."],
        ["<|0|>", "0"],
        ["<|1|>", "1"],
        ["<|2|>", "2"],
        ["<|3|>", "3"],
        ["<|4|>", "4"],
        ["<|5|>", "5"],
        ["<|6|>", "6"],
        ["<|7|>", "7"],
        ["<|8|>", "8"],
        ["<|9|>", "9"],
    ]
    for pattern, replacement in target_replace_patterns:
        target_text = target_text.replace(pattern, replacement)
    output = target_text

    instance["x"] = x
    instance["edge_index"] = edge_index
    instance["edge_attr"] = edge_attr
    instance["task"] = task
    instance["smiles"] = smiles
    instance["instruction"] = instruction
    instance["target"] = output
    return instance

In [6]:
processed_testset = testset.map(process_instance, num_proc=50)
processed_testset_path = '/data/text-mol/data/Mol-LLM-v7.1/mol_llm_testset_general'
processed_testset.save_to_disk(processed_testset_path)

Map (num_proc=50): 100%|██████████| 55757/55757 [00:02<00:00, 19073.10 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 55757/55757 [00:01<00:00, 34087.06 examples/s]


In [7]:
process_instance(testset[0])

{'task': 'aqsol-logS',
 'x': [[5.0, 0.0, 4.0, 5.0, 3.0, 0.0, 2.0, 0.0, 0.0],
  [5.0, 0.0, 3.0, 5.0, 0.0, 0.0, 1.0, 1.0, 1.0],
  [5.0, 0.0, 3.0, 5.0, 1.0, 0.0, 1.0, 1.0, 1.0],
  [5.0, 0.0, 3.0, 5.0, 1.0, 0.0, 1.0, 1.0, 1.0],
  [5.0, 0.0, 3.0, 5.0, 0.0, 0.0, 1.0, 1.0, 1.0],
  [5.0, 0.0, 3.0, 5.0, 1.0, 0.0, 1.0, 1.0, 1.0],
  [5.0, 0.0, 3.0, 5.0, 1.0, 0.0, 1.0, 1.0, 1.0],
  [5.0, 0.0, 4.0, 5.0, 0.0, 0.0, 2.0, 0.0, 0.0],
  [5.0, 0.0, 4.0, 5.0, 3.0, 0.0, 2.0, 0.0, 0.0],
  [5.0, 0.0, 4.0, 5.0, 3.0, 0.0, 2.0, 0.0, 0.0],
  [5.0, 0.0, 4.0, 5.0, 3.0, 0.0, 2.0, 0.0, 0.0]],
 'edge_index': [[0,
   1,
   1,
   2,
   2,
   3,
   3,
   4,
   4,
   5,
   5,
   6,
   4,
   7,
   7,
   8,
   7,
   9,
   7,
   10,
   6,
   1],
  [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 4, 8, 7, 9, 7, 10, 7, 1, 6]],
 'edge_attr': [[0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0],
  [3.0, 0.0, 1.0],
  [3.0, 0.0, 1.0],
  [3.0, 0.0, 1.0],
  [3.0, 0.0, 1.0],
  [3.0, 0.0, 1.0],
  [3.0, 0.0, 1.0],
  [3.0, 0.0, 1.0],
  [3.0, 0.0, 1.0],
  [3.0, 0

In [8]:
processed_testset[0]

{'task': 'aqsol-logS',
 'x': [[5, 0, 4, 5, 3, 0, 2, 0, 0],
  [5, 0, 3, 5, 0, 0, 1, 1, 1],
  [5, 0, 3, 5, 1, 0, 1, 1, 1],
  [5, 0, 3, 5, 1, 0, 1, 1, 1],
  [5, 0, 3, 5, 0, 0, 1, 1, 1],
  [5, 0, 3, 5, 1, 0, 1, 1, 1],
  [5, 0, 3, 5, 1, 0, 1, 1, 1],
  [5, 0, 4, 5, 0, 0, 2, 0, 0],
  [5, 0, 4, 5, 3, 0, 2, 0, 0],
  [5, 0, 4, 5, 3, 0, 2, 0, 0],
  [5, 0, 4, 5, 3, 0, 2, 0, 0]],
 'edge_index': [[0,
   1,
   1,
   2,
   2,
   3,
   3,
   4,
   4,
   5,
   5,
   6,
   4,
   7,
   7,
   8,
   7,
   9,
   7,
   10,
   6,
   1],
  [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 4, 8, 7, 9, 7, 10, 7, 1, 6]],
 'edge_attr': [[0, 0, 0],
  [0, 0, 0],
  [3, 0, 1],
  [3, 0, 1],
  [3, 0, 1],
  [3, 0, 1],
  [3, 0, 1],
  [3, 0, 1],
  [3, 0, 1],
  [3, 0, 1],
  [3, 0, 1],
  [3, 0, 1],
  [0, 0, 0],
  [0, 0, 0],
  [0, 0, 0],
  [0, 0, 0],
  [0, 0, 0],
  [0, 0, 0],
  [0, 0, 0],
  [0, 0, 0],
  [3, 0, 1],
  [3, 0, 1]],
 'additional_x': [[5, 0, 4, 5, 3, 0, 2, 0, 0],
  [5, 0, 3, 5, 0, 0, 1, 1, 1],
  [5, 0, 3, 5, 1, 0, 1, 1, 1],
 