In [1]:
import pandas as pd
import re

#### Prepare data for training

Import generated data

In [2]:
not_augmented = pd.read_csv('data/dataset_not_augmented.csv')
augmented_2x = pd.read_csv('data/dataset_augmented_2x.csv')
augmented_5x = pd.read_csv('data/dataset_augmented_5x.csv')
scrambled_5x = pd.read_csv('data/dataset_scrambled_5x.csv')
augmented_10x = pd.read_csv('data/dataset_augmented_10x.csv')
augmented_20x = pd.read_csv('data/dataset_augmented_20x.csv')

Tokenize the data

In [3]:
def smi_tokenizer(smi: str) -> str:
        """
        Tokenize a SMILES molecule or reaction.
        """
        pattern =  "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\!|\$|\%[0-9]{2}|[0-9])"
        regex = re.compile(pattern)
        tokens = [token for token in regex.findall(smi)]
        assert smi == ''.join(tokens)
        return ' '.join(tokens)

In [4]:
not_augmented['source'] = not_augmented['source'].apply(smi_tokenizer)
augmented_2x['source'] = augmented_2x['source'].apply(smi_tokenizer)
augmented_5x['source'] = augmented_5x['source'].apply(smi_tokenizer)
scrambled_5x['source'] = scrambled_5x['source'].apply(smi_tokenizer)
augmented_10x['source'] = augmented_10x['source'].apply(smi_tokenizer)
augmented_20x['source'] = augmented_20x['source'].apply(smi_tokenizer)

In [5]:
not_augmented['target'] = not_augmented['target'].apply(smi_tokenizer)
augmented_2x['target'] = augmented_2x['target'].apply(smi_tokenizer)
augmented_5x['target'] = augmented_5x['target'].apply(smi_tokenizer)
scrambled_5x['target'] = scrambled_5x['target'].apply(smi_tokenizer)
augmented_10x['target'] = augmented_10x['target'].apply(smi_tokenizer)
augmented_20x['target'] = augmented_20x['target'].apply(smi_tokenizer)

Export data for training with OpenNMT

In [6]:
def export_to_opennmt(df, path):
    
    train = df[df['split'] == 'train']
    test = df[df['split'] == 'test']
    val = df[df['split'] == 'validation']

    train['source'].to_csv(path + '/src-train.txt', index=False, header=False)
    train['target'].to_csv(path + '/tgt-train.txt', index=False, header=False)
    test['source'].to_csv(path + '/src-test.txt', index=False, header=False)
    test['target'].to_csv(path + '/tgt-test.txt', index=False, header=False)
    val['source'].to_csv(path + '/src-val.txt', index=False, header=False)
    val['target'].to_csv(path + '/tgt-val.txt', index=False, header=False)

export_to_opennmt(not_augmented, 'data/opennmt/not_augmented')
export_to_opennmt(augmented_2x, 'data/opennmt/augmented_2x')
export_to_opennmt(augmented_5x, 'data/opennmt/augmented_5x')
export_to_opennmt(scrambled_5x, 'data/opennmt/scrambled_5x')
export_to_opennmt(augmented_10x, 'data/opennmt/augmented_10x')
export_to_opennmt(augmented_20x, 'data/opennmt/augmented_20x')