In [22]:
from rdkit import Chem
from rdkit import RDLogger
import numpy as np
import augchem

RDLogger.DisableLog('rdApp.*')

In [23]:
loader = augchem.Loader('QM9')
data = loader.loadDataset('QM9', [])
data = np.array(data)

print(data)

[[['C' 'C']]

 [['N' 'N']]

 [['O' 'O']]

 [['C#C' 'C#C']]

 [['C#N' 'C#N']]

 [['C=O' 'C=O']]

 [['CC' 'CC']]

 [['CO' 'CO']]

 [['CC#C' 'CC#C']]

 [['CC#N' 'CC#N']]

 [['CC=O' 'CC=O']]

 [['NC=O' 'NC=O']]

 [['CCC' 'CCC']]

 [['CCO' 'CCO']]

 [['COC' 'COC']]

 [['C1CC1' 'C1CC1']]

 [['C1CO1' 'C1CO1']]

 [['CC(C)=O' 'CC(=O)C']]

 [['CC(N)=O' 'CC(=O)N']]

 [['NC(N)=O' 'NC(=O)N']]

 [['CC(C)C' 'CC(C)C']]

 [['CC(C)O' 'CC(C)O']]

 [['C#CC#C' 'C(#C)C#C']]

 [['C#CC#N' 'C(#C)C#N']]

 [['N#CC#N' 'N#CC#N']]

 [['O=CC#C' 'O=CC#C']]

 [['O=CC#N' 'O=CC#N']]

 [['O=CC=O' 'O=CC=O']]

 [['CC#CC' 'CC#CC']]

 [['CCC#C' 'CCC#C']]

 [['CCC#N' 'CCC#N']]

 [['NCC#N' 'NCC#N']]

 [['OCC#C' 'OCC#C']]

 [['OCC#N' 'OCC#N']]

 [['CCC=O' 'CCC=O']]

 [['CNC=O' 'CNC=O']]

 [['COC=O' 'COC=O']]

 [['OCC=O' 'OCC=O']]

 [['CCCC' 'CCCC']]

 [['CCCO' 'CCCO']]

 [['CCOC' 'CCOC']]

 [['OCCO' 'OCCO']]

 [['CC1CC1' 'CC1CC1']]

 [['CC1CO1' 'C[C@H]1CO1']]

 [['CN1CC1' 'CN1CC1']]

 [['OC1CC1' 'OC1CC1']]

 [['C1CCC1' 'C1CCC1'

In [24]:
valid_mols, invalid_mols = loader.verifyMolecules(data)
print('Valid molecules:', len(valid_mols))
print('Invalid molecules:', len(invalid_mols))

Valid molecules: 50
Invalid molecules: 0


In [32]:
aug = augchem.Augmentator(seed=432)

augment_set = []

# Augment the dataset
for item in data:
    for unit in item:
        for smiles in unit:
            mol = aug.mask(smiles, mask_ratio=0.3)

            if mol not in valid_mols and mol not in augment_set:
                augment_set.append(mol)

            mol = aug.delete(smiles, delete_ratio=0.3)

            if mol not in valid_mols and mol not in augment_set:
                augment_set.append(mol)

            mol = aug.swap(smiles)

            if mol not in valid_mols and mol not in augment_set:
                augment_set.append(mol)

            mol = aug.enumerateSmiles(smiles)

            if mol not in valid_mols and mol not in augment_set:
                augment_set.append(mol)

print('Augmented dataset:', len(augment_set) + len(data))
print('Augmented dataset:', augment_set)


Augmented dataset: 258
Augmented dataset: [['C'], ['N'], ['O'], ['C#C'], ['C#N', 'N#C'], 'O=C', ['O=C', 'C=O'], ['CC'], 'OC', ['CO', 'OC'], 'CC[M]C', ['CC#C', 'C#CC', 'C(#C)C'], '[M]C#C', 'CC[M]N', 'CN#C', ['N#CC', 'C(C)#N', 'C(#N)C', 'CC#N'], 'C[M]#N', 'CC#', ['N#CC', 'C(#N)C', 'CC#N'], 'CC=[M]', ['CC=O', 'C(C)=O', 'O=CC', 'C(=O)C'], 'C[M]=O', 'CO=C', ['CC=O', 'C(C)=O', 'O=CC'], '[M]C=O', 'N=O', 'CN=O', ['C(N)=O', 'O=CN', 'C(=O)N'], 'N[M]=O', ['NC=O', 'O=CN', 'C(=O)N'], ['CCC', 'C(C)C'], ['C(C)C', 'CCC'], ['CCO', 'OCC', 'C(C)O'], ['CCO', 'OCC', 'C(O)C'], ['COC', 'O(C)C'], ['O(C)C', 'COC'], 'CCC11', ['C1CC1'], ['C1OC1', 'O1CC1', 'C1CO1'], 'C[M](C)=O', ['CC(=O)C', 'C(=O)(C)C', 'O=C(C)C', 'C(C)(C)=O'], 'CC(=O)[M]', 'CO(=C)C', ['C(=O)(C)C', 'C(C)(=O)C', 'O=C(C)C', 'C(C)(C)=O', 'CC(C)=O', 'CC(=O)C'], 'CC[M]=O', 'CC(N)=', ['CC(=O)N', 'C(=O)(C)N', 'O=C(N)C', 'NC(=O)C', 'C(C)(N)=O', 'C(C)(=O)N', 'O=C(C)N'], 'C[M](=O)N', 'C(=O)N', 'CC(=O)N', ['CC(=O)N', 'NC(C)=O', 'CC(N)=O', 'C(=O)(N)C', 'O=C(