In [None]:
import pandas as pd
from rdkit import Chem
from rdkit.Chem.Descriptors import ExactMolWt
from rdkit.Chem import Draw
from IPython.display import display, Image
import io

In [None]:
def remove_nan_smiles(df):
    print("--- REMOVING NAN SMILES ---")
    orig_len = len(df)
    print(f"\tOriginal size before removing NaN SMILES: {orig_len}")
    df = df.dropna(subset=['IsomericSMILES'])
    print(f"\tNumber of NaN SMILES: {orig_len - len(df)}")
    print(f"\tSize after removing NaN SMILES: {len(df)}")
    return df

def contains_undesirable_elements(smiles, undesirable_elements_list):
    undesirable_smarts = [Chem.MolFromSmarts(f'[{element}]') for element in undesirable_elements_list]
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        print(f"\tWarning: Could not parse SMILES: {smiles}")
        return True
    for i, smarts in enumerate(undesirable_smarts):
        if mol.HasSubstructMatch(smarts):
            print(f"\tMatch found: SMILES {smiles} contains {undesirable_elements_list[i]}")
            return True
    return False

def remove_undesirable_elements(df):
    print("--- REMOVING UNDESIRABLE ELEMENTS ---")
    print(f"\tOriginal size before removing undesirable elements: {len(df)}")
    undesirable_elements_list = ["He", "Na", "Mg", "Al", "Si", "K", "Ca", "Ti", "V", "Cr", "Fe", "Co", "Cu", "Zn", "Bi"]
    print(f"\tPurging undesirable elements in {undesirable_elements_list}")
    df['contains_undesirable_elements'] = df.apply(lambda row: contains_undesirable_elements(row['IsomericSMILES'], undesirable_elements_list), axis=1)
    print(f"\tTotal undesirable element SMILES removed: {len(df[df['contains_undesirable_elements'] == True])}")
    df_filtered = df[~df['contains_undesirable_elements']].copy()
    df_filtered.drop(columns=['contains_undesirable_elements'], inplace=True)
    return df_filtered


def smiles_to_inchi_key(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        return Chem.inchi.MolToInchiKey(mol)
    return None

def remove_duplicate_smiles(df):
    print("--- REMOVING DUPLICATES ---")
    print(f"\tOriginal size before removing duplicates: {len(df)}")
    df['InChIKey'] = df['IsomericSMILES'].apply(smiles_to_inchi_key)
    duplicates = df[df.duplicated(subset='InChIKey', keep='first')]
    for smiles in duplicates['IsomericSMILES']:
        print(f"\tRemoving duplicate smiles: {smiles}")
    df.drop_duplicates(subset=['InChIKey'], keep='first', inplace=True)
    df.drop(columns=['InChIKey'], inplace=True)
    print(f"\tNumber of SMILES removed: {len(duplicates)}")
    print(f"\tSize after removing duplicates: {len(df)}")

    return df

def remove_small_molecules(df, threshold=5):
    print("--- REMOVING SMALL MOLECULES ---")
    print(f"\tSmall molecule threshold: {threshold}")
    print(f"\tOriginal size before removing small molecules: {len(df)}")
    def is_molecule_below_atom_count_threshold(smiles, threshold):
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return False
        
        # Add explicit hydrogens
        mol = Chem.AddHs(mol)
        
        # Count all atoms including hydrogens
        atom_count = mol.GetNumAtoms()
        
        if atom_count <= threshold:
            print(f"\tSMILES {smiles} has {atom_count} atoms, below threshold {threshold}")
            return True
        else:
            return False
    
    df['below_threshold'] = df['IsomericSMILES'].apply(is_molecule_below_atom_count_threshold, args=(threshold,))
    df_filtered = df[~df['below_threshold']].copy()
    df_filtered.drop(columns=['below_threshold'], inplace=True)
    print(f"\tNumber of small molecules removed: {len(df[df['below_threshold'] == True])}")
    print(f"\tSize before removing small molecules: {len(df_filtered)}")
    return df_filtered

def remove_non_carbon_molecules(df):
    print("--- REMOVING NON-CARBON MOLECULES ---")
    print(f"\tOriginal size before removing non-carbon molecules: {len(df)}")
    df['only_contains_non_carbon'] = df['IsomericSMILES'].apply(lambda x: 'C' not in x and 'c' not in x)
    non_carbon_molecules = df[df['only_contains_non_carbon']]['IsomericSMILES']
    for smiles in non_carbon_molecules:
        print(f"\tRemoving non-carbon molecule: {smiles}")
    df_filtered = df[~df['only_contains_non_carbon']].copy()
    df_filtered.drop(columns=['only_contains_non_carbon'], inplace=True)
    print(f"\tNumber of non-carbon molecules removed: {len(df[df['only_contains_non_carbon'] == True])}")
    print(f"\tSize after removing non-carbon molecules: {len(df_filtered)}")
    return df_filtered


def visualize_molecule(smiles):
    # Convert SMILES to RDKit molecule
    mol = Chem.MolFromSmiles(smiles)
    
    if mol is not None:
        # Generate a 2D depiction of the molecule
        img = Draw.MolToImage(mol)
        
        # Convert the image to bytes
        bio = io.BytesIO()
        img.save(bio, format='PNG')
        
        # Display the image
        display(Image(bio.getvalue()))
    else:
        print(f"\tInvalid SMILES: {smiles}")

def remove_molecular_weights(df, column_list, lower=20, higher=600):
    print(f"--- REMOVING MOLECULAR WEIGHTS BETWEEN {lower} AND {higher} ---")
    print(f"\tOriginal size before molecular weight thresholding: {len(df)}")
    df['rdkit'] = df.apply(lambda row: Chem.MolFromSmiles(row['IsomericSMILES']), axis=1)
    df['mw'] = df.apply(lambda row: ExactMolWt(row['rdkit']), axis=1)

    print(f"\tMolecules below MW threshold {lower}: {len(df[df['mw'] < lower])}")
    for _, row in df[df['mw'] < lower].iterrows():
        positive_columns = [col for col in column_list if row[col] == 1]
        print(f"\t\t{row['IsomericSMILES']}, {', '.join(positive_columns)}")
        #visualize_molecule(row['IsomericSMILES'])

    print(f"\tMolecules above MW threshold {higher}: {len(df[df['mw'] > higher])}")
    for _, row in df[df['mw'] > higher].iterrows():
        positive_columns = [col for col in column_list if row[col] == 1]
        print(f"\t\t{row['IsomericSMILES']}, {', '.join(positive_columns)}")
        #visualize_molecule(row['IsomericSMILES'])
        
    df_filtered = df[(df['mw'] >= lower) & (df['mw'] <= higher)].copy()
    df_filtered.drop(columns=['rdkit', 'mw'], inplace=True)
    print(f"\tSize after molecular weight thresholding: {len(df_filtered)}")
    return df_filtered


def contains_salts_and_charges(smiles):
    mol = Chem.MolFromSmiles(smiles)
    
    if mol is None:
        print(f"\tInvalid SMILES: {smiles}")
        return True

    if Chem.GetFormalCharge(mol) != 0:
        print(f"\t{smiles} has charges")
        return True
    return False

def contains_multimolecule(smiles):
    mol = Chem.MolFromSmiles(smiles)
    
    if mol is None:
        print(f"\tInvalid SMILES: {smiles}")
        return True
    
    if "." in smiles:
        print(f"\t{smiles} has multiple molecules")
        return True

    return False

def remove_salts_and_charges_and_multimolecule(df):
    print("--- REMOVING SALTS, CHARGED MOLECULES AND MULTIMOLECULES ---")
    print(f"\tOriginal size before removing salts, charges and multimolecules: {len(df)}")
    df['contains_salts_and_charges'] = df['IsomericSMILES'].apply(contains_salts_and_charges)
    print(f"\tTotal salts and charged SMILES removed: {len(df[df['contains_salts_and_charges'] == True])}")
    df_filtered = df[~df['contains_salts_and_charges']].copy()
    df_filtered['contains_multimolecule'] = df['IsomericSMILES'].apply(contains_multimolecule)
    print(f"\tTotal multimolecules removed: {len(df_filtered[df_filtered['contains_multimolecule'] == True])}") 
    df_filtered = df_filtered[~df_filtered['contains_multimolecule']].copy()
    df_filtered.drop(columns=['contains_salts_and_charges', 'contains_multimolecule'], inplace=True)
    print(f"\tFinal size after removing salts, charges and multimolecules: {len(df_filtered)}")
    return df_filtered


In [14]:
print(f"Dataset cleaning report for gs-lf")
df = pd.read_csv("curated_GS_LF_merged_4983.csv")

if 'nonStereoSMILES' in df.columns:
    df.rename(columns={'nonStereoSMILES': "IsomericSMILES"}, inplace=True)

descriptor_list = df.columns[2:].to_list()

df_filtered = remove_nan_smiles(df)
df_filtered = remove_undesirable_elements(df_filtered)
df_filtered = remove_duplicate_smiles(df_filtered)
df_filtered = remove_salts_and_charges_and_multimolecule(df=df_filtered)
df_filtered = remove_molecular_weights(df=df_filtered, column_list=descriptor_list, lower=20, higher=600)
df_filtered = remove_non_carbon_molecules(df_filtered)


print(f"--- DATASET CLEANING COMPLETE ---")
print(f"\tFinal size of dataframe: {len(df_filtered)}")

df_filtered.to_csv(f"gs-lf_combined.csv", index=False)

Dataset cleaning report for gs-lf
--- REMOVING NAN SMILES ---
	Original size before removing NaN SMILES: 4983
	Number of NaN SMILES: 0
	Size after removing NaN SMILES: 4983
--- REMOVING UNDESIRABLE ELEMENTS ---
	Original size before removing undesirable elements: 4983
	Purging undesirable elements in ['He', 'Na', 'Mg', 'Al', 'Si', 'K', 'Ca', 'Ti', 'V', 'Cr', 'Fe', 'Co', 'Cu', 'Zn', 'Bi']
	Match found: SMILES [Cl-].[K+] contains K
	Match found: SMILES O=C([O-])CN(CCN(CC(=O)[O-])CC(=O)[O-])CC(=O)[O-].[Ca+2].[Na+].[Na+] contains Na
	Match found: SMILES O=C([O-])CC(O)(CC(=O)[O-])C(=O)[O-].[Na+].[Na+].[Na+] contains Na
	Match found: SMILES CC(C)(CO)C(O)C(=O)NCCC(=O)[O-].CC(C)(CO)C(O)C(=O)NCCC(=O)[O-].[Ca+2] contains Ca
	Match found: SMILES O=C([O-])CC(O)C(=O)[O-].[Na+].[Na+] contains Na
	Match found: SMILES O=C([O-])CN(CCN(CC(=O)[O-])CC(=O)O)CC(=O)O.[Na+].[Na+] contains Na
	Match found: SMILES O=C([O-])CCC(=O)[O-].[Na+].[Na+] contains Na
	Match found: SMILES O=C([O-])C(O)C(O)C(O)C(O)CO.O=C(