# 01 Example

Here is a starter Jupyter notebook to get started.


In [6]:
import pyarrow.dataset as ds
from loguru import logger
path_train_data = "../../../data/train.parquet"
data_train = ds.dataset(source=path_train_data, format="parquet")

data_train.head(num_rows=5)


pyarrow.Table
id: int64
buildingblock1_smiles: string
buildingblock2_smiles: string
buildingblock3_smiles: string
molecule_smiles: string
protein_name: string
binds: int64
----
id: [[0,1,2,3,4]]
buildingblock1_smiles: [["C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21","C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21","C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21","C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21","C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21"]]
buildingblock2_smiles: [["C#CCOc1ccc(CN)cc1.Cl","C#CCOc1ccc(CN)cc1.Cl","C#CCOc1ccc(CN)cc1.Cl","C#CCOc1ccc(CN)cc1.Cl","C#CCOc1ccc(CN)cc1.Cl"]]
buildingblock3_smiles: [["Br.Br.NCC1CCCN1c1cccnn1","Br.Br.NCC1CCCN1c1cccnn1","Br.Br.NCC1CCCN1c1cccnn1","Br.NCc1cccc(Br)n1","Br.NCc1cccc(Br)n1"]]
molecule_smiles: [["C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1","C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1","C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1","C#

In [7]:
#filter using scanner to get the rows where the protein_name is BRD4
data_brd4 = data_train.scanner(filter=ds.field("protein_name") == "BRD4")

data_brd4=data_brd4.to_table().to_pandas()
#find the number of rows in the dataset
logger.info(f"Number of compounds targeting BRD4 protein {len(data_brd4)}")
#find the rows where the binds is 1
data_brd4_binds = data_brd4[data_brd4["binds"]==1]
logger.info(f"Number of compounds targeting BRD4 protein and binds is 1 {len(data_brd4_binds)}")


[32m2024-06-11 09:55:32.130[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mNumber of compounds targeting BRD4 protein 98415610[0m
[32m2024-06-11 09:55:32.250[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mNumber of compounds targeting BRD4 protein and binds is 1 456964[0m


In [8]:
#filter using scanner to get the rows where the protein_name is protein_name: HSA

data_hsa = data_train.scanner(filter=ds.field("protein_name") == "HSA")

data_hsa=data_hsa.to_table().to_pandas()
#find the number of rows in the dataset
logger.info(f"Number of compounds targeting HSA protein {len(data_hsa)}")
#find the rows where the binds is 1
data_hsa_binds = data_hsa[data_hsa["binds"]==1]
logger.info(f"Number of compounds targeting HSA protein and binds is 1 {len(data_hsa_binds)}")

[32m2024-06-11 09:56:27.576[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mNumber of compounds targeting HSA protein 98415610[0m
[32m2024-06-11 09:56:27.770[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m10[0m - [1mNumber of compounds targeting HSA protein and binds is 1 408410[0m


In [9]:
data_seh = data_train.scanner(filter=ds.field("protein_name") == "sEH")

data_seh=data_seh.to_table().to_pandas()
#find the number of rows in the dataset
logger.info(f"Number of compounds targeting sEH protein {len(data_seh)}")
#find the rows where the binds is 1
data_seh_binds = data_seh[data_seh["binds"]==1]
logger.info(f"Number of compounds targeting sEH protein and binds is 1 {len(data_seh_binds)}")

[32m2024-06-11 09:57:24.692[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1mNumber of compounds targeting sEH protein 98415610[0m
[32m2024-06-11 09:57:24.905[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1mNumber of compounds targeting sEH protein and binds is 1 724532[0m


In [20]:
import dgl
import torch
from rdkit import Chem
from rdkit.Chem import AllChem


def clean_smi(smi: str | list):
    r""" Clean a SMILES string by removing salts and fragments.
    Parameters
    ----------
    smi : str | list
        The SMILES string for a molecule. or a list of SMILES strings
    Returns
    -------
    str | list
        The cleaned SMILES string.
    """
    if isinstance(smi, list):
        return [clean_smi(s) for s in smi]
    # Remove [Dy] from smiles
    smi = smi.replace("[Dy]", "")

    # Convert SMILES to a RDKit molecule object
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        raise ValueError("Invalid SMILES string")
    
    # Remove any salts or fragments
    mol = Chem.RemoveHs(mol)  # Remove explicit hydrogens
    fragments = Chem.GetMolFrags(mol, asMols=True)
    
    # Keep the largest fragment
    largest_fragment = max(fragments, default=mol, key=lambda m: m.GetNumAtoms())
    
    # Standardize the molecule
    AllChem.Compute2DCoords(largest_fragment)  # Compute 2D coordinates
    
    # Convert the molecule back to a canonical SMILES string
    cleaned_smiles = Chem.MolToSmiles(largest_fragment, canonical=True)
    return cleaned_smiles


def smiles_to_dgl_graph(smiles: str |list):
    r""" Convert a SMILES string to a DGLGraph.
    Parameters
    ----------
    smiles : str | list
        The SMILES string for a molecule. or a list of SMILES strings
    Returns
    -------
    DGLGraph
        A DGLGraph object for the molecule.
    """
    if isinstance(smiles, list):
        return [smiles_to_dgl_graph(s) for s in smiles]
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    # Node features
    atom_features = []
    for atom in mol.GetAtoms():
        atom_features.append([
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetFormalCharge(),
            atom.GetHybridization(),
            atom.GetIsAromatic(),
            atom.GetTotalNumHs()
        ])
    
    # Edge features and adjacency list
    src, dst = [], []
    bond_features = []
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        src.append(start)
        dst.append(end)
        bond_features.append([
            bond.GetBondType(),
            bond.GetIsConjugated(),
            bond.IsInRing()
        ])
    
    g = dgl.graph((src, dst))
    g.ndata['h'] = torch.tensor(atom_features, dtype=torch.float)
    g.edata['h'] = torch.tensor(bond_features, dtype=torch.float)
    
    return g

# Placeholder example NOT TESTED
smiles = data_brd4['molecule_smiles'].tolist()
cleaned_smiles = clean_smi(smiles)
mol_graph = smiles_to_dgl_graph(cleaned_smiles)
mol_graph


KeyboardInterrupt: 