# 03 Graph convolutional network


## Data and processing

TODO:

In [1]:
import pyarrow.dataset as ds
import pyarrow.compute as pc

PATH_TRAIN_DATA = "../../../data/train.parquet"
PATH_TEST_DATA = "../../../data/test.parquet"
DATA = ds.dataset(source=PATH_TRAIN_DATA, format="parquet")
DATA_TEST = ds.dataset(source=PATH_TEST_DATA, format="parquet")

In [2]:
import random
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
import numpy.typing as npt

In [3]:
def split_indices(n_rows, train_split: float = 0.8) -> (npt.NDArray[np.uint64], npt.NDArray[np.uint64]):
    # Generate indices and shuffle them in place
    indices = np.arange(n_rows)
    np.random.shuffle(indices)

    # Split indices into training and validation sets
    train_size = int(n_rows * train_split)
    train_indices = indices[:train_size]
    val_indices = indices[train_size:]
    return train_indices, val_indices

N_NO_BIND = 293656924
N_BIND = 1589906
TRAIN_SPLIT = 0.8

train_indices_no_bind, valid_indices_no_bind = split_indices(
    N_NO_BIND, TRAIN_SPLIT
)
train_indices_bind, valid_indices_bind = split_indices(
    N_BIND, TRAIN_SPLIT
)

In [4]:
import torch

In [5]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [6]:
from torch.utils.data import Dataset

In [7]:
class MolDataset(Dataset):

    protein_seq = {
        "sEH": "TLRAAVFDLDGVLALPAVFGVLGRTEEALALPRGLLNDAFQKGGPEGATTRLMKGEITLSQWIPLMEENCRKCSETAKVCLPKNFSIKEIFDKAISARKINRPMLQAALMLRKKGFTTAILTNTWLDDRAERDGLAQLMCELKMHFDFLIESCQVGMVKPEPQIYKFLLDTLKASPSEVVFLDDIGANLKPARDLGMVTILVQDTDTALKELEKVTGIQLLNTPAPLPTSCNPSDMSHGYVTVKPRVRLHFVELGSGPAVCLCHGFPESWYSWRYQIPALAQAGYRVLAMDMKGYGESSAPPEIEEYCMEVLCKEMVTFLDKLGLSQAVFIGHDWGGMLVWYMALFYPERVRAVASLNTPFIPANPNMSPLESIKANPVFDYQLYFQEPGVAEAELEQNLSRTFKSLFRASDESVLSMHKVCEAGGLFVNSPEEPSLSRMVTEEEIQFYVQQFKKSGFRGPLNWYRNMERNWKWACKSLGRKILIPALMVTAEKDFVLVPQMSQHMEDWIPHLKRGHIEDCGHWTQMDKPTEVNQILIKWLDSDARNPPVVSKM",
        "BRD4": "NPPPPETSNPNKPKRQTNQLQYLLRVVLKTLWKHQFAWPFQQPVDAVKLNLPDYYKIIKTPMDMGTIKKRLENNYYWNAQECIQDFNTMFTNCYIYNKPGDDIVLMAEALEKLFLQKINELPTEETEIMIVQAKGRGRGRKETGTAKPGVSTVPNTTQASTPPQTQTPQPNPPPVQATPHPFPAVTPDLIVQTPVMTVVPPQPLQTPPPVPPQPQPPPAPAPQPVQSHPPIIAATPQPVKTKKGVKRKADTTTPTTIDPIHEPPSLPPEPKTTKLGQRRESSRPVKPPKKDVPDSQQHPAPEKSSKVSEQLKCCSGILKEMFAKKHAAYAWPFYKPVDVEALGLHDYCDIIKHPMDMSTIKSKLEAREYRDAQEFGADVRLMFSNCYKYNPPDHEVVAMARKLQDVFEMRFAKMPDE",
        "HSA": "DAHKSEVAHRFKDLGEENFKALVLIAFAQYLQQCPFEDHVKLVNEVTEFAKTCVADESAENCDKSLHTLFGDKLCTVATLRETYGEMADCCAKQEPERNECFLQHKDDNPNLPRLVRPEVDVMCTAFHDNEETFLKKYLYEIARRHPYFYAPELLFFAKRYKAAFTECCQAADKAACLLPKLDELRDEGKASSAKQRLKCASLQKFGERAFKAWAVARLSQRFPKAEFAEVSKLVTDLTKVHTECCHGDLLECADDRADLAKYICENQDSISSKLKECCEKPLLEKSHCIAEVENDEMPADLPSLAADFVESKDVCKNYAEAKDVFLGMFLYEYARRHPDYSVVLLLRLAKTYETTLEKCCAAADPHECYAKVFDEFKPLVEEPQNLIKQNCELFEQLGEYKFQNALLVRYTKKVPQVSTPTLVEVSRNLGKVGSKCCKHPEAKRMPCAEDYLSVVLNQLCVLHEKTPVSDRVTKCCTESLVNRRPCFSALEVDETYVPKEFNAETFTFHADICTLSEKERQIKKQTALVELVKHKPKATKEQLKAVMDDFAAFVEKCCKADDKETCFAEEGKKLVAASQAALGL",
    }

    def __init__(
        self,
        scanner_no_bind,
        scanner_bind,
        indices_no_bind,
        indices_bind,
        *args,
        **kwargs
    ):
        self.scanner_no_bind = scanner_no_bind
        self.scanner_bind = scanner_bind
        self.indices_no_bind = indices_no_bind
        self.indices_bind = indices_bind

    def get_protein_seq(self, key: str) -> str:
        return self.protein_seq[key]
    
    @staticmethod
    def clean_smiles(smiles: str) -> str:
        smiles = smiles.replace("[Dy]", "")
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            raise ValueError("Invalid SMILES string")
        mol = Chem.RemoveHs(mol)
        fragments = Chem.GetMolFrags(mol, asMols=True)
        largest_fragment = max(fragments, default=mol, key=lambda m: m.GetNumAtoms())
        AllChem.Compute2DCoords(largest_fragment)
        cleaned_smiles = Chem.MolToSmiles(largest_fragment, canonical=True)
        return cleaned_smiles

    @staticmethod
    def get_mol(smiles: str) -> Chem.rdchem.Mol:
        mol = Chem.MolFromSmiles(smiles)
        mol = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol)
        AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
        return mol
    
    def get_sample(self, idx, kind):
        if kind not in ("bind", "no-bind"):
            raise ValueError("kind must be `bind` or `no-bind`")
        
        record = None
        if kind == "bind":
            record = self.scanner_bind.take([idx]).to_pydict()
        else:
            record = self.scanner_no_bind.take([idx]).to_pydict()

        smiles = record['molecule_smiles'][0]
        protein_seq = record['protein_name'][0]
        label = record['binds'][0]

        smiles = self.clean_smiles(smiles)

        amino_acids = self.get_protein_seq(protein_seq)
        return smiles, amino_acids, label

    def __getitem__(self, bind_idx: int) -> (str, str, int):
        bind_data = self.get_sample(bind_idx, kind="bind")
        # TODO: Fix random choice to select one with the same type of protein.
        no_bind_idx = random.choice(self.indices_no_bind)
        no_bind_data = self.get_sample(no_bind_idx, kind="no-bind")
        return bind_data, no_bind_data

    def __len__(self):
        return len(self.indices_bind)

In [8]:
def featurize_atoms(mol):
    feats = []
    for atom in mol.GetAtoms():
        atom_type = atom.GetAtomicNum()
        degree = atom.GetDegree()
        formal_charge = atom.GetFormalCharge()
        hybridization = atom.GetHybridization()
        is_aromatic = atom.GetIsAromatic()

        if hybridization == Chem.HybridizationType.SP:
            hybridization = 1
        elif hybridization == Chem.HybridizationType.SP2:
            hybridization = 2
        elif hybridization == Chem.HybridizationType.SP3:
            hybridization = 3
        else:
            hybridization = 0

        feats.append([
            atom_type,
            degree,
            formal_charge,
            hybridization,
            int(is_aromatic)
        ])
    return {'atomic': torch.tensor(feats).float()}

def featurize_edges(mol, add_self_loop=False):
    bond_types = {
        Chem.BondType.SINGLE: 1,
        Chem.BondType.DOUBLE: 2,
        Chem.BondType.TRIPLE: 3,
        Chem.BondType.AROMATIC: 4
    }

    src = []
    dst = []
    feats = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bond_type = bond_types.get(bond.GetBondType(), 0)
        is_conjugated = bond.GetIsConjugated()
        is_in_ring = bond.IsInRing()

        src.append(i)
        dst.append(j)
        src.append(j)
        dst.append(i)

        feats.append([bond_type, int(is_conjugated), int(is_in_ring)])
        feats.append([bond_type, int(is_conjugated), int(is_in_ring)])

    if add_self_loop:
        num_atoms = mol.GetNumAtoms()
        for i in range(num_atoms):
            src.append(i)
            dst.append(i)
            feats.append([0, 0, 0])

    return {'src': torch.tensor(src).long(), 'dst': torch.tensor(dst).long(), 'bond': torch.tensor(feats).float()}


In [9]:
import os
os.environ["DGLBACKEND"] = "pytorch"
from dgllife.utils import mol_to_complete_graph

In [10]:
def collate_fn(batch):
    graphs_bind = []
    graphs_no_bind = []
    labels_bind = []
    labels_no_bind = []
    for item in batch:
        if item is None:
            continue
        (smiles_bind, protein_seq_bind, label_bind), (smiles_no_bind, protein_seq_no_bind, label_no_bind) = item

        try:
            graph_bind = mol_to_complete_graph(
                MolDataset.get_mol(smiles_bind),
                node_featurizer=featurize_atoms,
            )
            graph_no_bind = mol_to_complete_graph(
                MolDataset.get_mol(smiles_no_bind),
                node_featurizer=featurize_atoms,
            )

            graphs_bind.append(graph_bind)
            graphs_no_bind.append(graph_no_bind)
            labels_bind.append(label_bind)
            labels_no_bind.append(label_no_bind)

        except Exception as e:
            print(f"Error processing molecule: {e}")
            continue

    return dgl.batch(graphs_bind).to(device), torch.tensor(labels_bind).to(device), dgl.batch(graphs_no_bind).to(device), torch.tensor(labels_no_bind).to(device)


In [11]:
scanner_no_bind = DATA.scanner(filter=(pc.field("binds") == 0))
scanner_bind = DATA.scanner(filter=(pc.field("binds") == 1))

train_data = MolDataset(scanner_no_bind, scanner_bind, train_indices_no_bind, train_indices_bind)
valid_data = MolDataset(scanner_no_bind, scanner_bind, valid_indices_no_bind, valid_indices_bind)


In [12]:
from torch.utils.data import DataLoader

In [13]:
train_loader = DataLoader(train_data, batch_size=2, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(valid_data, batch_size=24, shuffle=False, collate_fn=collate_fn)


## Model

TODO:

In [14]:
import torch
import torch.nn as nn
import dgl
from dgl.nn.pytorch import GraphConv

In [15]:
class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, out_feats):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, out_feats)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = torch.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')

## Training

In [16]:
in_feats = 5  # Number of features per atom
h_feats = 16
out_feats = 8  # Embedding dimension

model = GCN(in_feats, h_feats, out_feats).to(device)

In [17]:
def contrastive_loss(embeddings_bind, embeddings_no_bind, margin=1.0):
    distance = torch.nn.functional.pairwise_distance(embeddings_bind, embeddings_no_bind)
    loss = torch.mean((1 - labels) * torch.pow(distance, 2) + labels * torch.pow(torch.clamp(margin - distance, min=0.0), 2))
    return loss

In [18]:
import torch.optim as optim

In [19]:
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [21]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_idx, batch in enumerate(train_loader, 1):
        graphs_bind, labels_bind, graphs_no_bind, labels_no_bind = batch
        
        # Move data to the appropriate device
        graphs_bind = graphs_bind.to(device)
        labels_bind = labels_bind.to(device)
        graphs_no_bind = graphs_no_bind.to(device)
        labels_no_bind = labels_no_bind.to(device)
        
        # Forward pass
        embeddings_bind = model(graphs_bind, graphs_bind.ndata['atomic'].to(device))
        embeddings_no_bind = model(graphs_no_bind, graphs_no_bind.ndata['atomic'].to(device))
        
        # Compute loss
        loss = contrastive_loss(embeddings_bind, embeddings_no_bind)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

        # Print batch progress
        if batch_idx % 10 == 0:  # Change the frequency as needed
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}')
        
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {total_loss/len(train_loader):.4f}')


DGLError: [12:06:59] /opt/dgl/src/runtime/c_runtime_api.cc:82: Check failed: allow_missing: Device API cuda is not enabled. Please install the cuda version of dgl.
Stack trace:
  [bt] (0) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x67) [0x7b2161f621b7]
  [bt] (1) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(dgl::runtime::DeviceAPIManager::GetAPI(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool)+0x2a5) [0x7b21623bae85]
  [bt] (2) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(dgl::runtime::DeviceAPI::Get(DGLContext, bool)+0x1ea) [0x7b21623b789a]
  [bt] (3) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(dgl::runtime::NDArray::Empty(std::vector<long, std::allocator<long> >, DGLDataType, DGLContext)+0x130) [0x7b21623d0e60]
  [bt] (4) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(dgl::runtime::NDArray::CopyTo(DGLContext const&) const+0xb5) [0x7b2162405c45]
  [bt] (5) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(dgl::UnitGraph::CopyTo(std::shared_ptr<dgl::BaseHeteroGraph>, DGLContext const&)+0x1e7) [0x7b21624f7e27]
  [bt] (6) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(dgl::HeteroGraph::CopyTo(std::shared_ptr<dgl::BaseHeteroGraph>, DGLContext const&)+0xfa) [0x7b2162411a0a]
  [bt] (7) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(+0x6226d6) [0x7b21624226d6]
  [bt] (8) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(DGLFuncCall+0x4c) [0x7b21623b9e2c]

