In [None]:
!pip install -q deepchem torch_geometric rdkit biopython

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m552.4/552.4 kB[0m [31m21.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m39.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.2/36.2 MB[0m [31m43.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m47.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import deepchem as dc
import pickle
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from sklearn.metrics import mean_squared_error, r2_score
from torch.nn import Linear, ReLU, Sequential, BatchNorm1d, Dropout
from torch_geometric.nn import GINConv, global_add_pool, HeteroConv, SAGEConv, global_mean_pool, GATv2Conv, Linear
from torch_geometric.data import Data, HeteroData, Dataset
from torch.serialization import add_safe_globals
from torch_geometric.data.storage import BaseStorage
from scipy.spatial import distance_matrix
from rdkit import Chem
from Bio import PDB
# from torch_geometric.utils import scatter
add_safe_globals([BaseStorage, Data])

Instructions for updating:
experimental_relax_shapes is deprecated, use reduce_retracing instead


GraphDock Functions

In [None]:
def get_ligand_atoms_features(ligand_path):
    """Extract ligand atom features and 3D coordinates for GraphDock."""
    from rdkit import Chem
    # Suppress RDKit warnings temporarily
    from rdkit import RDLogger
    RDLogger.DisableLog('rdApp.*')

    if ligand_path.endswith('.pdb'):
        mol = Chem.MolFromPDBFile(str(ligand_path), removeHs=False)
    elif ligand_path.endswith('.mol2'):
        mol = Chem.MolFromMol2File(str(ligand_path), removeHs=False)
    else:
        mol = Chem.MolFromMolFile(str(ligand_path), removeHs=False)

    if mol is None:
        return None, None, None

    conf = mol.GetConformer()
    coords = np.array([list(conf.GetAtomPosition(i)) for i in range(mol.GetNumAtoms())])

    features = []
    atom_types = []

    for atom in mol.GetAtoms():
        atom_num = atom.GetAtomicNum()
        atom_types.append(atom_num)

        atom_type_onehot = [0] * 11
        type_map = {6: 0, 7: 1, 8: 2, 16: 3, 15: 4, 9: 5, 17: 6, 35: 7, 53: 8, 1: 9}
        atom_type_onehot[type_map.get(atom_num, 10)] = 1

        feat = atom_type_onehot + [
            atom.GetDegree() / 6.0,
            atom.GetFormalCharge(),
            int(atom.GetHybridization()) / 6.0,
            int(atom.GetIsAromatic()),
            atom.GetTotalNumHs() / 4.0,
            atom.GetMass() / 100.0,
            int(atom.IsInRing()),
        ]
        features.append(feat)

    return coords, np.array(features, dtype=np.float32), atom_types

def get_protein_atoms_features(protein_path):
    """Extract protein pocket atom features and 3D coordinates for GraphDock with enhanced features."""
    from Bio import PDB
    from Bio.PDB import DSSP

    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure("pocket", protein_path)

    # Try to compute DSSP for secondary structure
    try:
        model = structure[0]
        dssp = DSSP(model, protein_path, dssp='mkdssp')
        dssp_dict = dict(dssp)
    except:
        dssp_dict = None

    coords = []
    features = []
    residue_types = []

    # Enhanced amino acid properties: [hydrophobic, charge, polar, aromatic, size, pKa]
    aa_properties = {
        'ALA': [1, 0, 0, 0, 0.3, 0],
        'ARG': [0, 1, 1, 0, 0.9, 12.5],
        'ASN': [0, 0, 1, 0, 0.5, 0],
        'ASP': [0, -1, 1, 0, 0.5, 3.9],
        'CYS': [1, 0, 1, 0, 0.4, 8.3],
        'GLN': [0, 0, 1, 0, 0.6, 0],
        'GLU': [0, -1, 1, 0, 0.6, 4.2],
        'GLY': [0, 0, 0, 0, 0.2, 0],
        'HIS': [0, 0.5, 1, 1, 0.6, 6.0],
        'ILE': [1, 0, 0, 0, 0.7, 0],
        'LEU': [1, 0, 0, 0, 0.7, 0],
        'LYS': [0, 1, 1, 0, 0.7, 10.5],
        'MET': [1, 0, 0, 0, 0.7, 0],
        'PHE': [1, 0, 0, 1, 0.8, 0],
        'PRO': [0, 0, 0, 0, 0.5, 0],
        'SER': [0, 0, 1, 0, 0.4, 0],
        'THR': [0, 0, 1, 0, 0.5, 0],
        'TRP': [1, 0, 0, 1, 1.0, 0],
        'TYR': [0, 0, 1, 1, 0.9, 10.1],
        'VAL': [1, 0, 0, 0, 0.6, 0],
    }

    for model in structure:
        for chain in model:
            for residue in chain:
                res_name = residue.get_resname()

                # Get secondary structure for this residue
                if dssp_dict:
                    try:
                        key = (chain.id, residue.id)
                        ss = dssp_dict[key][2]
                        acc = dssp_dict[key][3]  # Solvent accessibility

                        # Secondary structure encoding
                        if ss in ['H', 'G', 'I']:  # Helices
                            ss_feat = [1, 0, 0]
                        elif ss in ['E', 'B']:  # Sheets
                            ss_feat = [0, 1, 0]
                        else:  # Coils and others
                            ss_feat = [0, 0, 1]

                        # Normalize accessibility (typical max is ~200)
                        acc_feat = min(acc / 200.0, 1.0)
                    except:
                        ss_feat = [0, 0, 1]  # Default to coil
                        acc_feat = 0.5  # Unknown accessibility
                else:
                    ss_feat = [0, 0, 1]
                    acc_feat = 0.5

                for atom in residue:
                    coords.append(atom.get_coord())
                    element = atom.element.strip()

                    # Atom type one-hot
                    atom_type_onehot = [0] * 5
                    type_map = {'C': 0, 'N': 1, 'O': 2, 'S': 3}
                    atom_type_onehot[type_map.get(element, 4)] = 1

                    # Get enhanced residue properties
                    res_props = aa_properties.get(res_name, [0, 0, 0, 0, 0.5, 0])

                    # Normalize pKa
                    res_props_normalized = res_props[:5] + [res_props[5] / 14.0]  # pH scale 0-14

                    is_backbone = int(atom.name in ['N', 'CA', 'C', 'O'])

                    # Combine all features:
                    # 5 (atom type) + 6 (residue props) + 1 (backbone) + 3 (secondary structure) + 1 (accessibility)
                    feat = atom_type_onehot + res_props_normalized + [is_backbone] + ss_feat + [acc_feat]
                    features.append(feat)
                    residue_types.append(res_name)

    return np.array(coords), np.array(features, dtype=np.float32), residue_types


def compute_edge_features_graphdock(coord_i, coord_j, dist):
    """Compute edge features for GraphDock."""
    distance_feat = [dist]
    rbf_centers = np.linspace(0, 10, 10)
    rbf_gamma = 0.5
    rbf_feats = np.exp(-rbf_gamma * (dist - rbf_centers) ** 2).tolist()

    if dist > 0:
        direction = (coord_j - coord_i) / dist
        direction_feats = direction.tolist()
    else:
        direction_feats = [0.0, 0.0, 0.0]

    return distance_feat + rbf_feats + direction_feats


def create_graphdock_hetero_graph(ligand_path, protein_path, ligand_cutoff=5.0,
                                   protein_cutoff=6.0, interaction_cutoff=5.0):
    """Create a heterogeneous graph for GraphDock."""
    lig_coords, lig_feats, lig_types = get_ligand_atoms_features(ligand_path)
    if lig_coords is None:
        return None

    prot_coords, prot_feats, res_types = get_protein_atoms_features(protein_path)

    data = HeteroData()
    data['ligand'].x = torch.tensor(lig_feats, dtype=torch.float)
    data['ligand'].pos = torch.tensor(lig_coords, dtype=torch.float)
    data['ligand'].num_nodes = len(lig_coords)
    data['protein'].x = torch.tensor(prot_feats, dtype=torch.float)
    data['protein'].pos = torch.tensor(prot_coords, dtype=torch.float)
    data['protein'].num_nodes = len(prot_coords)

    # Ligand-ligand edges
    lig_dist_matrix = distance_matrix(lig_coords, lig_coords)
    lig_edges = []
    lig_edge_feats = []
    for i in range(len(lig_coords)):
        for j in range(i + 1, len(lig_coords)):
            if lig_dist_matrix[i, j] <= ligand_cutoff:
                dist = lig_dist_matrix[i, j]
                edge_feat = compute_edge_features_graphdock(lig_coords[i], lig_coords[j], dist)
                lig_edges.append([i, j])
                lig_edges.append([j, i])
                lig_edge_feats.append(edge_feat)
                lig_edge_feats.append(edge_feat)

    if len(lig_edges) > 0:
        data['ligand', 'lig_lig', 'ligand'].edge_index = torch.tensor(
            lig_edges, dtype=torch.long).t().contiguous()
        data['ligand', 'lig_lig', 'ligand'].edge_attr = torch.tensor(
            lig_edge_feats, dtype=torch.float)

    # Protein-protein edges
    prot_dist_matrix = distance_matrix(prot_coords, prot_coords)
    prot_edges = []
    prot_edge_feats = []
    for i in range(len(prot_coords)):
        nearby = np.where(prot_dist_matrix[i] <= protein_cutoff)[0]
        for j in nearby:
            if i < j:
                dist = prot_dist_matrix[i, j]
                edge_feat = compute_edge_features_graphdock(prot_coords[i], prot_coords[j], dist)
                prot_edges.append([i, j])
                prot_edges.append([j, i])
                prot_edge_feats.append(edge_feat)
                prot_edge_feats.append(edge_feat)

    if len(prot_edges) > 0:
        data['protein', 'prot_prot', 'protein'].edge_index = torch.tensor(
            prot_edges, dtype=torch.long).t().contiguous()
        data['protein', 'prot_prot', 'protein'].edge_attr = torch.tensor(
            prot_edge_feats, dtype=torch.float)

    # Protein-ligand interaction edges
    cross_dist_matrix = distance_matrix(lig_coords, prot_coords)
    pl_edges = []
    pl_edge_feats = []
    lp_edges = []
    lp_edge_feats = []

    for i in range(len(lig_coords)):
        for j in range(len(prot_coords)):
            if cross_dist_matrix[i, j] <= interaction_cutoff:
                dist = cross_dist_matrix[i, j]
                edge_feat = compute_edge_features_graphdock(lig_coords[i], prot_coords[j], dist)
                pl_edges.append([i, j])
                pl_edge_feats.append(edge_feat)
                edge_feat_rev = compute_edge_features_graphdock(prot_coords[j], lig_coords[i], dist)
                lp_edges.append([j, i])
                lp_edge_feats.append(edge_feat_rev)

    if len(pl_edges) > 0:
        data['ligand', 'interaction', 'protein'].edge_index = torch.tensor(
            pl_edges, dtype=torch.long).t().contiguous()
        data['ligand', 'interaction', 'protein'].edge_attr = torch.tensor(
            pl_edge_feats, dtype=torch.float)
        data['protein', 'interaction', 'ligand'].edge_index = torch.tensor(
            lp_edges, dtype=torch.long).t().contiguous()
        data['protein', 'interaction', 'ligand'].edge_attr = torch.tensor(
            lp_edge_feats, dtype=torch.float)

    return data
def process_pdbbind_with_graphdock(dc_dataset, ligand_cutoff=5.0,
                                    protein_cutoff=6.0, interaction_cutoff=5.0):
    """Process PDBBind dataset with GraphDock."""
    graph_data = []
    failed_count = 0

    for X, y, w, ids in tqdm(dc_dataset.iterbatches(batch_size=1, deterministic=True),
                             desc="Processing with GraphDock"):
        ligand_path = X[0][0]
        protein_path = X[0][1]

        try:
            hetero_data = create_graphdock_hetero_graph(
                ligand_path, protein_path, ligand_cutoff, protein_cutoff, interaction_cutoff
            )
            if hetero_data is None:
                failed_count += 1
                continue
            hetero_data.y = torch.tensor(y, dtype=torch.float)
            graph_data.append(hetero_data)
        except Exception as e:
            failed_count += 1
            continue

    print(f"\nProcessed {len(graph_data)} complexes successfully, Failed: {failed_count}")
    return graph_data

GraphDock Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, GINConv, global_mean_pool, global_max_pool
from torch.nn import Linear, BatchNorm1d, Sequential, ReLU, Dropout

class GraphDockModel(nn.Module):
    """
    GraphDock GNN model using GIN (Graph Isomorphism Network) for predicting
    ligand-protein binding affinity.

    Key features:
    - GINConv for powerful graph representation learning
    - Edge features incorporated via concatenation to node features
    - Residual connections for better gradient flow
    - Multi-level pooling (mean + max) for robust graph representations
    """
    def __init__(self, ligand_feat_dim=18, protein_feature_dim=16, edge_feat_dim=14,
                 hidden_dim=128, num_layers=3, output_dim=1, dropout=0.2):
        super(GraphDockModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # --- Ligand embedding ---
        self.ligand_embedding = nn.Sequential(
            nn.Linear(ligand_feat_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # --- Protein embedding ---
        self.protein_embedding = nn.Sequential(
            nn.Linear(protein_feature_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # --- Edge feature encoders ---
        self.edge_encoders = nn.ModuleDict({
            'lig_lig': Linear(edge_feat_dim, hidden_dim),
            'prot_prot': Linear(edge_feat_dim, hidden_dim),
            'interaction': Linear(edge_feat_dim, hidden_dim),
        })

        # --- Heterogeneous GNN layers with GIN ---
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleDict()

        for layer in range(num_layers):
            conv_dict = {}

            # GIN uses MLPs to aggregate neighbor features
            # Each edge type gets its own GIN layer with a 2-layer MLP
            conv_dict[('ligand', 'lig_lig', 'ligand')] = GINConv(
                Sequential(
                    Linear(hidden_dim * 2, hidden_dim * 2),  # *2 for node + edge features
                    BatchNorm1d(hidden_dim * 2),
                    ReLU(),
                    Dropout(dropout),
                    Linear(hidden_dim * 2, hidden_dim)
                ),
                train_eps=True
            )

            conv_dict[('protein', 'prot_prot', 'protein')] = GINConv(
                Sequential(
                    Linear(hidden_dim * 2, hidden_dim * 2),
                    BatchNorm1d(hidden_dim * 2),
                    ReLU(),
                    Dropout(dropout),
                    Linear(hidden_dim * 2, hidden_dim)
                ),
                train_eps=True
            )

            conv_dict[('ligand', 'interaction', 'protein')] = GINConv(
                Sequential(
                    Linear(hidden_dim * 2, hidden_dim * 2),
                    BatchNorm1d(hidden_dim * 2),
                    ReLU(),
                    Dropout(dropout),
                    Linear(hidden_dim * 2, hidden_dim)
                ),
                train_eps=True
            )

            conv_dict[('protein', 'interaction', 'ligand')] = GINConv(
                Sequential(
                    Linear(hidden_dim * 2, hidden_dim * 2),
                    BatchNorm1d(hidden_dim * 2),
                    ReLU(),
                    Dropout(dropout),
                    Linear(hidden_dim * 2, hidden_dim)
                ),
                train_eps=True
            )

            self.convs.append(HeteroConv(conv_dict, aggr='sum'))

            self.batch_norms[f'ligand_{layer}'] = BatchNorm1d(hidden_dim)
            self.batch_norms[f'protein_{layer}'] = BatchNorm1d(hidden_dim)

        # --- Prediction head ---
        # Input: 4 * hidden_dim (ligand_mean + ligand_max + protein_mean + protein_max)
        self.predictor = Sequential(
            Linear(hidden_dim * 4, hidden_dim * 2),
            BatchNorm1d(hidden_dim * 2),
            ReLU(),
            Dropout(dropout),
            Linear(hidden_dim * 2, hidden_dim),
            BatchNorm1d(hidden_dim),
            ReLU(),
            Dropout(dropout),
            Linear(hidden_dim, output_dim)
        )

    def forward(self, data):
        # --- Embed nodes ---
        x_dict = {
            'ligand': self.ligand_embedding(data['ligand'].x),
            'protein': self.protein_embedding(data['protein'].x)
        }

        # --- Encode edge features ---
        edge_attr_dict = {}
        for edge_type in data.edge_types:
            if hasattr(data[edge_type], 'edge_attr'):
                if edge_type[1] == 'lig_lig':
                    encoder_key = 'lig_lig'
                elif edge_type[1] == 'prot_prot':
                    encoder_key = 'prot_prot'
                else:
                    encoder_key = 'interaction'
                edge_attr_dict[edge_type] = self.edge_encoders[encoder_key](
                    data[edge_type].edge_attr
                )

        # --- Message passing through GNN layers with residual connections ---
        for layer_idx, conv in enumerate(self.convs):
            # Prepare node features by concatenating with aggregated edge features
            # For GIN, we'll concatenate edge features with node features
            x_dict_with_edges = {}
            for node_type in ['ligand', 'protein']:
                # Collect all incoming edges for this node type
                edge_features_list = []
                for edge_type in data.edge_types:
                    if edge_type[2] == node_type and edge_type in edge_attr_dict:
                        edge_index = data[edge_type].edge_index
                        edge_attr = edge_attr_dict[edge_type]

                        # Aggregate edge features to target nodes (simple mean)
                        target_nodes = edge_index[1]
                        num_nodes = x_dict[node_type].size(0)

                        # Sum edge features for each node
                        edge_sum = torch.zeros(num_nodes, self.hidden_dim,
                                             device=x_dict[node_type].device)
                        edge_sum.index_add_(0, target_nodes, edge_attr)

                        # Count edges per node for averaging
                        edge_count = torch.zeros(num_nodes, 1,
                                               device=x_dict[node_type].device)
                        edge_count.index_add_(0, target_nodes,
                                            torch.ones(target_nodes.size(0), 1,
                                                     device=x_dict[node_type].device))
                        edge_count = edge_count.clamp(min=1)  # Avoid division by zero

                        edge_features_list.append(edge_sum / edge_count)

                # Average all edge features
                if edge_features_list:
                    avg_edge_features = torch.stack(edge_features_list).mean(dim=0)
                else:
                    avg_edge_features = torch.zeros_like(x_dict[node_type])

                # Concatenate node features with edge features
                x_dict_with_edges[node_type] = torch.cat(
                    [x_dict[node_type], avg_edge_features], dim=1
                )

            # Apply GIN convolution
            x_dict_new = conv(x_dict_with_edges, data.edge_index_dict)

            # Apply batch norm and activation
            x_dict_new = {
                'ligand': self.batch_norms[f'ligand_{layer_idx}'](x_dict_new['ligand']),
                'protein': self.batch_norms[f'protein_{layer_idx}'](x_dict_new['protein'])
            }

            # Residual connection: add previous layer's output
            x_dict = {
                'ligand': F.relu(x_dict_new['ligand'] + x_dict['ligand']),
                'protein': F.relu(x_dict_new['protein'] + x_dict['protein'])
            }

        # --- Get batch indices ---
        ligand_batch = data['ligand'].batch if hasattr(data['ligand'], 'batch') else torch.zeros(
            data['ligand'].num_nodes, dtype=torch.long, device=x_dict['ligand'].device
        )
        protein_batch = data['protein'].batch if hasattr(data['protein'], 'batch') else torch.zeros(
            data['protein'].num_nodes, dtype=torch.long, device=x_dict['protein'].device
        )

        # --- Multi-level pooling: combine mean and max ---
        # Mean pooling captures overall graph structure
        ligand_mean = global_mean_pool(x_dict['ligand'], ligand_batch)
        protein_mean = global_mean_pool(x_dict['protein'], protein_batch)

        # Max pooling captures most salient features
        ligand_max = global_max_pool(x_dict['ligand'], ligand_batch)
        protein_max = global_max_pool(x_dict['protein'], protein_batch)

        # Concatenate all pooled representations
        combined = torch.cat([ligand_mean, ligand_max, protein_mean, protein_max], dim=1)

        return self.predictor(combined)

In [None]:
# Core PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast, GradScaler

# PyTorch Geometric
from torch_geometric.nn import global_mean_pool, global_max_pool
import os

# Enable better CUDA error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

def train_graphdock_model(model, train_loader, val_loader, num_epochs=200,
                          patience=25, lr=0.001, device='cpu',
                          save_path='best_model.pth'):
    """
    Training loop for GraphDock model using mixed precision.
    """
    print(f"Using device: {device}")
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=10
    )

    scaler = GradScaler('cuda' if device == 'cuda' else 'cpu')

    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    for epoch in range(num_epochs):
        print(f"Starting epoch {epoch+1}")

        # --- Training step ---
        model.train()
        train_losses = []

        for batch_idx, batch in enumerate(train_loader):
            try:
                batch = batch.to(device)

                optimizer.zero_grad()

                with autocast(device_type='cuda' if device == 'cuda' else 'cpu'):
                    preds = model(batch).squeeze()
                    targets = batch.y.squeeze().float().to(device)

                    # Check for NaN/Inf
                    if torch.isnan(preds).any() or torch.isinf(preds).any():
                        print(f"Warning: NaN/Inf in predictions at batch {batch_idx}")
                        continue

                    loss = F.mse_loss(preds, targets)

                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                train_losses.append(loss.item())

            except Exception as e:
                print(f"Error in batch {batch_idx}: {str(e)}")
                raise e

        if not train_losses:
            print("Warning: No valid training batches in this epoch")
            continue

        # --- Validation step ---
        model.eval()
        val_losses = []

        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                try:
                    batch = batch.to(device)

                    with autocast(device_type='cuda' if device == 'cuda' else 'cpu'):
                        preds = model(batch).squeeze()
                        targets = batch.y.squeeze().float().to(device)

                        if torch.isnan(preds).any() or torch.isinf(preds).any():
                            print(f"Warning: NaN/Inf in validation predictions at batch {batch_idx}")
                            continue

                        loss = F.mse_loss(preds, targets)
                        val_losses.append(loss.item())

                except Exception as e:
                    print(f"Error in validation batch {batch_idx}: {str(e)}")
                    continue

        if not val_losses:
            print("Warning: No valid validation batches in this epoch")
            continue

        # Compute average losses
        train_loss_avg = sum(train_losses) / len(train_losses)
        val_loss_avg = sum(val_losses) / len(val_losses)

        # Adjust learning rate
        old_lr = optimizer.param_groups[0]['lr']
        scheduler.step(val_loss_avg)
        new_lr = optimizer.param_groups[0]['lr']
        if new_lr != old_lr:
            print(f"Learning rate reduced to {new_lr}")

        # --- Early stopping logic ---
        if val_loss_avg < best_val_loss:
            best_val_loss = val_loss_avg
            patience_counter = 0
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

            # Save best model
            torch.save({
                'epoch': epoch,
                'model_state_dict': best_model_state,
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_val_loss,
            }, save_path)

            print(f"  ✓ New best model saved (val_loss: {best_val_loss:.4f})")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"\nEarly stopping at epoch {epoch+1}")
                model.load_state_dict(best_model_state)
                break

        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss_avg:.4f} - Val Loss: {val_loss_avg:.4f}")

    return model

Evaluation Function

In [None]:
def evaluate_model(model, loader, device='cpu'):
    """Evaluate model and return metrics."""
    model.eval()
    predictions = []
    actuals = []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            preds = model(batch).squeeze().cpu().numpy()
            targets = batch.y.squeeze().cpu().numpy()

            if preds.ndim == 0:
                preds = np.array([preds])
            if targets.ndim == 0:
                targets = np.array([targets])

            predictions.extend(preds)
            actuals.extend(targets)

    actuals = np.array(actuals)
    predictions = np.array(predictions)

    mse = mean_squared_error(actuals, predictions)
    rmse = np.sqrt(mse)
    r2 = r2_score(actuals, predictions)

    return mse, rmse, r2, actuals, predictions

### Load Dataset ###

In [None]:
# Load PDBBind dataset
tasks, datasets, transformers = dc.molnet.load_pdbbind(
    featurizer='raw',
    set_name='refined',
    splitter='random',
    reload=True
)
train_dataset, valid_dataset, test_dataset = datasets
print(f"Train dataset length: {len(train_dataset)}\nValidate dataset length: {len(valid_dataset)}\nTest dataset length: {len(test_dataset)}")


Train dataset length: 3881
Validate dataset length: 485
Test dataset length: 486


In [None]:
for X, y, w, ids in train_dataset.iterbatches(batch_size=1, deterministic=True):
    print("Features (X):", X)
    print("Protein pocket:",X[0][1])
    print("Label (y):", y)
    print("Weight (w):", w)
    print("ID:", ids)
    break

Features (X): [['/tmp/refined-set/4zeb/4zeb_ligand.sdf'
  '/tmp/refined-set/4zeb/4zeb_pocket.pdb']]
Protein pocket: /tmp/refined-set/4zeb/4zeb_pocket.pdb
Label (y): [0.06175972]
Weight (w): [1.]
ID: ['4zeb']


In [None]:
from torch_geometric.data import Batch

class GraphDataset(Dataset):
    from torch.serialization import add_safe_globals
    from torch_geometric.data import Data, HeteroData
    from torch_geometric.data.storage import BaseStorage, NodeStorage, EdgeStorage
    add_safe_globals([BaseStorage, NodeStorage, EdgeStorage, Data, HeteroData])

    def __init__(self, graph_dir, dataset_name='train'):
        self.graph_dir = graph_dir
        metadata_path = os.path.join(graph_dir, f"{dataset_name}_metadata.pkl")
        with open(metadata_path, 'rb') as f:
            metadata = pickle.load(f)
        self.complex_ids = metadata['successful_ids']

    def __len__(self):
        return len(self.complex_ids)

    def __getitem__(self, idx):
        complex_id = self.complex_ids[idx]
        graph_path = os.path.join(self.graph_dir, f"{complex_id}.pt")
        return torch.load(graph_path)

def collate_fn(batch):
    return Batch.from_data_list(batch)


In [None]:
train_graphs = process_pdbbind_with_graphdock(
    train_dataset,
    ligand_cutoff=5.0,
    protein_cutoff=6.0,
    interaction_cutoff=5.0
)

print("Processing validation set...")
valid_graphs = process_pdbbind_with_graphdock(
    valid_dataset,
    ligand_cutoff=5.0,
    protein_cutoff=6.0,
    interaction_cutoff=5.0
)

print("Processing test set...")
test_graphs = process_pdbbind_with_graphdock(
    test_dataset,
    ligand_cutoff=5.0,
    protein_cutoff=6.0,
    interaction_cutoff=5.0
)

# May not be actual size because of invalid molecule structure
print(f"Train graphs have size of {len(train_graphs)}")
print(f"Valid graphs have size of {len(valid_graphs)}")
print(f"Test graphs have size of {len(test_graphs)}")



Processing with GraphDock: 3881it [12:00,  5.39it/s]



Processed 2212 complexes successfully, Failed: 1669
Processing validation set...


Processing with GraphDock: 485it [01:38,  4.93it/s]



Processed 284 complexes successfully, Failed: 201
Processing test set...


Processing with GraphDock: 486it [01:36,  5.01it/s]


Processed 291 complexes successfully, Failed: 195
Train graphs have size of 2212
Valid graphs have size of 284
Test graphs have size of 291





In [None]:
BATCH_SIZE = 16
train_loader = DataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_graphs, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_graphs, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Get one batch
data_iter = iter(train_loader)
batch = next(data_iter)

# Inspect batch type
print("Batch type:", type(batch))

# Inspect ligand nodes
print("Ligand node features shape:", batch['ligand'].x.shape)
print("First 5 ligand node features:\n", batch['ligand'].x[:5])

# Inspect protein nodes
print("Protein node features shape:", batch['protein'].x.shape)
print("First 5 protein node features:\n", batch['protein'].x[:5])

# Inspect edges
print("Edge types:", batch.edge_index_dict.keys())
for etype, edge_index in batch.edge_index_dict.items():
    print(f"{etype} edge_index shape:", edge_index.shape)
    if batch[etype].edge_attr is not None:
        print(f"{etype} edge_attr shape:", batch[etype].edge_attr.shape)

# Optional: look at batch indices
if hasattr(batch['ligand'], 'batch'):
    print("Ligand batch indices:", batch['ligand'].batch[:10])
if hasattr(batch['protein'], 'batch'):
    print("Protein batch indices:", batch['protein'].batch[:10])


Batch type: <class 'abc.HeteroDataBatch'>
Ligand node features shape: torch.Size([809, 18])
First 5 ligand node features:
 tensor([[0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.3333, 0.0000, 0.6667, 0.0000, 0.0000, 0.1600, 0.0000],
        [1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.6667, 0.0000, 0.6667, 0.0000, 0.0000, 0.1201, 1.0000],
        [1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.6667, 0.0000, 0.6667, 0.0000, 0.0000, 0.1201, 1.0000],
        [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.6667, 1.0000, 0.6667, 0.0000, 0.0000, 0.1401, 1.0000],
        [1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.6667, 0.0000, 0.6667, 0.0000, 0.0000, 0.1201, 1.0000]])
Protein node features shape: torch.Size([7703, 16])
First 5 p

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); print(f"Using {device}.")

# --- Model ---
model = GraphDockModel(
    ligand_feat_dim=18,
    protein_feature_dim=16,
    edge_feat_dim=14,
    hidden_dim=128,
    num_layers=3,
    output_dim=1,
    dropout=0.2
).to(device)

Using cuda.


In [None]:
# --- Training ---
trained_model = train_graphdock_model(
    model=model,
    train_loader=train_loader,
    val_loader=valid_loader,
    num_epochs=50,  # fewer epochs for testing
    patience=10,
    lr=0.001,
    device=device,
    save_path='graphdock_bestmodel.pth'
)

Using device: cuda
Starting epoch 1




  ✓ New best model saved (val_loss: 0.6497)
Epoch 1/50 - Train Loss: 0.8057 - Val Loss: 0.6497
Starting epoch 2
Epoch 2/50 - Train Loss: 0.6609 - Val Loss: 0.7550
Starting epoch 3
Epoch 3/50 - Train Loss: 0.6193 - Val Loss: 0.8508
Starting epoch 4
  ✓ New best model saved (val_loss: 0.6231)
Epoch 4/50 - Train Loss: 0.5768 - Val Loss: 0.6231
Starting epoch 5
  ✓ New best model saved (val_loss: 0.5949)
Epoch 5/50 - Train Loss: 0.5488 - Val Loss: 0.5949
Starting epoch 6
  ✓ New best model saved (val_loss: 0.5654)
Epoch 6/50 - Train Loss: 0.5084 - Val Loss: 0.5654
Starting epoch 7
Epoch 7/50 - Train Loss: 0.4875 - Val Loss: 0.7560
Starting epoch 8
  ✓ New best model saved (val_loss: 0.5474)
Epoch 8/50 - Train Loss: 0.4663 - Val Loss: 0.5474
Starting epoch 9
  ✓ New best model saved (val_loss: 0.5363)
Epoch 9/50 - Train Loss: 0.4610 - Val Loss: 0.5363
Starting epoch 10
Epoch 10/50 - Train Loss: 0.4312 - Val Loss: 0.7729
Starting epoch 11
Epoch 11/50 - Train Loss: 0.4395 - Val Loss: 0.5697
S

Model Evaluation

In [None]:
# Evaluate on validation set
val_mse, val_rmse, val_r2, val_actuals, val_preds = evaluate_model(
    trained_model, valid_loader, device=device
)

print(f"Validation MSE: {val_mse:.4f}")
print(f"Validation RMSE: {val_rmse:.4f}")
print(f"Validation R²: {val_r2:.4f}")

# Evaluate on test set
test_mse, test_rmse, test_r2, test_actuals, test_preds = evaluate_model(
    trained_model, test_loader, device=device
)

print(f"Test MSE: {test_mse:.4f}")
print(f"Test RMSE: {test_rmse:.4f}")
print(f"Test R²: {test_r2:.4f}")


Validation MSE: 0.5140
Validation RMSE: 0.7169
Validation R²: 0.4709
Test MSE: 0.4252
Test RMSE: 0.6521
Test R²: 0.5744
