## Imports and setup

In [3]:
import os
import huggingface_hub

In [4]:
if os.path.exists(os.path.join(os.environ['HF_HOME'], 'token')):
    with open(os.path.join(os.environ['HF_HOME'], 'token'), 'r') as f:
        token = f.read().strip()
else:
    token = None
huggingface_hub.login(token=token)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /share/kuleshov/yzs2/discrete-guidance/.hf_cache/token
Login successful


In [26]:
import json
import re
import typing

import datasets
import numpy as np
import pandas as pd
import rdkit
import transformers
from rdkit import Chem as rdChem
from tqdm.auto import tqdm

In [7]:
# TODO: Update to 2024.03.6 release when available instead of suppressing warning!
#  See: https://github.com/rdkit/rdkit/issues/7625#
rdkit.rdBase.DisableLog('rdApp.warning')

## Create dataset

In [None]:
def count_rings_and_bonds(
    mol: rdChem.Mol
) -> typing.Dict[str, int]:
    """Counts bond and ring (by type)."""
    
    # Counting rings
    ssr = rdChem.GetSymmSSSR(mol)
    ring_count = len(ssr)
    
    ring_sizes = {}
    for ring in ssr:
        ring_size = len(ring)
        if ring_size not in ring_sizes:
            ring_sizes[ring_size] = 0
        ring_sizes[ring_size] += 1
    
    # Counting bond types
    bond_counts = {
        'single': 0,
        'double': 0,
        'triple': 0,
        'aromatic': 0
    }
    
    for bond in mol.GetBonds():
        if bond.GetIsAromatic():
            bond_counts['aromatic'] += 1
        elif bond.GetBondType() == rdChem.BondType.SINGLE:
            bond_counts['single'] += 1
        elif bond.GetBondType() == rdChem.BondType.DOUBLE:
            bond_counts['double'] += 1
        elif bond.GetBondType() == rdChem.BondType.TRIPLE:
            bond_counts['triple'] += 1
    result = {
        'ring_count': ring_count,
    }
    for k, v in ring_sizes.items():
        result[f"R{k}"] = v

    for k, v in bond_counts.items():
        result[f"{k}_bond"] = v
    return result

In [None]:
"""
    Download data and validation indices from:
        "Score-based Generative Modeling of Graphs via the System of Stochastic Differential Equations"
        https://github.com/harryjo97/GDSS
    > wget wget https://raw.githubusercontent.com/harryjo97/GDSS/master/data/zinc250k.csv
    > wget https://raw.githubusercontent.com/harryjo97/GDSS/master/data/valid_idx_zinc250k.json
"""
df = pd.read_csv('/Users/yairschiff/Downloads/zinc250k.csv', index_col=0, encoding='utf_8')
feats = []
for i, row in tqdm(df.iterrows(), total=len(df), desc='RDKit feats', leave=False):
    feat = {'smiles': row['smiles']}
    feat['canonical_smiles'] = rdChem.CanonSmiles(feat['smiles'])
    m = rdChem.MolFromSmiles(feat['canonical_smiles'])
    feat.update(count_rings_and_bonds(m))
    feats.append(feat)
df = pd.merge(df, pd.DataFrame.from_records(feats), on='smiles')
df = df.fillna(0)
for col in df.columns:  # recast ring counts as int
    if re.search("^R[0-9]+$", col) is not None:
        df[col] = df[col].astype(int)
# Re-order columns
df = df[
    ['smiles', 'logP', 'qed', 'SAS', 'canonical_smiles',
     'single_bond', 'double_bond', 'triple_bond', 'aromatic_bond',
     'ring_count','R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R12', 'R13', 'R14', 'R15', 'R18', 'R24']]

In [None]:
# Read in validation indices
with open('/Users/yairschiff/Downloads/valid_idx_zinc250k.json', 'r') as f:
    valid_idxs = json.load(f)
df['validation'] = df.index.isin(valid_idxs).astype(int)

In [None]:
# Create HF dataset
dataset = datasets.DatasetDict({
    'train': datasets.Dataset.from_pandas(df[df['validation'] == 0].drop(columns=['validation'])),
    'validation': datasets.Dataset.from_pandas(df[df['validation'] == 1].drop(columns=['validation'])),
})
dataset = dataset.remove_columns('__index_level_0__')

In [None]:
dataset.push_to_hub('yairschiff/zinc250k')

## Create tokenizer

In [8]:
def smi_tokenizer(smi):
    """Tokenize a SMILES molecule or reaction.

        Copied from https://github.com/pschwllr/MolecularTransformer.
    """
    import re
    pattern =  "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
    regex = re.compile(pattern)
    tokens = [token for token in regex.findall(smi)]
    assert smi == ''.join(tokens)
    return tokens

In [11]:
dataset = datasets.load_dataset('yairschiff/zinc250k')

In [12]:
# # If vocab file not created yet, uncomment and run this cell

# tokens = []
# for split in dataset.keys():
#     for smi in dataset[split]['canonical_smiles']:
#         tokens.extend(smi_tokenizer(smi))

# with open('zinc250k_vocab.json', 'w', encoding='utf-8') as f:
#     f.write(
#         json.dumps(
#             {t: i for i, t in enumerate(sorted(set(tokens)))},
#             indent=2,
#             sort_keys=True,
#             ensure_ascii=False
#         ) + '\n')

In [14]:
# # If HF tokenizer not yet published, uncomment and run this cell
# import tokenizer

# tokenizer.Zinc250kTokenizer.register_for_auto_class()
# zinc250k_tokenizer = tokenizer.Zinc250kTokenizer(vocab_file='zinc250k_vocab.json')
# zinc250k_tokenizer.push_to_hub('yairschiff/zinc250k-tokenizer')

CommitInfo(commit_url='https://huggingface.co/yairschiff/zinc250k-tokenizer/commit/7a07b0165a8a4f14f09d6137da8cdabf789397fd', commit_message='Upload tokenizer', commit_description='', oid='7a07b0165a8a4f14f09d6137da8cdabf789397fd', pr_url=None, pr_revision=None, pr_num=None)

In [18]:
# Test tokenizer
zinc250k_tokenizer = transformers.AutoTokenizer.from_pretrained(
    'yairschiff/zinc250k-tokenizer', trust_remote_code=True, resume_download=None)
print(dataset['train'][1000]['canonical_smiles'])
print(zinc250k_tokenizer.encode(dataset['train'][1000]['canonical_smiles']))
print(zinc250k_tokenizer.decode(zinc250k_tokenizer.encode(dataset['train'][1000]['canonical_smiles'])))
print(zinc250k_tokenizer.decode(zinc250k_tokenizer.encode(dataset['train'][1000]['canonical_smiles'], add_special_tokens=False)))

Cn1ncc2c1CCC[C@H]2NC(=O)NC[C@H](O)COc1ccc(F)cc1
[0, 25, 69, 15, 69, 68, 68, 16, 68, 15, 25, 25, 25, 35, 16, 29, 25, 11, 23, 30, 12, 29, 25, 35, 11, 30, 12, 25, 30, 68, 15, 68, 68, 68, 11, 27, 12, 68, 68, 15, 1]
<bos>Cn1ncc2c1CCC[C@H]2NC(=O)NC[C@H](O)COc1ccc(F)cc1<eos>
Cn1ncc2c1CCC[C@H]2NC(=O)NC[C@H](O)COc1ccc(F)cc1


In [28]:
lengths = [len(zinc250k_tokenizer.encode(i['canonical_smiles'])) for i in tqdm(dataset['train'], leave=False)]
lengths += [len(zinc250k_tokenizer.encode(i['canonical_smiles'])) for i in tqdm(dataset['validation'], leave=False)]
print(np.histogram(lengths))
print(min(lengths))
print(max(lengths))

  0%|          | 0/224568 [00:00<?, ?it/s]

  0%|          | 0/24887 [00:00<?, ?it/s]

(array([  152,  3351, 21311, 47185, 67972, 70367, 25030, 11778,  2179,
         130]), array([10. , 16.4, 22.8, 29.2, 35.6, 42. , 48.4, 54.8, 61.2, 67.6, 74. ]))
10
74
