In [1]:
import re
import pandas as pd
from rdkit import Chem
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor

In [2]:
USPTO_MIXED_PATH = "../../data/uspto_mit/data/MIT_mixed"
USPTO_SEP_PATH = "../../data/uspto_mit/data/MIT_separated"
USPTO_50_PATH = "../../data/uspto_50"

In [3]:
USPTO_MIXED_PICKLE_PATH = "../../data/uspto_mit/uspto_mixed.pickle"
USPTO_SEP_PICKLE_PATH = "../../data/uspto_mit/uspto_sep.pickle"
USPTO_50_PICKLE_PATH = "../../data/uspto_50/uspto_50.pickle"

In [4]:
# **********************
# *** Util Functions ***
# **********************

In [5]:
def remove_whitespace(path):
    text = path.read_text()
    lines = [line for line in text.split("\n") if line is not None and line != ""]
    mol_strs = [line.replace(" ", "") for line in lines]
    return mol_strs

In [6]:
def create_mols(mol_strs):
    executor = ThreadPoolExecutor()
    futures = [executor.submit(Chem.MolFromSmiles, mol_str) for mol_str in mol_strs]
    mols = [future.result() for future in futures]
    err_strs = [mol_strs[i] for i, mol in enumerate(mols) if mol is None]

    if len(err_strs) > 0:
        print("Could not construct mols for the following strings:")
        for err_str in err_strs:
            print(err_str)

    return mols

In [7]:
def create_mols_sep(mol_strs):
    def process_mol_str(mol_str):
        splits = mol_str.split(">")
        if len(splits) != 2:
            raise ValueError(f"Error with mol str: {mol_str}")

        reactants, reagents = tuple(splits)
        react_mol = Chem.MolFromSmiles(reactants)
        reag_mol = Chem.MolFromSmiles(reagents)
        return react_mol, reag_mol

    executor = ThreadPoolExecutor()
    futures = [executor.submit(process_mol_str, mol_str) for mol_str in mol_strs]
    mols = [future.result() for future in futures]
    err_strs = [mol_strs[i] for i, (react, reag) in enumerate(mols) if react is None or reag is None]

    if len(err_strs) > 0:
        print("Could not construct mols for the following strings:")
        for err_str in err_strs:
            print(err_str)

    react_mols, reag_mols = tuple(zip(*mols))
    return react_mols, reag_mols

In [8]:
def load_mols(path, sep=False, uspto_50=False):
    mol_strs = remove_whitespace(path)

    if uspto_50:
        new_mol_strs = []
        reaction_types = []
        prog = re.compile("(<RX_6>|<RX_2>|<RX_1>|<RX_3>|<RX_7>|<RX_9>|<RX_5>|<RX_10>|<RX_4>|<RX_8>)")

        for mol_str in mol_strs:
            new_str = prog.sub("", mol_str)
            reaction_type = prog.match(mol_str)
            reaction_type = reaction_type[0] if reaction_type is not None else None
            new_mol_strs.append(new_str)
            reaction_types.append(reaction_type)

        mol_strs = new_mol_strs

    if sep:
        react_mols, reag_mols = create_mols_sep(mol_strs)
        return react_mols, reag_mols

    mols = create_mols(mol_strs)
    if uspto_50:
        return mols, reaction_types

    return mols

In [9]:
def build_df(reacts, prods, set_name, reags=None, reaction_types=None):
    data = {"reactants_mol": reacts, "products_mol": prods}
    if reags is not None:
        data["reagents_mol"] = reags
    if reaction_types is not None:
        data["reaction_type"] = reaction_types

    df = pd.DataFrame(data=data)
    df["set"] = set_name
    return df

In [10]:
# ************************
# *** Process Datasets ***
# ************************

In [11]:
def process_uspto_mixed_dataset(path):
    train_reacts = load_mols(path / "src-train.txt")
    train_prods = load_mols(path / "tgt-train.txt")
    train_df = build_df(train_reacts, train_prods, "train")

    val_reacts = load_mols(path / "src-val.txt")
    val_prods = load_mols(path / "tgt-val.txt")
    val_df = build_df(val_reacts, val_prods, "valid")

    test_reacts = load_mols(path / "src-test.txt")
    test_prods = load_mols(path / "tgt-test.txt")
    test_df = build_df(test_reacts, test_prods, "test")

    dataset_df = pd.concat([train_df, val_df, test_df], ignore_index=True)
    return dataset_df

In [27]:
# Process USPTO Mixed dataset
uspto_mixed_path = Path(USPTO_MIXED_PATH)
uspto_mixed_df = process_uspto_mixed_dataset(uspto_mixed_path)
print(f"Read {str(len(uspto_mixed_df.index))} rows from USPTO Mixed dataset.")

Read 479035 rows from USPTO Mixed dataset.


In [29]:
uspto_mixed_df.head()

Unnamed: 0,reactants_mol,products_mol,set
0,<rdkit.Chem.rdchem.Mol object at 0x0000018613E...,<rdkit.Chem.rdchem.Mol object at 0x0000018613F...,train
1,<rdkit.Chem.rdchem.Mol object at 0x0000018613F...,<rdkit.Chem.rdchem.Mol object at 0x0000018613F...,train
2,<rdkit.Chem.rdchem.Mol object at 0x0000018613F...,<rdkit.Chem.rdchem.Mol object at 0x0000018613F...,train
3,<rdkit.Chem.rdchem.Mol object at 0x0000018613F...,<rdkit.Chem.rdchem.Mol object at 0x0000018613F...,train
4,<rdkit.Chem.rdchem.Mol object at 0x0000018613F...,<rdkit.Chem.rdchem.Mol object at 0x0000018613F...,train


In [30]:
uspto_mixed_df.to_pickle(Path(USPTO_MIXED_PICKLE_PATH))

In [42]:
def process_uspto_sep_dataset(path):
    train_reacts, train_reags = load_mols(path / "src-train.txt", sep=True)
    train_prods = load_mols(path / "tgt-train.txt")
    train_df = build_df(train_reacts, train_prods, "train", reags=train_reags)

    val_reacts, val_reags = load_mols(path / "src-val.txt", sep=True)
    val_prods = load_mols(path / "tgt-val.txt")
    val_df = build_df(val_reacts, val_prods, "valid", reags=val_reags)

    test_reacts, test_reags = load_mols(path / "src-test.txt", sep=True)
    test_prods = load_mols(path / "tgt-test.txt")
    test_df = build_df(test_reacts, test_prods, "test", reags=test_reags)

    dataset_df = pd.concat([train_df, val_df, test_df], ignore_index=True)
    return dataset_df

In [43]:
# Process USPTO Separated dataset
uspto_sep_path = Path(USPTO_SEP_PATH)
uspto_sep_df = process_uspto_sep_dataset(uspto_sep_path)
print(f"Read {str(len(uspto_sep_df.index))} rows from USPTO Separated dataset.")

Read 479035 rows from USPTO Separated dataset.


In [44]:
uspto_sep_df.head()

Unnamed: 0,reactants_mol,products_mol,reagents_mol,set
0,<rdkit.Chem.rdchem.Mol object at 0x0000018D0B7...,<rdkit.Chem.rdchem.Mol object at 0x00000186A97...,<rdkit.Chem.rdchem.Mol object at 0x0000018D0B7...,train
1,<rdkit.Chem.rdchem.Mol object at 0x00000186AD7...,<rdkit.Chem.rdchem.Mol object at 0x0000018D51E...,<rdkit.Chem.rdchem.Mol object at 0x00000186AD7...,train
2,<rdkit.Chem.rdchem.Mol object at 0x00000187227...,<rdkit.Chem.rdchem.Mol object at 0x0000018D53E...,<rdkit.Chem.rdchem.Mol object at 0x00000187227...,train
3,<rdkit.Chem.rdchem.Mol object at 0x00000187227...,<rdkit.Chem.rdchem.Mol object at 0x0000018D0B7...,<rdkit.Chem.rdchem.Mol object at 0x00000187227...,train
4,<rdkit.Chem.rdchem.Mol object at 0x0000018DA5E...,<rdkit.Chem.rdchem.Mol object at 0x0000018D9EC...,<rdkit.Chem.rdchem.Mol object at 0x0000018DA5E...,train


In [45]:
uspto_sep_df.to_pickle(Path(USPTO_SEP_PICKLE_PATH))

In [12]:
def process_uspto_50_dataset(path):
    train_prods, train_types = load_mols(path / "src-train.txt", uspto_50=True)
    train_reacts, _ = load_mols(path / "tgt-train.txt", uspto_50=True)
    assert len(train_types) == len(train_reacts) == len(train_prods)
    train_df = build_df(train_reacts, train_prods, "train", reaction_types=train_types)

    val_prods, val_types = load_mols(path / "src-val.txt", uspto_50=True)
    val_reacts, _ = load_mols(path / "tgt-val.txt", uspto_50=True)
    assert len(val_types) == len(val_reacts) == len(val_prods)
    val_df = build_df(val_reacts, val_prods, "valid", reaction_types=val_types)

    test_prods, test_types = load_mols(path / "src-test.txt", uspto_50=True)
    test_reacts, _ = load_mols(path / "tgt-test.txt", uspto_50=True)
    assert len(test_types) == len(test_reacts) == len(test_prods)
    test_df = build_df(test_reacts, test_prods, "test", reaction_types=test_types)

    dataset_df = pd.concat([train_df, val_df, test_df], ignore_index=True)
    return dataset_df

In [13]:
# Process USPTO 50K dataset
uspto_50_path = Path(USPTO_50_PATH)
uspto_50_df = process_uspto_50_dataset(uspto_50_path)
print(f"Read {str(len(uspto_50_df.index))} rows from USPTO 50K dataset.")

Read 50037 rows from USPTO 50K dataset.


In [14]:
uspto_50_df.head()

Unnamed: 0,reactants_mol,products_mol,reaction_type,set
0,<rdkit.Chem.rdchem.Mol object at 0x0000021A123...,<rdkit.Chem.rdchem.Mol object at 0x0000021A121...,<RX_1>,train
1,<rdkit.Chem.rdchem.Mol object at 0x0000021A123...,<rdkit.Chem.rdchem.Mol object at 0x0000021A121...,<RX_6>,train
2,<rdkit.Chem.rdchem.Mol object at 0x0000021A123...,<rdkit.Chem.rdchem.Mol object at 0x0000021A121...,<RX_9>,train
3,<rdkit.Chem.rdchem.Mol object at 0x0000021A123...,<rdkit.Chem.rdchem.Mol object at 0x0000021A121...,<RX_6>,train
4,<rdkit.Chem.rdchem.Mol object at 0x0000021A123...,<rdkit.Chem.rdchem.Mol object at 0x0000021A121...,<RX_1>,train


In [15]:
uspto_50_df.to_pickle(Path(USPTO_50_PICKLE_PATH))

In [4]:
# Generate a text file for USPTO 50K test data for predict.py script

In [15]:
USPTO_50_SAVED_PICKLE_PATH = "../../data/uspto_50.pickle"
USPTO_50_TEST_TEXT_PATH = "../uspto_50_test.txt"

In [8]:
uspto_50_df = pd.read_pickle(USPTO_50_SAVED_PICKLE_PATH)

In [9]:
uspto_50_df.head()

Unnamed: 0,reactants_mol,products_mol,reaction_type,set
0,<rdkit.Chem.rdchem.Mol object at 0x000001ED569...,<rdkit.Chem.rdchem.Mol object at 0x000001ED7FC...,<RX_1>,train
1,<rdkit.Chem.rdchem.Mol object at 0x000001ED569...,<rdkit.Chem.rdchem.Mol object at 0x000001ED7FC...,<RX_6>,train
2,<rdkit.Chem.rdchem.Mol object at 0x000001ED569...,<rdkit.Chem.rdchem.Mol object at 0x000001ED7FC...,<RX_9>,train
3,<rdkit.Chem.rdchem.Mol object at 0x000001ED569...,<rdkit.Chem.rdchem.Mol object at 0x000001ED7FC...,<RX_6>,train
4,<rdkit.Chem.rdchem.Mol object at 0x000001ED569...,<rdkit.Chem.rdchem.Mol object at 0x000001ED7FC...,<RX_1>,train


In [10]:
reacts_mol = uspto_50_df["reactants_mol"].tolist()
sets = uspto_50_df["set"].tolist()

In [11]:
reacts_mol_test = [mol for idx, mol in enumerate(reacts_mol) if sets[idx] == "test"]

In [13]:
reacts_test = [Chem.MolToSmiles(mol) for mol in reacts_mol_test]
print(f"Length of test dataset: {len(reacts_test)}")

Length of test dataset: 5004


In [14]:
output_str = "\n".join(reacts_test)

In [16]:
p = Path(USPTO_50_TEST_TEXT_PATH)
p.write_text(output_str)

259297