In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from rdkit import Chem
from rdkit.Chem import inchi
from tqdm import tqdm
from time import sleep
from tqdm.notebook import tqdm

from chembl_structure_pipeline import standardizer as ChEMBL_standardizer
from papyrus_structure_pipeline import standardize



In [None]:
#Load the raw Cornelissen et al (2021) data
df_raw = pd.read_csv('Cornelissen_master_file.tsv', delimiter='\t')

In [None]:
#Change column names
df_raw.rename(columns={'Status_Influx': 'status_influx',
                       'Status_Efflux': 'status_efflux',
                       'Status_PAMPA': 'status_pampa',
                       'Status_BBB': 'status_bbb'}, inplace=True)

In [None]:
#Papyrus Standardization

def create_sd_smiles(sd_mol):
    try:
        standardized_smiles =  Chem.MolToSmiles(sd_mol)
        return standardized_smiles
    except Exception as e:
        print(f"An sd_smiles error occurred: {str(e)}")
        return None
    
#Create InChI keys from standardized molecules
def mol_to_inchi_key(sd_mol):
    if sd_mol is not None:

        inchi_str = inchi.MolToInchi(sd_mol)
        inchi_key = inchi.InchiToInchiKey(inchi_str)
    else:
        inchi_key = None   
    return inchi_key

def standardize_molecule(mol):
    standardized_mol =  standardize(mol,raise_error=False )
    return standardized_mol

#Standardize 

def standardize_workflow(df_raw):
    for i in range(0,len(df_raw)):
        smiles =df_raw.at[i,'ParentSmiles']
        mol = Chem.MolFromSmiles(smiles)
        sd_mol =  standardize_molecule(mol)
        sd_smiles = create_sd_smiles(sd_mol)
        sd_inchi_key = mol_to_inchi_key(sd_mol)
        df_raw.at[i,'papyrus_SMILES'] = sd_smiles
        df_raw.at[i,'papyrus_inchi_key'] = sd_inchi_key

    print(f'df length after standardization: {len(df_raw)}')

    #Save the dataset with standardized information
    df_raw.to_csv('datasets/cornelissen_all_papyrus_standardized.csv', index=False)

    return df_raw

In [None]:
#Check for missing inchi key

def missing_inchi(df_raw):
    smiles_nan = df_raw['papyrus_SMILES'].isna().sum()
    inchikey_nan =df_raw['papyrus_inchi_key'].isna().sum()
    print(f'DB length: {len(df_raw)},        SMILES nan: {smiles_nan},        inchi key nan: {inchikey_nan}')

    #Remove rows with missing inchikey
    df_valid_inchi= df_raw[((df_raw['papyrus_inchi_key'].notna()))]
    print('-----remove missing inchikey----')
    print(f'updated length: {len(df_valid_inchi)}')

    return df_valid_inchi

In [None]:
#Create connectivity inchi column

def inchi_first_part(inchi):
    return inchi.split('-')[0]

def create_connectivity_inchi(df):
    df['inchi_connectivity'] = df['papyrus_inchi_key'].apply(inchi_first_part)
    
    return df

In [None]:
#Filter dataset for transport information

def filter_transport(df,t):
    status_col = f'status_{t}'
    df = df.dropna(subset=[status_col])
    df = df.reset_index(drop=True)

    return df

In [None]:
#Check for duplicates

def remove_duplicates(df,t):
    print(f'length: {len(df)}')
    inchi_un = df['inchi_connectivity'].nunique()
    print(f'unique_inchi: {inchi_un}')

    status_col=f"status_{t}"
    unique_counts = df.groupby('inchi_connectivity')[status_col].nunique()
    duplicates_diff_class = unique_counts[unique_counts > 1].index

    print(f'Contradicting duplicates: {len(duplicates_diff_class)}')

    #Remove duplicates
    df = df[~(df['inchi_connectivity'].isin(duplicates_diff_class))]
    print(len(df))
    print(df['inchi_connectivity'].nunique())

    df=df.drop_duplicates(subset=['inchi_connectivity'], keep="first").reset_index(drop=True)
    #df['inchi_stereo'].values_counts
    
    df.to_csv('datasets/cornelissen_{t}_no_duplicates.csv', index=False)

    return df

In [None]:
#Run all the preprocessing to create dataset per transport
df = df_raw

transports =["influx",'efflux','pampa','bbb']

df_sd = standardize_workflow(df)        #Standardize molecules
df_valid = missing_inchi(df_sd)         #Check for missing inchi key
df_connectivity_inchi = create_connectivity_inchi(df_valid) #Create connectivity inchi column

#Create separate datasets per transport mechanism
for t in transports:
    df_transport_filter = filter_transport(df_connectivity_inchi,t)     #filter for transport mechanism

    df_no_duplicates = remove_duplicates(df_connectivity_inchi,t)       #remove duplicates
    file_name = f'datasets/cornelissen_{t}.csv'                         #provide file name
    df_no_duplicates.to_csv(file_name)                              #save datasets into csv

Test-train split

In [None]:
#Rename column

def rename_col(df):
    df.rename(columns={'TestValTrain_Influx': 'TrainTestVal_influx',
                       'TestValTrain_Efflux': 'TrainTestVal_efflux',
                       'TestValTrain_PAMPA': 'TrainTestVal_pampa',
                       'TestValTrain_BBB': 'TrainTestVal_bbb'}, inplace=True)
    return df



In [None]:
#Filtered dataset: to keep only Train, or only Test/Val compounds

def filter_dataset(df,dataset,transport):
    filter_col=f"TrainTestVal_{transport}"

    if dataset=="train":
        filtered_df = df[df[filter_col] == 'Train']
        
    elif dataset == "testval":
        filtered_df = df[df[filter_col].isin(['Test', 'Val'])]

    return filtered_df

In [None]:
#Run the train-test split for all transport

transports = ['influx','efflux','pampa','bbb']
datasets = ['testval',"train"]

df_raw = df_no_duplicates

for t in transports:
    for dataset in datasets:
        df = rename_col(df_raw)         #Rename columns
        df_filtered = filter_dataset(df,dataset,t)  #Filter for transport, and dataset type

        save_file=f'datasets/cornelissen_{t}_{dataset}_raw.csv' #Define raw file name
        df_filtered.to_csv(save_file,index=True)        #Save rawd ataset into csv

        print(f'{t} - {dataset}: {len(df_filtered)}')       #print info for process check

        stat_col =f'status_{t}'         
        nan_values = df_filtered[stat_col].isna().any()   #Check that only rows with such transport information are present

        if nan_values:
            print(f'{stat_col} in {dataset} dataset contains NaN values.')
        else:
            print(f'{stat_col} in {dataset} dataset contains NO NaN values.')