# 02 Graph convolutional network


## Data and processing

TODO:

In [1]:
import pyarrow.dataset as ds
import pyarrow.compute as pc

path_train_data = "../../../data/train.parquet"
data = ds.dataset(source=path_train_data, format="parquet")

In [2]:
import random
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
import numpy.typing as npt

In [3]:
class DatasetManager:

    protein_seq = {
        "sEH": "TLRAAVFDLDGVLALPAVFGVLGRTEEALALPRGLLNDAFQKGGPEGATTRLMKGEITLSQWIPLMEENCRKCSETAKVCLPKNFSIKEIFDKAISARKINRPMLQAALMLRKKGFTTAILTNTWLDDRAERDGLAQLMCELKMHFDFLIESCQVGMVKPEPQIYKFLLDTLKASPSEVVFLDDIGANLKPARDLGMVTILVQDTDTALKELEKVTGIQLLNTPAPLPTSCNPSDMSHGYVTVKPRVRLHFVELGSGPAVCLCHGFPESWYSWRYQIPALAQAGYRVLAMDMKGYGESSAPPEIEEYCMEVLCKEMVTFLDKLGLSQAVFIGHDWGGMLVWYMALFYPERVRAVASLNTPFIPANPNMSPLESIKANPVFDYQLYFQEPGVAEAELEQNLSRTFKSLFRASDESVLSMHKVCEAGGLFVNSPEEPSLSRMVTEEEIQFYVQQFKKSGFRGPLNWYRNMERNWKWACKSLGRKILIPALMVTAEKDFVLVPQMSQHMEDWIPHLKRGHIEDCGHWTQMDKPTEVNQILIKWLDSDARNPPVVSKM",
        "BRD4": "NPPPPETSNPNKPKRQTNQLQYLLRVVLKTLWKHQFAWPFQQPVDAVKLNLPDYYKIIKTPMDMGTIKKRLENNYYWNAQECIQDFNTMFTNCYIYNKPGDDIVLMAEALEKLFLQKINELPTEETEIMIVQAKGRGRGRKETGTAKPGVSTVPNTTQASTPPQTQTPQPNPPPVQATPHPFPAVTPDLIVQTPVMTVVPPQPLQTPPPVPPQPQPPPAPAPQPVQSHPPIIAATPQPVKTKKGVKRKADTTTPTTIDPIHEPPSLPPEPKTTKLGQRRESSRPVKPPKKDVPDSQQHPAPEKSSKVSEQLKCCSGILKEMFAKKHAAYAWPFYKPVDVEALGLHDYCDIIKHPMDMSTIKSKLEAREYRDAQEFGADVRLMFSNCYKYNPPDHEVVAMARKLQDVFEMRFAKMPDE",
        "HSA": "DAHKSEVAHRFKDLGEENFKALVLIAFAQYLQQCPFEDHVKLVNEVTEFAKTCVADESAENCDKSLHTLFGDKLCTVATLRETYGEMADCCAKQEPERNECFLQHKDDNPNLPRLVRPEVDVMCTAFHDNEETFLKKYLYEIARRHPYFYAPELLFFAKRYKAAFTECCQAADKAACLLPKLDELRDEGKASSAKQRLKCASLQKFGERAFKAWAVARLSQRFPKAEFAEVSKLVTDLTKVHTECCHGDLLECADDRADLAKYICENQDSISSKLKECCEKPLLEKSHCIAEVENDEMPADLPSLAADFVESKDVCKNYAEAKDVFLGMFLYEYARRHPDYSVVLLLRLAKTYETTLEKCCAAADPHECYAKVFDEFKPLVEEPQNLIKQNCELFEQLGEYKFQNALLVRYTKKVPQVSTPTLVEVSRNLGKVGSKCCKHPEAKRMPCAEDYLSVVLNQLCVLHEKTPVSDRVTKCCTESLVNRRPCFSALEVDETYVPKEFNAETFTFHADICTLSEKERQIKKQTALVELVKHKPKATKEQLKAVMDDFAAFVEKCCKADDKETCFAEEGKKLVAASQAALGL",
    }

    def __init__(
        self, file_path: str, train_split: float = 0.8, *args, **kwargs
    ) -> None:
        self.data = ds.dataset(source=file_path, *args, **kwargs)
        self.train_split = train_split
        self.set_indices(train_split=train_split)
    
    def set_indices(
        self,
        n_no_bind: int = 293656924,
        n_bind: int = 1589906,
        train_split: float = 0.8
    ) -> None:
        self.train_indices_no_bind, self.valid_indices_no_bind = self._get_indices(
            n_no_bind, train_split
        )
        self.train_indices_bind, self.valid_indices_bind = self._get_indices(
            n_bind, train_split
        )
    
    @staticmethod
    def _get_indices(
        n_rows, train_split: float = 0.8
    ) -> (npt.NDArray[np.uint64], npt.NDArray[np.uint64]):
        # Generate indices and shuffle them in place
        indices = np.arange(n_rows)
        np.random.shuffle(indices)

        # Split indices into training and validation sets
        train_size = int(n_rows * train_split)
        train_indices = indices[:train_size]
        val_indices = indices[train_size:]
        return train_indices, val_indices

    
    def split_data(self, train_split: float = 0.8) -> None:
        self.scanner_no_bind = self.data.scanner(
            filter=(pc.field("binds") == 0)
        )
        self.scanner_bind = self.data.scanner(
            filter=(pc.field("binds") == 1)
        )

    def get_protein_seq(self, key: str) -> str:
        return self.protein_seq[key]
    
    @staticmethod
    def clean_smiles(smiles: str) -> str:
        smiles = smiles.replace("[Dy]", "")
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            raise ValueError("Invalid SMILES string")
        mol = Chem.RemoveHs(mol)
        fragments = Chem.GetMolFrags(mol, asMols=True)
        largest_fragment = max(
            fragments, default=mol, key=lambda m: m.GetNumAtoms()
        )
        AllChem.Compute2DCoords(largest_fragment)
        cleaned_smiles = Chem.MolToSmiles(
            largest_fragment, canonical=True
        )
        return cleaned_smiles
    
    @staticmethod
    def get_mol(smiles: str) -> Chem.rdchem.Mol:
        mol = Chem.MolFromSmiles(smiles)
        mol = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol)
        AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
        return mol
    
    def __getitem__(self, idx: int, kind: str = "bind") -> (str, str, int):

        if kind not in ("bind", "no-bind"):
            raise ValueError("kind must be `bind` or `no-bind`")
        
        if kind == "bind":
            record = self.scanner_bind[idx]
        else:
            record = self.scanner_no_bind[idx]
        smiles: str = record['molecule_smiles'].as_py()
        protein_seq: str = record['protein_name'].as_py()
        label: int = record['binds'].as_py()

        smiles = self.clean_smiles(smiles)
        if smiles is None:
            return None

        amino_acids = self.get_protein_seq(protein_seq)
        return smiles, amino_acids, label
    
    def get_contrastive_pair(self, idx: int) -> ((str, str, int), (str, str, int)):
        no_bind_idx = random.choice(self.train_indices_no_bind)
        bind_idx = random.choice(self.train_indices_bind)
        no_bind_sample = self.__getitem__(no_bind_idx, kind="no-bind")
        bind_sample = self.__getitem__(bind_idx, kind="bind")
        return no_bind_sample, bind_sample

## Model

TODO:

In [4]:
import torch
import torch.nn as nn

# # See https://github.com/pytorch/pytorch/pull/122616#issuecomment-2100569173
# torch.utils.data.datapipes.utils.common.DILL_AVAILABLE = torch.utils._import_utils.dill_available()

import dgl
from dgl.nn.pytorch import GraphConv

In [5]:
# Define the graph neural network
class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, h_feats)
        self.classify = nn.Linear(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = torch.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        hg = dgl.mean_nodes(g, 'h')
        return self.classify(hg)