In [1]:
import random
import re
import pandas as pd
import numpy as np

from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

#### Import data

Read data frame

In [2]:
df = pd.read_csv('data/coconut/coconut_split.csv')
df.rename(columns={'identifier': 'id', 'smiles': 'target', 'split': 'split'}, inplace=True)
df = df[['id', 'target', 'split']].copy()

Generate equivalent absolute SMILES. This step is done here to remove faulty SMILES from the dataset. The step is repeated later for the dataset augmentations.

In [3]:
def flatten(smiles):
    substitutions = {
        r'\[K[@,H]*\]': '[K]',
        r'\[B[@,H]*\]': 'B',
        r'\[Na[@,H,+,-]*\]': '[Na]',
        r'\[C[@,H]*\]': 'C',
        r'\[N[@,H]*\]': 'N',
        r'\[O[@,H]*\]': 'O',
        r'\[S[@,H]*\]': 'S',
        r'\[P[@,H]*\]': 'P',
        r'\[F[@,H]*\]': 'F',
        r'\[Cl[@,H]*\]': '[Cl]',
        r'\[Br[@,H]*\]': '[Br]',
        r'\[I[@,H]*\]': 'I',
        r'@': '',
        r'/': '',
        r'\\': '',
        r'\[C\]': 'C'
    }

    for pattern, replacement in substitutions.items():
        smiles = re.sub(pattern, replacement, smiles)

    return smiles

df['source'] = df['target'].apply(flatten)

Sanity check

In [4]:
def sanity_check(row):
    target_mol = Chem.MolFromSmiles(row['target'])
    source_mol = Chem.MolFromSmiles(row['source'])
    target_smiles = Chem.MolToSmiles(target_mol, canonical=True, isomericSmiles=False)
    source_smiles = Chem.MolToSmiles(source_mol, canonical=True, isomericSmiles=False)
    return target_smiles == source_smiles

checks = df.apply(sanity_check, axis=1)

Remove annoying SMILES (where the two flat structures are not the same)

In [5]:
df = df[checks].reset_index(drop=True)

#### Generate dataset with scrambled stereocenters

Define helper function

In [6]:
random.seed(42)

def scramble_stereochemistry(text):
    def random_replacement(match):
        return random.choice(["@", "@@"])
    text = re.sub(r'@@|@', random_replacement, text)
    text = text.replace('/', 'TEMP_SLASH').replace('\\', '/').replace('TEMP_SLASH', '\\')
    return text

scrambled_df = df.copy()
scrambled_df['target'] = scrambled_df['target'].apply(scramble_stereochemistry)

#### Augment data by SMILES randomization

Prepare function to randomize SMILES

In [7]:
def randomize_smiles(smiles, seed=None):
    
    if seed is not None:
        np.random.seed(seed)
    else:
        np.random.seed()

    mol = Chem.MolFromSmiles(smiles)
    atoms = list(range(mol.GetNumAtoms()))
    np.random.shuffle(atoms)
    new_mol = Chem.RenumberAtoms(mol, atoms)
    return Chem.MolToSmiles(new_mol, canonical=False, isomericSmiles=True)

def augment_dataframe(df, factor):
    augmented_data = []

    for _, row in df.iterrows():
        for i in range(factor):
            new_smiles = randomize_smiles(row['target'], seed=i)
            augmented_data.append({'id': row['id'], 'target': new_smiles, 'split': row['split']})

    augmented_df = pd.DataFrame(augmented_data)
    return augmented_df

Prepare datasets with different augmentations

In [8]:
augmented_2x = augment_dataframe(df, 2)
augmented_5x = augment_dataframe(df, 5)
scrambled_5x = augment_dataframe(scrambled_df, 5)
augmented_10x = augment_dataframe(df, 10)
augmented_20x = augment_dataframe(df, 20)

Generate equivalent achiral SMILES

In [14]:
augmented_2x['source'] = augmented_2x['target'].apply(flatten)
augmented_5x['source'] = augmented_5x['target'].apply(flatten)
scrambled_5x['source'] = scrambled_5x['target'].apply(flatten)
augmented_10x['source'] = augmented_10x['target'].apply(flatten)
augmented_20x['source'] = augmented_20x['target'].apply(flatten)

Remove duplicates

In [16]:
augmented_2x.drop_duplicates(subset='target', inplace=True)
augmented_5x.drop_duplicates(subset='target', inplace=True)
scrambled_5x.drop_duplicates(subset='target', inplace=True)
augmented_10x.drop_duplicates(subset='target', inplace=True)
augmented_20x.drop_duplicates(subset='target', inplace=True)

Shuffle data

In [19]:
augmented_2x_shuffled = augmented_2x.sample(frac=1, random_state=42).reset_index(drop=True)
augmented_5x_shuffled = augmented_5x.sample(frac=1, random_state=42).reset_index(drop=True)
scrambled_5x_shuffled = scrambled_5x.sample(frac=1, random_state=42).reset_index(drop=True)
augmented_10x_shuffled = augmented_10x.sample(frac=1, random_state=42).reset_index(drop=True)
augmented_20x_shuffled = augmented_20x.sample(frac=1, random_state=42).reset_index(drop=True)

#### Export augmented datasets

Save augmented data frames

In [20]:
df[['id', 'source', 'target', 'split']].to_csv('data/dataset_not_augmented.csv', index=False)
augmented_2x_shuffled[['id', 'source', 'target', 'split']].to_csv('data/dataset_augmented_2x.csv', index=False)
augmented_5x_shuffled[['id', 'source', 'target', 'split']].to_csv('data/dataset_augmented_5x.csv', index=False)
scrambled_5x_shuffled[['id', 'source', 'target', 'split']].to_csv('data/dataset_scrambled_5x.csv', index=False)
augmented_10x_shuffled[['id', 'source', 'target', 'split']].to_csv('data/dataset_augmented_10x.csv', index=False)
augmented_20x_shuffled[['id', 'source', 'target', 'split']].to_csv('data/dataset_augmented_20x.csv', index=False)