# install rdkit via pip

!pip install rdkit

# imports

In [None]:
import os

import mlx.core as mx
import scipy as sp
import numpy as np
import pandas as pd
from mlx_graphs.data import GraphData
from mlx_graphs.datasets.dataset import Dataset
from mlx_graphs.datasets.utils import download
from mlx_graphs.utils.transformations import to_sparse_adjacency_matrix
from typing import Tuple
from typing import Optional

from rdkit import Chem
from rdkit.Chem import Lipinski
from rdkit.Chem import rdMolDescriptors


# get the ESOL source file

In [None]:
download('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv', path="ESOL.csv")

# retreive SMILES and target LogS columns only

In [None]:
data = pd.read_csv('ESOL.csv')
data = data[['smiles','measured log solubility in mols per litre']]
data.columns = ['Smiles','LogS']
data.head()

In [None]:
# 



In [None]:

# my c++ official RDkit code to get Atom basic features 
def _atomFeatures(atom: Chem.Atom,
                 ) -> np.array:
    return rdMolDescriptors.GetAtomFeatures(atom.GetOwningMol(), atom.GetIdx())

# AdjacencyMatrix extractoin and formating
def _compute_adjacency(
    molecule: Chem.Mol,
    dtype: np.dtype = np.int32,
) -> Tuple[np.ndarray, np.ndarray]:
    'Computes adjacency matrix from an RDKit molecule object.'

    adjacency = Chem.GetAdjacencyMatrix(molecule)

    return adjacency.astype(dtype)

# Get Edge_features
def generate_bond_features(
    mol: Chem.Mol)-> Tuple[np.ndarray, np.ndarray]:
    # Dictionaries for mapping bond types and stereochemistry to integers
    bond_type_dict = {'SINGLE': 1, 'DOUBLE': 2, 'TRIPLE': 3, 'AROMATIC': 4}
    bond_stereo_dict = {'STEREONONE': 0,'STEREOANY': 1, 'STEREOE': 2, 'STEREOZ': 3}
    
    # Calculate rotatable bonds
    rotbonds = Lipinski._RotatableBonds(mol)
    
    # Initialize a list to store bond features
    bond_features = []
    
    for bond in mol.GetBonds():
        # Get the owning molecule (not necessary if `mol` is already given)
        mol = bond.GetOwningMol()
        
        # Get sorted atom indices for the bond
        atom_indices = tuple(sorted([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]))
        
        # Determine if the bond is rotatable
        is_rotatable = atom_indices in rotbonds
        
        # Get the bond's features
        bond_stereo_feature = bond_stereo_dict[bond.GetStereo().name]
        bond_type_feature = bond_type_dict[bond.GetBondType().name]
        is_conjugated = bond.GetIsConjugated()
        
        # Append the features as a tuple to the bond_features list
        bond_features.append((bond_stereo_feature, bond_type_feature, is_conjugated, is_rotatable))
    
    return bond_features

In [None]:
# generate the dataset

In [None]:
dataset = []
for i,row in data.iterrows():
    try:
        # Get an RDkit Molecule object from Smiles
        mol = Chem.MolFromSmiles(row.Smiles.strip())
        # Get the edge_indexes (we will overwrite the edge_features)
        edge_index, edge_features = to_sparse_adjacency_matrix(mx.array(_compute_adjacency(mol)))
        # Get Node features
        atf = np.zeros((mol.GetNumAtoms(),49))
        for i, atom in enumerate(mol.GetAtoms()):
            atf[i,:] = _atomFeatures(atom)
        node_features  = mx.array(atf)
        # Get Edge features
        edge_features = mx.array(generate_bond_features(mol))
        # Get the target : "LogS"
        label =  mx.array(row.LogS)
        # append the list of GraphData objects
        dataset.append(
            GraphData(
                
                edge_index=edge_index,
                node_features=node_features,
                edge_features=edge_features,
                node_labels=label, 
                        )
                )
    except:
        # the "C" is a single atom molecule so it is expected to be an exception in this process!
        print(row)

In [None]:
dataset[0]

In [None]:
from mlx_graphs.loaders import Dataloader

train_dataset = dataset[:150]
test_dataset = dataset[150:]


BATCH_SIZE = 64

train_loader = Dataloader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = Dataloader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

for batch in train_loader:
    print(f"\nGraph batch of size {len(batch)}")
    print(batch)
    print(batch.batch_indices)

In [None]:
# Some useful properties
print("Dataset attributes")
print("-" * 20)
print(f"Number of graphs: {len(dataset)}")
print(f"Number of node features: {dataset.num_node_features}")
print(f"Number of edge features: {dataset.num_edge_features}")
print(f"Number of graph features: {dataset.num_graph_features}")
print(f"Number of graph classes to predict: {dataset.num_graph_classes}\n")

# Statistics of the dataset
stats = defaultdict(list)
for g in dataset:
    stats["Mean node degree"].append(g.num_edges / g.num_nodes)
    stats["Mean num of nodes"].append(g.num_nodes)
    stats["Mean num of edges"].append(g.num_edges)

print("Dataset stats")
print("-" * 20)
for k, v in stats.items():
    mean = mx.mean(mx.array(v)).item()
    print(f"{k}: {mean:.2f}")