In [None]:
import os
import random
import re
import pandas as pd
import numpy as np
import itertools
from tqdm import tqdm
tqdm.pandas()

from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

from eval_functions import flatten  

Utility & Preprocessing Functions

In [None]:
def canonicalize_smiles(smiles):
    """Canonicalize SMILES (preserving aromaticity)."""
    mol = Chem.MolFromSmiles(smiles)
    return Chem.MolToSmiles(mol) if mol is not None else smiles

def sanity_check(row):
    """Sanity-check that target and flattened source yield the same canonical SMILES."""
    target_mol = Chem.MolFromSmiles(row['target'])
    source_mol = Chem.MolFromSmiles(row['source'])
    if target_mol is None or source_mol is None:
        return False
    target_smiles = Chem.MolToSmiles(target_mol, canonical=True, isomericSmiles=False)
    source_smiles = Chem.MolToSmiles(source_mol, canonical=True, isomericSmiles=False)
    return target_smiles == source_smiles

Stereochemistry Modification Functions

In [None]:
def exact_swap_stereochemistry(text):
    text = text.replace('@@', 'TEMP_DOUBLE_AT')
    text = text.replace('@', '@@')
    text = text.replace('TEMP_DOUBLE_AT', '@')
    text = text.replace('/', 'TEMP_SLASH').replace('\\', '/').replace('TEMP_SLASH', '\\')
    return text

def scramble_stereochemistry(text):
    def random_replacement(match):
        return random.choice(["@", "@@"])
    def random_slash_backslash(match):
        return random.choice(["/", "\\"])
    text = re.sub(r'@@|@', random_replacement, text)
    text = re.sub(r'/|\\', random_slash_backslash, text)
    return text

SMILES Randomization & Augmentation Functions

In [None]:
def randomize_smiles(smiles, seed_val=None):
    if seed_val is not None:
        np.random.seed(seed_val)
    else:
        np.random.seed()
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return smiles
    atoms = list(range(mol.GetNumAtoms()))
    np.random.shuffle(atoms)
    new_mol = Chem.RenumberAtoms(mol, atoms)
    return Chem.MolToSmiles(new_mol, canonical=False, isomericSmiles=True)

def randomize_augment_dataframe(df, factor):
    """Generate `factor` randomized SMILES per row."""
    augmented_data = []
    for _, row in df.iterrows():
        for i in range(factor):
            new_smiles = randomize_smiles(row['target'], seed_val=i)
            augmented_data.append({
                'id': row['id'],
                'target': new_smiles,
                'split': row['split']
            })
    return pd.DataFrame(augmented_data)

Partial Stereochemistry Removal

In [None]:
substitutions = {
    r'\[K[@,H]*\]': '[K]',
    r'\[B[@,H]*\]': 'B',
    r'\[Na[@,H,+,-]*\]': '[Na]',
    r'\[C[@,H]*\]': 'C',
    r'\[N[@,H]*\]': 'N',
    r'\[O[@,H]*\]': 'O',
    r'\[S[@,H]*\]': 'S',
    r'\[P[@,H]*\]': 'P',
    r'\[F[@,H]*\]': 'F',
    r'\[Cl[@,H]*\]': '[Cl]',
    r'\[Br[@,H]*\]': '[Br]',
    r'\[I[@,H]*\]': 'I',
    r'@': '',
    r'/': '',
    r'\\': '',
    r'\[C\]': 'C'
}

def apply_substitutions(smiles):
    for pattern, replacement in substitutions.items():
        smiles = re.sub(pattern, replacement, smiles)
    return smiles

def generate_n_permutations(smiles, matches, num_to_remove, n, seed_val=None):
    """Generate exactly `n` permutations replacing `num_to_remove` matches."""
    permutations = set()
    if seed_val is not None:
        random.seed(seed_val)
    if len(matches) < num_to_remove:
        return list(permutations)
    all_combinations = list(itertools.combinations(range(len(matches)), num_to_remove))
    selected_combinations = all_combinations if len(all_combinations) <= n else random.sample(all_combinations, n)
    for indices in selected_combinations:
        modified_smiles = list(smiles)
        for index in sorted(indices, reverse=True):
            match = matches[index]
            match_start, match_end = match.start(), match.end()
            match_str = smiles[match_start:match_end]
            modified_smiles[match_start:match_end] = apply_substitutions(match_str)
        permutations.add(''.join(modified_smiles))
    return list(permutations)

def generate_random_permutations(smiles, matches, max_augmentations=50, seed_val=None, max_attempts=1000):
    if seed_val is not None:
        random.seed(seed_val)
    augmentations = set()
    attempts = 0
    while len(augmentations) < max_augmentations and attempts < max_attempts:
        num_to_remove = random.randint(1, len(matches))
        selected_indices = random.sample(range(len(matches)), num_to_remove)
        modified_smiles = list(smiles)
        for index in sorted(selected_indices, reverse=True):
            match = matches[index]
            match_start, match_end = match.start(), match.end()
            match_str = smiles[match_start:match_end]
            modified_smiles[match_start:match_end] = apply_substitutions(match_str)
        augmentations.add(''.join(modified_smiles))
        attempts += 1
    return list(augmentations)

def uniform_augment(smiles, n, seed_val=None):
    """Augment a SMILES string uniformly using substitutions."""
    pattern = r'(\[.*?\]|[\\/])'
    matches = list(re.finditer(pattern, smiles))
    if not matches:
        return [smiles]
    augmentations = set()
    if len(matches) > 20:
        augmentations.update(generate_random_permutations(smiles, matches, max_augmentations=50, seed_val=seed_val))
    else:
        for num_to_remove in range(1, len(matches) + 1):
            augmentations.update(generate_n_permutations(smiles, matches, num_to_remove, n, seed_val))
    return list(augmentations)

def uniform_augment_dataframe(df, smiles_column, id_column, split_column, n=2, seed_val=None):
    augmented_data = []
    for _, row in df.iterrows():
        original_smiles = row[smiles_column]
        augmented_smiles_list = uniform_augment(original_smiles, n, seed_val)
        for augmented_smiles in augmented_smiles_list:
            augmented_data.append({
                id_column: row[id_column],
                smiles_column: original_smiles,
                'source': augmented_smiles,
                split_column: row[split_column]
            })
    return pd.DataFrame(augmented_data)

Processing of One Input Dataset

In [None]:
def process_split_dataset(input_file, output_dir, aug_seed):
    """
    Process one split-dataset (from a given seed) and generate all augmented versions.
    
    Parameters:
      - input_file: Path to the original split CSV file.
      - output_dir: Directory where output files will be saved.
      - aug_seed: Seed used for augmentation (and scrambling) operations.
    """
    # Ensure the output directory exists.
    os.makedirs(output_dir, exist_ok=True)
    
    # Set global seeds for reproducibility
    random.seed(aug_seed)
    np.random.seed(aug_seed)
    
    # Read and prepare the dataset
    df = pd.read_csv(input_file)
    df.rename(columns={'identifier': 'id', 'smiles': 'target', 'split': 'split'}, inplace=True)
    df = df[['id', 'target', 'split']].copy()
    
    # Canonicalize SMILES and generate the achiral (flattened) version
    df['target'] = df['target'].apply(canonicalize_smiles)
    df['source'] = df['target'].apply(flatten)
    
    # Apply sanity check and remove problematic rows
    checks = df.apply(sanity_check, axis=1)
    df = df[checks].reset_index(drop=True)
    
    # Create a scrambled version (scrambled stereochemistry)
    scrambled_df = df.copy()
    scrambled_df['target'] = scrambled_df['target'].apply(scramble_stereochemistry)
    

    # Augmentation: Randomized SMILES Generation
    randomized = randomize_augment_dataframe(df, factor=1)
    scrambled_randomized = randomize_augment_dataframe(scrambled_df, factor=1)
    augmented_2x = randomize_augment_dataframe(df, factor=2)
    augmented_5x = randomize_augment_dataframe(df, factor=5)
    augmented_10x = randomize_augment_dataframe(df, factor=10)
    augmented_20x = randomize_augment_dataframe(df, factor=20)
    augmented_50x = randomize_augment_dataframe(df, factor=50)
    
    # Re-generate the achiral (flattened) SMILES for augmented datasets
    for d in [randomized, scrambled_randomized, augmented_2x, augmented_5x, augmented_10x, augmented_20x, augmented_50x]:
        d['source'] = d['target'].apply(flatten)
    
    # Partial Augmentations via Uniform Augmentation
    partial_augmented_5x = uniform_augment_dataframe(df, 'target', 'id', 'split', n=5, seed_val=aug_seed)
    randomized_partial_augmented_5x = uniform_augment_dataframe(randomized, 'target', 'id', 'split', n=5, seed_val=aug_seed)
    scrambled_partial_augmented_5x = uniform_augment_dataframe(scrambled_df, 'target', 'id', 'split', n=5, seed_val=aug_seed)
    randomized_scrambled_partial_augmented_5x = uniform_augment_dataframe(scrambled_randomized, 'target', 'id', 'split', n=5, seed_val=aug_seed)
    mixed_augmented = uniform_augment_dataframe(augmented_10x, 'target', 'id', 'split', n=1, seed_val=aug_seed)
    
    # Remove duplicates
    for d in [randomized, augmented_2x, augmented_5x, augmented_10x, augmented_20x, augmented_50x]:
        d.drop_duplicates(subset='target', inplace=True)
    for d in [partial_augmented_5x, randomized_partial_augmented_5x,
              scrambled_partial_augmented_5x, randomized_scrambled_partial_augmented_5x,
              mixed_augmented]:
        d.drop_duplicates(subset='source', inplace=True)
    
    # Shuffle datasets
    augmented_2x_shuffled = augmented_2x.sample(frac=1, random_state=aug_seed).reset_index(drop=True)
    augmented_5x_shuffled = augmented_5x.sample(frac=1, random_state=aug_seed).reset_index(drop=True)
    augmented_10x_shuffled = augmented_10x.sample(frac=1, random_state=aug_seed).reset_index(drop=True)
    augmented_20x_shuffled = augmented_20x.sample(frac=1, random_state=aug_seed).reset_index(drop=True)
    augmented_50x_shuffled = augmented_50x.sample(frac=1, random_state=aug_seed).reset_index(drop=True)
    
    partial_augmented_5x_shuffled = partial_augmented_5x.sample(frac=1, random_state=aug_seed).reset_index(drop=True)
    randomized_partial_augmented_5x_shuffled = randomized_partial_augmented_5x.sample(frac=1, random_state=aug_seed).reset_index(drop=True)
    scrambled_partial_augmented_5x_shuffled = scrambled_partial_augmented_5x.sample(frac=1, random_state=aug_seed).reset_index(drop=True)
    randomized_scrambled_partial_augmented_5x_shuffled = randomized_scrambled_partial_augmented_5x.sample(frac=1, random_state=aug_seed).reset_index(drop=True)
    mixed_augmented_shuffled = mixed_augmented.sample(frac=1, random_state=aug_seed).reset_index(drop=True)
    
    # Export all datasets (filenames include the seed)
    df[['id', 'source', 'target', 'split']].to_csv(os.path.join(output_dir, f'c1-{aug_seed}.csv'), index=False)
    randomized[['id', 'source', 'target', 'split']].to_csv(os.path.join(output_dir, f'nc1-{aug_seed}.csv'), index=False)
    scrambled_df[['id', 'source', 'target', 'split']].to_csv(os.path.join(output_dir, f'r1-{aug_seed}.csv'), index=False)
    
    augmented_2x_shuffled[['id', 'source', 'target', 'split']].to_csv(os.path.join(output_dir, f'a2-{aug_seed}.csv'), index=False)
    augmented_5x_shuffled[['id', 'source', 'target', 'split']].to_csv(os.path.join(output_dir, f'a5-{aug_seed}.csv'), index=False)
    augmented_10x_shuffled[['id', 'source', 'target', 'split']].to_csv(os.path.join(output_dir, f'a10-{aug_seed}.csv'), index=False)
    augmented_20x_shuffled[['id', 'source', 'target', 'split']].to_csv(os.path.join(output_dir, f'a20-{aug_seed}.csv'), index=False)
    augmented_50x_shuffled[['id', 'source', 'target', 'split']].to_csv(os.path.join(output_dir, f'a50-{aug_seed}.csv'), index=False)
    
    partial_augmented_5x_shuffled[['id', 'source', 'target', 'split']].to_csv(os.path.join(output_dir, f'npstereo-{aug_seed}.csv'), index=False)
    randomized_partial_augmented_5x_shuffled[['id', 'source', 'target', 'split']].to_csv(os.path.join(output_dir, f'ncnpstereo-{aug_seed}.csv'), index=False)
    scrambled_partial_augmented_5x_shuffled[['id', 'source', 'target', 'split']].to_csv(os.path.join(output_dir, f'rp-{aug_seed}.csv'), index=False)
    randomized_scrambled_partial_augmented_5x_shuffled[['id', 'source', 'target', 'split']].to_csv(os.path.join(output_dir, f'ncrp-{aug_seed}.csv'), index=False)
    mixed_augmented_shuffled[['id', 'source', 'target', 'split']].to_csv(os.path.join(output_dir, f'm65-{aug_seed}.csv'), index=False)
    
    print(f"Finished processing {input_file} with augmentation seed {aug_seed}.")

Generate datasets

In [None]:
seeds = [0, 1, 42]

for seed_val in seeds:
    input_file = f'data/coconut/coconut-split-{seed_val}.csv'
    output_dir = f'data/augmented/seed-{seed_val}'
    process_split_dataset(input_file, output_dir, aug_seed=seed_val)