# 03 Graph convolutional network


## Data and processing

Please refer to [the previous notebook](../01_molecular_features#data-and-processing) for explanations of the following code block.

In [1]:
import pyarrow.dataset as ds

PATH_TRAIN_DATA = "../../../data/train.parquet"

DATA = ds.dataset(source=PATH_TRAIN_DATA, format="parquet")

In [4]:
import torch

In [5]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [10]:
def featurize_ligand(smiles):
    mol = Chem.MolFromSmiles(smiles)
    mol = Chem.AddHs(mol)  # Add hydrogen atoms
    AllChem.EmbedMolecule(mol, randomSeed=42)  # Generate 3D coordinates

    # Node features
    node_features = []
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetFormalCharge(),
            atom.GetHybridization(),
            atom.GetTotalNumHs(),
            int(atom.GetIsAromatic()),
            int(atom.IsInRing()),
            int(atom.GetChiralTag() != Chem.ChiralType.CHI_UNSPECIFIED)
        ]
        node_features.append(features)

    # Edge features and edge index
    edge_features = []
    edge_index = []
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        features = [
            bond.GetBondTypeAsDouble(),
            int(bond.IsInRing()),
            bond.GetBondLength(),
            # You can add bond angle here if needed
        ]
        edge_features.append(features)
        edge_features.append(features)  # Add twice for undirected graph
        edge_index.append([start, end])
        edge_index.append([end, start])  # Add reverse edge for undirected graph

    return torch.tensor(node_features, dtype=torch.float), \
           torch.tensor(edge_index, dtype=torch.long).t().contiguous(), \
           torch.tensor(edge_features, dtype=torch.float)

In [11]:
scanner_no_bind = DATA.scanner(filter=(pc.field("binds") == 0))
scanner_bind = DATA.scanner(filter=(pc.field("binds") == 1))



In [12]:
from torch.utils.data import DataLoader

In [13]:
train_loader = DataLoader(train_data, batch_size=2, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(valid_data, batch_size=24, shuffle=False, collate_fn=collate_fn)


## Model

TODO:

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

In [15]:
class LigandProteinGCN(nn.Module):
    def __init__(self, num_node_features, num_edge_features, hidden_channels, protein_embedding_dim):
        super(LigandProteinGCN, self).__init__()
        
        # Ligand GCN layers
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        
        # Protein sequence encoding
        self.protein_encoder = nn.LSTM(20, protein_embedding_dim, batch_first=True, bidirectional=True)
        
        # Fusion and classification layers
        self.fusion = nn.Linear(hidden_channels + 2*protein_embedding_dim, hidden_channels)
        self.fc = nn.Linear(hidden_channels, 1)
        
    def forward(self, x, edge_index, batch, protein_seq):
        # Ligand graph processing
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = global_mean_pool(x, batch)  # [num_graphs, hidden_channels]
        
        # Protein sequence processing
        protein_output, _ = self.protein_encoder(protein_seq)
        protein_embedding = protein_output[:, -1, :]  # Take the last output
        
        # Fusion
        combined = torch.cat([x, protein_embedding], dim=1)
        fused = F.relu(self.fusion(combined))
        
        # Classification
        out = torch.sigmoid(self.fc(fused))
        return out

## Training

In [16]:
in_feats = 5  # Number of features per atom
h_feats = 16
out_feats = 8  # Embedding dimension

model = GCN(in_feats, h_feats, out_feats).to(device)

In [17]:
def contrastive_loss(embeddings_bind, embeddings_no_bind, margin=1.0):
    distance = torch.nn.functional.pairwise_distance(embeddings_bind, embeddings_no_bind)
    loss = torch.mean((1 - labels) * torch.pow(distance, 2) + labels * torch.pow(torch.clamp(margin - distance, min=0.0), 2))
    return loss

In [18]:
import torch.optim as optim

In [19]:
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [21]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_idx, batch in enumerate(train_loader, 1):
        graphs_bind, labels_bind, graphs_no_bind, labels_no_bind = batch
        
        # Move data to the appropriate device
        graphs_bind = graphs_bind.to(device)
        labels_bind = labels_bind.to(device)
        graphs_no_bind = graphs_no_bind.to(device)
        labels_no_bind = labels_no_bind.to(device)
        
        # Forward pass
        embeddings_bind = model(graphs_bind, graphs_bind.ndata['atomic'].to(device))
        embeddings_no_bind = model(graphs_no_bind, graphs_no_bind.ndata['atomic'].to(device))
        
        # Compute loss
        loss = contrastive_loss(embeddings_bind, embeddings_no_bind)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

        # Print batch progress
        if batch_idx % 10 == 0:  # Change the frequency as needed
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}')
        
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {total_loss/len(train_loader):.4f}')


DGLError: [12:06:59] /opt/dgl/src/runtime/c_runtime_api.cc:82: Check failed: allow_missing: Device API cuda is not enabled. Please install the cuda version of dgl.
Stack trace:
  [bt] (0) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x67) [0x7b2161f621b7]
  [bt] (1) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(dgl::runtime::DeviceAPIManager::GetAPI(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool)+0x2a5) [0x7b21623bae85]
  [bt] (2) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(dgl::runtime::DeviceAPI::Get(DGLContext, bool)+0x1ea) [0x7b21623b789a]
  [bt] (3) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(dgl::runtime::NDArray::Empty(std::vector<long, std::allocator<long> >, DGLDataType, DGLContext)+0x130) [0x7b21623d0e60]
  [bt] (4) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(dgl::runtime::NDArray::CopyTo(DGLContext const&) const+0xb5) [0x7b2162405c45]
  [bt] (5) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(dgl::UnitGraph::CopyTo(std::shared_ptr<dgl::BaseHeteroGraph>, DGLContext const&)+0x1e7) [0x7b21624f7e27]
  [bt] (6) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(dgl::HeteroGraph::CopyTo(std::shared_ptr<dgl::BaseHeteroGraph>, DGLContext const&)+0xfa) [0x7b2162411a0a]
  [bt] (7) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(+0x6226d6) [0x7b21624226d6]
  [bt] (8) /home/amm503/conda-envs/leash_bio_kaggle-dev/lib/python3.11/site-packages/dgl/libdgl.so(DGLFuncCall+0x4c) [0x7b21623b9e2c]

