Part 1: Import libraries and load data

In [1]:
# %%
# Import libraries
import pandas as pd
from rdkit import Chem
import os
import dgl
from dgl.data.utils import save_graphs, load_graphs
import torch
from dgllife.utils import CanonicalAtomFeaturizer, CanonicalBondFeaturizer
from dgllife.utils import mol_to_bigraph

# %%
# Load data
data_model1 = pd.read_pickle('../data_mvi/ml_datasets_set2/model1_rna_bin_non_rna_bin/data.pkl')
data_model2 = pd.read_pickle('../data_mvi/ml_datasets_set2/model2_rna_bin_protein_bin/data.pkl')
data_model3 = pd.read_pickle('../data_mvi/ml_datasets_set2/model3_binders_nonbinders/data.pkl')

# %%
# Show shapes and label value counts
print(data_model1.shape, data_model2.shape, data_model3.shape)
print(data_model1['label'].value_counts(), data_model2['label'].value_counts(), data_model3['label'].value_counts())

# %%
# Show columns
print(data_model1.columns)


(3952, 6) (3952, 6) (3952, 6)
label
0    1976
1    1976
Name: count, dtype: int64 label
0    1976
1    1976
Name: count, dtype: int64 label
1    1976
0    1976
Name: count, dtype: int64
Index(['mol', 'source', 'smiles', 'ecfp6', 'bit_info_map', 'label'], dtype='object')


Part 2: Define utility functions

In [2]:
# %%
# Define utility functions
def check_mol_consistency(data):
    for idx, row in data.iterrows():
        smiles = row['smiles']
        mol = row['mol']
        if mol is None or mol.GetNumAtoms() == 0:
            print(f"Invalid Mol object at index {idx}: {smiles}")
            continue

        mol_from_smiles = Chem.MolFromSmiles(smiles)
        if mol_from_smiles is None:
            print(f"Invalid SMILES string at index {idx}: {smiles}")
            continue

def repair_and_balance_mol_objects(df):
    invalid_mols = []
    labels_to_remove = []

    for idx, row in df.iterrows():
        smiles = row['smiles']
        mol = row['mol']
        label = row['label']

        if mol is None or mol.GetNumAtoms() == 0:
            print(f"Removing row at index {idx}: {smiles}")
            invalid_mols.append(idx)

            if label == 0:
                labels_to_remove.append(1)

    # Remove rows with invalid mol objects
    df.drop(invalid_mols, inplace=True)

    # Balance the dataset by removing rows with label 1
    if labels_to_remove:
        df_label_1 = df[df['label'] == 1]
        remove_indices = df_label_1.sample(len(labels_to_remove)).index
        df.drop(remove_indices, inplace=True)

    return df


Part 3: Repair and balance datasets

In [3]:
# %%
# Repair and balance datasets
data_model1 = repair_and_balance_mol_objects(data_model1)
data_model2 = repair_and_balance_mol_objects(data_model2)
data_model3 = repair_and_balance_mol_objects(data_model3)


Removing row at index 2181: 
Removing row at index 3868: 


Part 4: Define data processing functions

In [4]:
# %%
# Define data processing functions
def process_dataset(data, output_folder):
    data = repair_and_balance_mol_objects(data)
    check_mol_consistency(data)

    data.reset_index(drop=True, inplace=True)
    data = data.sample(frac=1).reset_index(drop=True)

    disconnected_mols = []

    for idx, row in data.iterrows():
        smiles = row['smiles']
        mol = row['mol']
        source = row['source']

        mol_from_smiles = Chem.MolFromSmiles(smiles)
        if mol_from_smiles is None:
            print(f"Invalid SMILES string at index {idx}: {smiles}")
            continue

        if '.' in smiles:
            print(f"Disconnected components in SMILES string at index {idx}: {smiles}, from source {source}")
            continue

        if mol is None or mol.GetNumAtoms() == 0:
            print(f"Invalid Mol object at index {idx}: {smiles}")
            continue

        if mol is not None:
            num_components = Chem.GetMolFrags(mol, asMols=False, sanitizeFrags=False)
            if len(num_components) > 1:
                print(f"Disconnected components in molecule at index {idx}: {smiles}, from source {source}")
                disconnected_mols.append(row)
                data.drop(idx, inplace=True)

    disconnected_mols_df = pd.DataFrame(disconnected_mols)
    disconnected_mols_df.to_json(os.path.join(output_folder, 'disconnected_mols_df.json'))

    if disconnected_mols_df.shape[0] < 1:
        print("No disconnected molecules found!")

    graphs, graph_labels = create_graphs_from_dataframe(data)
    save_graphs(os.path.join(output_folder, "graphs.bin"), graphs, graph_labels)

def create_graphs_from_dataframe(df):
    graphs = []
    labels = []

    for _, row in df.iterrows():
        mol = Chem.MolFromSmiles(row['smiles'])
        label = row['label']

        graph = mol_to_bigraph(
            mol,
            node_featurizer=CanonicalAtomFeaturizer(),
            edge_featurizer=CanonicalBondFeaturizer(self_loop=True),
            explicit_hydrogens=False,
            add_self_loop=True
        )

        graphs.append(graph)
        labels.append(label)

    label_tensor = torch.tensor(labels).unsqueeze(-1)
    graph_labels = {'labels': label_tensor}

    return graphs, graph_labels


Part 5: Process datasets and save graphs

In [5]:
# %%
# Define paths for the datasets
data_paths = {
    'model1_rna_bin_non_rna_bin': data_model1,
    'model2_rna_bin_protein_bin': data_model2,
    'model3_binders_nonbinders': data_model3
}

# %%
# Iterate through each dataset
for model_name, data in data_paths.items():
    output_folder = f'data_mvi/data_for_ml/dataset_set2/{model_name}/graphs.bin'

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    process_dataset(data, output_folder)


No disconnected molecules found!
No disconnected molecules found!
No disconnected molecules found!
