In [1]:
import random
import re
import pandas as pd
import numpy as np
import itertools
from tqdm import tqdm
tqdm.pandas()

from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

from eval_functions import *

#### Import data

Read data frame

In [2]:
df = pd.read_csv('data/coconut/coconut_split.csv')
df.rename(columns={'identifier': 'id', 'smiles': 'target', 'split': 'split'}, inplace=True)

df = df[['id', 'target', 'split']].copy()

Canonicalize SMILES (they already are, however we want to keep the aromaticity information)

In [20]:
df['target'] = df['target'].apply(lambda x: Chem.MolToSmiles(Chem.MolFromSmiles(x)))

Generate equivalent absolute SMILES. This step is done here to remove faulty SMILES from the dataset. The step is repeated later for the dataset augmentations.

In [21]:
df['source'] = df['target'].apply(flatten)

Sanity check

In [22]:
def sanity_check(row):
    target_mol = Chem.MolFromSmiles(row['target'])
    source_mol = Chem.MolFromSmiles(row['source'])
    target_smiles = Chem.MolToSmiles(target_mol, canonical=True, isomericSmiles=False)
    source_smiles = Chem.MolToSmiles(source_mol, canonical=True, isomericSmiles=False)
    return target_smiles == source_smiles

checks = df.apply(sanity_check, axis=1)

Remove annoying SMILES (where the two flat structures are not the same)

In [23]:
df = df[checks].reset_index(drop=True)

#### Generate dataset with scrambled stereocenters

Define helper function

In [24]:
random.seed(42)

def exact_swap_stereochemistry(text):
    text = text.replace('@@', 'TEMP_DOUBLE_AT')
    text = text.replace('@', '@@')
    text = text.replace('TEMP_DOUBLE_AT', '@')
    text = text.replace('/', 'TEMP_SLASH').replace('\\', '/').replace('TEMP_SLASH', '\\')
    return text

def scramble_stereochemistry(text):
    def random_replacement(match):
        return random.choice(["@", "@@"])
    def random_slash_backslash(match):
        return random.choice(["/", "\\"])
    text = re.sub(r'@@|@', random_replacement, text)
    text = re.sub(r'/|\\', random_slash_backslash, text)
    
    return text

scrambled_df = df.copy()
scrambled_df['target'] = scrambled_df['target'].apply(scramble_stereochemistry)

#### Augment data by SMILES randomization

Prepare function to randomize SMILES

In [25]:
def randomize_smiles(smiles, seed=None):
    
    if seed is not None:
        np.random.seed(seed)
    else:
        np.random.seed()

    mol = Chem.MolFromSmiles(smiles)
    atoms = list(range(mol.GetNumAtoms()))
    np.random.shuffle(atoms)
    new_mol = Chem.RenumberAtoms(mol, atoms)
    return Chem.MolToSmiles(new_mol, canonical=False, isomericSmiles=True)

def randomize_augment_dataframe(df, factor):
    augmented_data = []

    for _, row in df.iterrows():
        for i in range(factor):
            new_smiles = randomize_smiles(row['target'], seed=i)
            augmented_data.append({'id': row['id'], 'target': new_smiles, 'split': row['split']})

    augmented_df = pd.DataFrame(augmented_data)
    return augmented_df

In [26]:
substitutions = {
    r'\[K[@,H]*\]': '[K]',
    r'\[B[@,H]*\]': 'B',
    r'\[Na[@,H,+,-]*\]': '[Na]',
    r'\[C[@,H]*\]': 'C',
    r'\[N[@,H]*\]': 'N',
    r'\[O[@,H]*\]': 'O',
    r'\[S[@,H]*\]': 'S',
    r'\[P[@,H]*\]': 'P',
    r'\[F[@,H]*\]': 'F',
    r'\[Cl[@,H]*\]': '[Cl]',
    r'\[Br[@,H]*\]': '[Br]',
    r'\[I[@,H]*\]': 'I',
    r'@': '',
    r'/': '',
    r'\\': '',
    r'\[C\]': 'C'
}

def apply_substitutions(smiles):
    """Apply predefined substitutions to the SMILES string."""
    for pattern, replacement in substitutions.items():
        smiles = re.sub(pattern, replacement, smiles)
    return smiles

def generate_n_permutations(smiles, matches, num_to_remove, n, seed=None):
    """Generate exactly 'n' permutations where 'num_to_remove' matches are replaced with substitutions."""
    permutations = set()

    if seed is not None:
        random.seed(seed)

    if len(matches) < num_to_remove:
        return list(permutations)

    all_combinations = list(itertools.combinations(range(len(matches)), num_to_remove))
    
    if len(all_combinations) <= n:
        selected_combinations = all_combinations
    else:
        selected_combinations = random.sample(all_combinations, n)

    for indices in selected_combinations:
        modified_smiles = list(smiles)
        for index in sorted(indices, reverse=True):
            match_start = matches[index].start()
            match_end = matches[index].end()
            match_str = smiles[match_start:match_end]
            modified_smiles[match_start:match_end] = apply_substitutions(match_str)

        permutations.add(''.join(modified_smiles))

    return list(permutations)

def generate_random_permutations(smiles, matches, max_augmentations=50, seed=None, max_attempts=1000):
    """Generate a specified number of random permutations by replacing matches with substitutions."""
    if seed is not None:
        random.seed(seed)

    augmentations = set()
    attempts = 0

    while len(augmentations) < max_augmentations and attempts < max_attempts:
        num_to_remove = random.randint(1, len(matches))
        selected_indices = random.sample(range(len(matches)), num_to_remove)
        modified_smiles = list(smiles)

        for index in sorted(selected_indices, reverse=True):
            match_start = matches[index].start()
            match_end = matches[index].end()
            match_str = smiles[match_start:match_end]
            modified_smiles[match_start:match_end] = apply_substitutions(match_str)

        augmented_smiles = ''.join(modified_smiles)  
        augmentations.add(augmented_smiles)
        
        attempts += 1

    return list(augmentations)

def uniform_augment(smiles, n, seed=None):
    """Augment SMILES by generating 'n' augmentations for each number of matches replaced."""
    pattern = r'(\[.*?\]|[\\/])'
    matches = list(re.finditer(pattern, smiles))

    if not matches:
        return [smiles]

    augmentations = set()
    
    if len(matches) > 20:
        augmentations.update(generate_random_permutations(smiles, matches, max_augmentations=50, seed=seed))
    else:
        for num_to_remove in range(1, len(matches) + 1):
            augmentations.update(generate_n_permutations(smiles, matches, num_to_remove, n, seed))

    return list(augmentations)

def uniform_augment_dataframe(df, smiles_column, id_column, split_column, n=2, seed=None):
    """Augment SMILES in the specified column of the DataFrame while keeping id and split columns."""
    augmented_data = []

    for index, row in df.iterrows():
        original_smiles = row[smiles_column]
        augmented_smiles_list = uniform_augment(original_smiles, n, seed)

        for augmented_smiles in augmented_smiles_list:
            augmented_data.append({
                id_column: row[id_column],
                smiles_column: original_smiles,
                'source': augmented_smiles,
                split_column: row[split_column]
            })

    augmented_df = pd.DataFrame(augmented_data)
    
    return augmented_df

Prepare datasets with different augmentations

In [9]:
randomized = randomize_augment_dataframe(df, 1)
scrambled_randomized = randomize_augment_dataframe(scrambled_df, 1)
augmented_2x = randomize_augment_dataframe(df, 2)
augmented_5x = randomize_augment_dataframe(df, 5)
augmented_10x = randomize_augment_dataframe(df, 10)
augmented_20x = randomize_augment_dataframe(df, 20)
augmented_50x = randomize_augment_dataframe(df, 50)

Generate equivalent achiral SMILES

In [10]:
randomized['source'] = randomized['target'].apply(flatten)
scrambled_randomized['source'] = scrambled_randomized['target'].apply(flatten)
augmented_2x['source'] = augmented_2x['target'].apply(flatten)
augmented_5x['source'] = augmented_5x['target'].apply(flatten)
augmented_10x['source'] = augmented_10x['target'].apply(flatten)
augmented_20x['source'] = augmented_20x['target'].apply(flatten)
augmented_50x['source'] = augmented_50x['target'].apply(flatten)

Partial augmentations

In [21]:
partial_augmented_5x = uniform_augment_dataframe(df, 'target', 'id', 'split', 5, seed=42)
randomized_partial_augmented_5x = uniform_augment_dataframe(randomized, 'target', 'id', 'split', 5, seed=42)
scrambled_partial_augmented_5x = uniform_augment_dataframe(scrambled_df, 'target', 'id', 'split', 5, seed=42)
randomized_scrambled_partial_augmented_5x = uniform_augment_dataframe(scrambled_randomized, 'target', 'id', 'split', 5, seed=42)
mixed_augmented = uniform_augment_dataframe(augmented_10x, 'target', 'id', 'split', 1, seed=42)

Remove duplicates

In [11]:
randomized.drop_duplicates(subset='target', inplace=True)
augmented_2x.drop_duplicates(subset='target', inplace=True)
augmented_5x.drop_duplicates(subset='target', inplace=True)
augmented_10x.drop_duplicates(subset='target', inplace=True)
augmented_20x.drop_duplicates(subset='target', inplace=True)
augmented_50x.drop_duplicates(subset='target', inplace=True)

partial_augmented_5x.drop_duplicates(subset='source', inplace=True)
randomized_partial_augmented_5x.drop_duplicates(subset='source', inplace=True)
scrambled_partial_augmented_5x.drop_duplicates(subset='source', inplace=True)
randomized_scrambled_partial_augmented_5x.drop_duplicates(subset='source', inplace=True)
mixed_augmented.drop_duplicates(subset='source', inplace=True)

Shuffle data

In [12]:
augmented_2x_shuffled = augmented_2x.sample(frac=1, random_state=42).reset_index(drop=True)
augmented_5x_shuffled = augmented_5x.sample(frac=1, random_state=42).reset_index(drop=True)
augmented_10x_shuffled = augmented_10x.sample(frac=1, random_state=42).reset_index(drop=True)
augmented_20x_shuffled = augmented_20x.sample(frac=1, random_state=42).reset_index(drop=True)
augmented_50x_shuffled = augmented_50x.sample(frac=1, random_state=42).reset_index(drop=True)

partial_augmented_5x_shuffled = partial_augmented_5x.sample(frac=1, random_state=42).reset_index(drop=True)
randomized_partial_augmented_5x_shuffled = randomized_partial_augmented_5x.sample(frac=1, random_state=42).reset_index(drop=True)
scrambled_partial_augmented_5x_shuffled = scrambled_partial_augmented_5x.sample(frac=1, random_state=42).reset_index(drop=True)
randomized_scrambled_partial_augmented_5x_shuffled = randomized_scrambled_partial_augmented_5x.sample(frac=1, random_state=42).reset_index(drop=True)
mixed_augmented_shuffled = mixed_augmented.sample(frac=1, random_state=42).reset_index(drop=True)

#### Export augmented datasets

Save augmented data frames

In [34]:
df[['id', 'source', 'target', 'split']].to_csv('data/dataset_not_augmented.csv', index=False)
randomized[['id', 'source', 'target', 'split']].to_csv('data/dataset_randomized.csv', index=False)
scrambled_df[['id', 'source', 'target', 'split']].to_csv('data/dataset_scrambled.csv', index=False)

augmented_2x_shuffled[['id', 'source', 'target', 'split']].to_csv('data/dataset_augmented_2x.csv', index=False)
augmented_5x_shuffled[['id', 'source', 'target', 'split']].to_csv('data/dataset_augmented_5x.csv', index=False)
augmented_10x_shuffled[['id', 'source', 'target', 'split']].to_csv('data/dataset_augmented_10x.csv', index=False)
augmented_20x_shuffled[['id', 'source', 'target', 'split']].to_csv('data/dataset_augmented_20x.csv', index=False)
augmented_50x_shuffled[['id', 'source', 'target', 'split']].to_csv('data/dataset_augmented_50x.csv', index=False)

partial_augmented_5x_shuffled[['id', 'source', 'target', 'split']].to_csv('data/dataset_partial_augmented_5x.csv', index=False)
randomized_partial_augmented_5x_shuffled[['id', 'source', 'target', 'split']].to_csv('data/dataset_randomized_partial_augmented_5x.csv', index=False)
scrambled_partial_augmented_5x_shuffled[['id', 'source', 'target', 'split']].to_csv('data/dataset_scrambled_partial_augmented_5x.csv', index=False)
randomized_scrambled_partial_augmented_5x_shuffled[['id', 'source', 'target', 'split']].to_csv('data/dataset_randomized_scrambled_partial_augmented_5x.csv', index=False)
mixed_augmented_shuffled[['id', 'source', 'target', 'split']].to_csv('data/dataset_mixed_augmented.csv', index=False)