# GNN-Based Pipeline for Predicting Small Molecule Inhibition of Protein Complexes - Detailed Explanation

This notebook provides a comprehensive explanation of the Graph Neural Network (GNN) based pipeline for predicting small molecule inhibition of protein-protein interaction (PPI) complexes.

## Table of Contents
1. [Overview](#overview)
2. [Setup and Installation](#setup)
3. [Data Balancing Utilities](#data-balancing)
4. [Protein Structure Processing](#protein-processing)
5. [GNN Model Architecture](#gnn-architecture)
6. [Training Pipeline](#training)
7. [Evaluation and Results](#evaluation)

## 1. Overview {#overview}

### What This Pipeline Does

This pipeline predicts whether a small molecule can inhibit protein-protein interactions (PPIs). The key components are:

- **Input**: 3D protein structures (PDB files) and small molecule SMILES strings
- **Process**: Convert proteins to graphs, extract features using GNN, combine with compound fingerprints
- **Output**: Binary prediction (inhibitor or not) with confidence scores

### Key Technical Approach

1. **Graph Representation**: Proteins are represented as graphs where:
   - Nodes = atoms
   - Edges = spatial proximity (within 6Å)
   - Two types of edges: same-residue and different-residue neighbors

2. **Feature Extraction**:
   - Atom type (C, N, O, etc.) - one-hot encoded
   - Residue type (ALA, ARG, etc.) - one-hot encoded
   - Neighborhood information (up to 10 neighbors per atom)

3. **Model Architecture**:
   - GNN for protein structure encoding
   - MLP for final prediction combining protein features, interface features, and compound fingerprints

4. **Training Strategy**:
   - Leave-one-complex-out cross-validation (LOCO-CV)
   - Balanced batch sampling to handle class imbalance
   - Class-weighted loss function

## 2. Setup and Installation {#setup}

### Environment Configuration

The pipeline requires GPU acceleration for efficient training. When running on Google Colab:
1. Go to **Runtime → Change Runtime Type**
2. Set **Hardware accelerator** to **GPU**

### Required Libraries

- **BioPython**: For parsing PDB files and extracting protein structures
- **RDKit**: For processing molecular structures and generating fingerprints
- **PyTorch**: Deep learning framework for GNN implementation
- **scikit-learn**: For data preprocessing and evaluation metrics

In [None]:
# Mount Google Drive (for Colab)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install required packages
!pip install biopython  # For protein structure parsing
!pip install rdkit      # For molecular fingerprints

# Clone the repository containing data and features
!git clone https://github.com/adibayaseen/PPI-Inhibitors

## 3. Data Balancing Utilities {#data-balancing}

### The Imbalanced Data Problem

In drug discovery, we typically have many more non-inhibitors (negatives) than inhibitors (positives). This creates challenges:
- Models can achieve high accuracy by simply predicting everything as negative
- True positive examples get "drowned out" during training
- The model fails to learn the distinguishing features of inhibitors

### Solution: Balanced Sampling

This pipeline implements two strategies for handling imbalanced data:

1. **BalancedDataset + WeightedRandomSampler**: Assigns higher sampling probability to minority class
2. **BinaryBalancedSampler**: Ensures each batch has exactly 50% positive and 50% negative examples

### Class 1: BalancedDataset

**Purpose**: Creates a dataset with sample weights inversely proportional to class frequency.

**How It Works**:
1. Counts examples in each class using `np.bincount()`
2. Calculates weights as `1 / class_count`
3. Assigns each sample a weight based on its class

**Example**: 
- If you have 100 negatives and 10 positives:
  - Negative weight = 1/100 = 0.01
  - Positive weight = 1/10 = 0.1
  - Positives are 10x more likely to be sampled

**Note**: Uses stochastic sampling, so some examples might never be selected in an epoch.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
import numpy as np

class BalancedDataset(Dataset):
    """
    A custom dataset class that creates a balanced dataset from imbalanced data.
    This dataset calculates sample weights inversely proportional to class frequencies,
    which can be used with a WeightedRandomSampler to achieve balanced batches.

    NOTE: As it involves stochastic sampling, there is a chance that a few training 
    examples are actually never selected.

    Attributes:
        data: The input data (list, NumPy array, or PyTorch tensor)
        labels: The labels corresponding to the data (1D array-like object)
        sample_weights: Weights for each sample, inversely proportional to class frequencies
    """
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

        # Count the number of examples in each class
        class_counts = np.bincount(self.labels)
        
        # Assign weight inversely proportional to class frequency
        # Rare classes get higher weights
        weights = 1. / torch.tensor(class_counts, dtype=torch.float)
        
        # Create a weight list for each sample based on its class
        self.sample_weights = weights[labels]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

### Function: create_balanced_loader

**Purpose**: Convenience function to create a DataLoader with automatic balancing.

**Parameters**:
- `data`: Input features
- `labels`: Binary labels (0 or 1)
- `batch_size`: Number of samples per batch (default: 32)

**Returns**: A DataLoader that yields balanced batches

**Usage**: Ideal for quick prototyping when you want automatic balancing without manual setup.

In [None]:
def create_balanced_loader(data, labels, batch_size=32):
    """
    Creates a DataLoader with balanced batches for a given dataset.
    This function is useful for training models on imbalanced datasets.

    Args:
        data: The input data (list, NumPy array, or PyTorch tensor)
        labels: The labels corresponding to the data (1D array-like object)
        batch_size: The size of each batch (default: 32)

    Returns:
        DataLoader: A PyTorch DataLoader that yields balanced batches

    Usage Example:
        >>> data = [features1, features2, ...]  # Your data features
        >>> labels = [label1, label2, ...]     # Your data labels
        >>> balanced_loader = create_balanced_loader(data, labels, batch_size=32)
        >>> for batch_data, batch_labels in balanced_loader:
        >>>     # Train your model using the balanced batches
    """
    dataset = BalancedDataset(data, labels)
    
    # WeightedRandomSampler will take care of the balancing
    # replacement=True allows samples to be selected multiple times
    sampler = WeightedRandomSampler(
        weights=dataset.sample_weights, 
        num_samples=len(dataset.sample_weights), 
        replacement=True
    )

    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    return loader

### Class 2: BinaryBalancedSampler

**Purpose**: Ensures each batch has exactly equal numbers of positive and negative examples.

**How It Works**:
1. Identifies majority and minority classes
2. Oversamples minority class to match majority class size
3. Uses StratifiedKFold to split into batches while maintaining 50-50 ratio

**Key Features**:
- Guarantees perfect balance in every batch
- Uses stratified splitting to maintain ratio
- May result in more examples per epoch than in dataset (due to oversampling)

**Equivalent Epochs**: The number of times minority class is seen in one DataLoader iteration.
- If you have 100 positives and 1000 negatives, equivalent_epochs ≈ 10
- This means minority examples are seen 10x per "epoch"

In [None]:
from torch.utils.data import Sampler
from sklearn.model_selection import StratifiedKFold

class BinaryBalancedSampler(Sampler):
    """
    A PyTorch Sampler that returns batches with an equal number of positive 
    and negative examples. The sampler oversamples from the minority class to 
    balance the majority class, ensuring that each batch contains 50% positive 
    and 50% negative examples.

    NOTE: It leads to more examples in single iteration through the data loader 
    than in one epoch

    Attributes:
        class_vector: List or numpy array of class labels
        batch_size: The size of each batch
        n_splits: The number of batches/splits in the dataset
        equivalent_epochs: The number of times the sampler goes over the minority 
                          class in one complete iteration of the DataLoader
    """
    def __init__(self, class_vector, batch_size=10):
        self.batch_size = batch_size
        self.class_vector = class_vector
        
        YY = np.array(self.class_vector)
        
        # Find majority class (class with most examples)
        U, C = np.unique(YY, return_counts=True)
        M = U[np.argmax(C)]  # Majority class label
        
        # Get indices of majority and minority classes
        Midx = np.nonzero(YY == M)[0]  # Majority indices
        midx = np.nonzero(YY != M)[0]  # Minority indices
        
        # Oversample minority indices to match majority size
        midx_ = np.random.choice(midx, size=len(Midx))
        
        # Combine majority and oversampled minority
        self.YY = np.array(list(YY[Midx]) + list(YY[midx_]))
        self.idx = np.array(list(Midx) + list(midx_))
        
        # Calculate number of batches
        self.n_splits = int(np.ceil(len(self.idx) / self.batch_size))
        
        # Calculate equivalent epochs
        self.equivalent_epochs = len(self.idx) / len(self.class_vector)
        print(f'Equivalent epochs in one iteration of data loader: {self.equivalent_epochs}')

    def gen_sample_array(self):
        """
        Generates batch indices using StratifiedKFold to maintain class balance.
        Yields indices for each batch.
        """
        skf = StratifiedKFold(n_splits=self.n_splits, shuffle=True)
        for tridx, ttidx in skf.split(self.idx, self.YY):
            yield np.array(self.idx[ttidx])

    def __iter__(self):
        return iter(self.gen_sample_array())

    def __len__(self):
        return self.n_splits

### Class 3: CustomDataset

**Purpose**: Basic PyTorch Dataset wrapper for data and labels.

**Usage**: Simple container for pairing features with labels. Used with BinaryBalancedSampler.

In [None]:
class CustomDataset(Dataset):
    """
    Simple dataset wrapper that pairs data with labels.
    """
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

## 4. Protein Structure Processing {#protein-processing}

### Overview

To use proteins in a GNN, we need to convert 3D structures (PDB files) into graph representations. This involves:

1. **Node Features**: Information about each atom
2. **Edge Information**: Which atoms are connected (neighbors)

### The Three Core Functions

1. **atom1()**: Encodes atom types
2. **res1()**: Encodes residue types
3. **neigh1()**: Computes neighborhood structure

### Function 1: atom1(structure)

**Purpose**: One-hot encode atom types in the protein structure.

**How It Works**:
1. Defines vocabulary of atom types: C, CA, CB, CG, CH2, N, NH2, OG, OH, O1, O2, SE
2. Unknown atoms are mapped to "1" (catch-all category)
3. Uses sklearn's OneHotEncoder for encoding

**Input**: BioPython Structure object

**Output**: NumPy array of shape (N_atoms, 13) where:
- N_atoms = total number of atoms in structure
- 13 = number of atom types (12 known + 1 unknown)

**Example**: For a Cα (alpha carbon) atom, output might be [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [None]:
from Bio.PDB import PDBParser
from sklearn.preprocessing import OneHotEncoder
import warnings

def atom1(structure):
    """
    One-hot encodes atom types in a protein structure.
    
    Args:
        structure: BioPython Structure object from PDB file
    
    Returns:
        atoms_onehot: NumPy array of shape (N_atoms, 13) with one-hot encoded atom types
    """
    # Define vocabulary of atom types
    atomslist = np.array(sorted(np.array([
        'C', 'CA', 'CB', 'CG', 'CH2', 'N', 'NH2', 'OG', 'OH', 'O1', 'O2', 'SE', '1'
    ]))).reshape(-1, 1)
    
    # Initialize and fit encoder
    enc = OneHotEncoder(handle_unknown='ignore')
    enc.fit(atomslist)
    
    # Extract atom names from structure
    atom_list = []
    for atom in structure.get_atoms():
        if atom.get_name() in atomslist:
            atom_list.append(atom.get_name())
        else:
            # Unknown atoms mapped to '1'
            atom_list.append("1")
    
    # One-hot encode
    atoms_onehot = enc.transform(np.array(atom_list).reshape(-1, 1)).toarray()
    
    return atoms_onehot

### Function 2: res1(structure)

**Purpose**: One-hot encode residue types for each atom in the protein structure.

**How It Works**:
1. Defines vocabulary of 20 standard amino acids plus unknown
2. For each atom, gets its parent residue
3. Encodes the residue type

**Input**: BioPython Structure object

**Output**: NumPy array of shape (N_atoms, 21) where:
- N_atoms = total number of atoms
- 21 = number of residue types (20 amino acids + 1 unknown)

**Why Encode at Atom Level?**: 
- The GNN operates on atoms as nodes
- Each atom needs to know what residue it belongs to
- This provides chemical context for the atom

**Example**: All atoms in an alanine residue will have the same residue encoding

In [None]:
def res1(structure):
    """
    One-hot encodes residue information for each atom.
    Each atom is labeled with its parent residue type.
    
    Args:
        structure: BioPython Structure object from PDB file
    
    Returns:
        res_onehot: NumPy array of shape (N_atoms, 21) with one-hot encoded residue types
    """
    # Define vocabulary of residue types (20 amino acids + unknown)
    residuelist = np.array(sorted(np.array([
        'ALA', 'ARG', 'ASN', 'ASP', 'GLN', 'GLU', 'GLY', 'ILE', 'LEU', 'LYS', 
        'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'CYS', 'HIS', '1'
    ]))).reshape(-1, 1)
    
    # Initialize and fit encoder
    encr = OneHotEncoder(handle_unknown='ignore')
    encr.fit(residuelist)
    
    # Extract residue name for each atom
    residue_list = []
    for atom in structure.get_atoms():
        if atom.get_parent().get_resname() in residuelist:
            residue_list.append((atom.get_parent()).get_resname())
        else:
            # Unknown residues mapped to '1'
            residue_list.append("1")

    # One-hot encode
    res_onehot = encr.transform(np.array(residue_list).reshape(-1, 1)).toarray()

    return res_onehot

### Function 3: neigh1(structure)

**Purpose**: Compute neighborhood structure for the protein graph.

**Key Concepts**:
- **Spatial Neighbors**: Atoms within 6Å distance are considered neighbors
- **Two Types of Edges**:
  1. **Same-residue neighbors**: Atoms in the same amino acid (covalent bonds)
  2. **Different-residue neighbors**: Atoms in different amino acids (non-covalent interactions)
- **Limited to 10 neighbors**: Each type is capped at 10 to control graph density

**Why Distinguish Edge Types?**
- Same-residue edges represent chemical bonds (strong, local)
- Different-residue edges represent interactions (weaker, global)
- The GNN learns different weight matrices for each type

**Algorithm Steps**:
1. Use BioPython's NeighborSearch to find all atom pairs within 6Å
2. Sort neighbors by distance (closest first)
3. For each atom, collect up to 10 same-residue neighbors
4. For each atom, collect up to 10 different-residue neighbors
5. Use -1 to indicate "no neighbor" if atom has < 10 neighbors

**Output**: Two NumPy arrays of shape (N_atoms, 10):
- `neigh_same_res`: Indices of same-residue neighbors
- `neigh_diff_res`: Indices of different-residue neighbors

**Example**: 
```
neigh_same_res[42] = [41, 43, 40, 44, -1, -1, -1, -1, -1, -1]
# Atom 42 has 4 same-residue neighbors (atoms 41, 43, 40, 44)
```

In [None]:
from Bio.PDB.NeighborSearch import NeighborSearch

def neigh1(structure):
    """
    Calculates the neighbors of each atom with a 6Å distance cutoff.
    Distinguishes between same-residue and different-residue neighbors.
    Limits each atom to 10 neighbors of each type.
    
    Args:
        structure: BioPython Structure object from PDB file
    
    Returns:
        neigh_same_res: NumPy array of shape (N_atoms, 10) with indices of same-residue neighbors
        neigh_diff_res: NumPy array of shape (N_atoms, 10) with indices of different-residue neighbors
        
    Note: -1 indicates no neighbor in that position
    """
    # Get all atoms as numpy array
    atom_list = np.array([atom for atom in structure.get_atoms()])

    # Find all atom pairs within 6Å
    p4 = NeighborSearch(atom_list)
    neighbour_list = p4.search_all(6, level="A")
    neighbour_list = np.array(neighbour_list)

    # Calculate distances between neighbors
    dist = np.array([atom1 - atom2 for atom1, atom2 in neighbour_list])
    
    # Sort neighbors by distance (ascending)
    place = np.argsort(dist)
    sorted_neighbour_list = neighbour_list[place]

    # Extract atom objects
    source_vertex_list_atom_object = np.array(sorted_neighbour_list[:, 0])
    len_source_vertex = len(source_vertex_list_atom_object)
    neighbour_vertex_with_respect_each_source_atom_object = np.array(sorted_neighbour_list[:, 1])
    
    # Store original atom serial numbers and residue numbers
    old_atom_number = []
    old_residue_number = []
    for i in atom_list:
        old_atom_number.append(i.get_serial_number())
        old_residue_number.append(i.get_parent().get_id()[1])
    old_atom_number = np.array(old_atom_number)
    old_residue_number = np.array(old_residue_number)
    
    total_atoms = len(atom_list)
    
    # Initialize neighbor arrays with -1 (no neighbor)
    neigh_same_res = np.array([[-1]*10 for i in range(total_atoms)])
    neigh_diff_res = np.array([[-1]*10 for i in range(total_atoms)])
    
    # Counters for number of neighbors added
    same_flag = [0] * total_atoms
    diff_flag = [0] * total_atoms
    
    # Iterate through all neighbor pairs
    for i in range(len_source_vertex):
        source_atom_id = source_vertex_list_atom_object[i].get_serial_number()
        neigh_atom_id = neighbour_vertex_with_respect_each_source_atom_object[i].get_serial_number()
        source_atom_res = source_vertex_list_atom_object[i].get_parent().get_id()[1]
        neigh_atom_res = neighbour_vertex_with_respect_each_source_atom_object[i].get_parent().get_id()[1]
        
        # Find indices in original atom array
        temp_index1 = np.where(source_atom_id == old_atom_number)[0]
        temp_index2 = np.where(neigh_atom_id == old_atom_number)[0]
        
        for i1 in temp_index1:
            if old_residue_number[i1] == source_atom_res:
                source_index = i1
                break
        for i1 in temp_index2:
            if old_residue_number[i1] == neigh_atom_res:
                neigh_index = i1
                break
        
        # If both atoms in same residue
        if source_atom_res == neigh_atom_res:
            # Add to same-residue neighbors (limit to 10)
            if int(same_flag[source_index]) < 10:
                neigh_same_res[source_index][same_flag[source_index]] = neigh_index
                same_flag[source_index] += 1

            if int(same_flag[neigh_index]) < 10:
                neigh_same_res[neigh_index][same_flag[neigh_index]] = source_index
                same_flag[neigh_index] += 1

        # If atoms in different residues
        elif source_atom_res != neigh_atom_res:
            # Add to different-residue neighbors (limit to 10)
            if int(diff_flag[source_index]) < 10:
                neigh_diff_res[source_index][diff_flag[source_index]] = neigh_index
                diff_flag[source_index] += 1

            if int(diff_flag[neigh_index]) < 10:
                neigh_diff_res[neigh_index][diff_flag[neigh_index]] = source_index
                diff_flag[neigh_index] += 1

    return neigh_same_res, neigh_diff_res

## 5. GNN Model Architecture {#gnn-architecture}

### Overview

The Graph Neural Network consists of:
1. **GNN_First_Layer**: Initial layer that processes raw atom and residue features
2. **GNN_Layer**: Subsequent layers that process and aggregate information
3. **Dense**: Final layer for dimensionality reduction
4. **GNN**: Complete model that chains these components

### Key Concept: Message Passing

Each GNN layer performs **message passing**:
1. Each node (atom) creates a "message" for its neighbors
2. Each node receives messages from its neighbors
3. Node updates its representation by combining its own features with neighbor messages
4. This allows information to flow through the graph

**Why 3 Layers?**
- Layer 1: Sees immediate neighbors (1-hop)
- Layer 2: Sees neighbors of neighbors (2-hop)
- Layer 3: Sees even further (3-hop)
- Each layer increases the "receptive field"

### GNN_First_Layer

**Purpose**: Initial layer that processes raw features and creates first hidden representation.

**Inputs**:
- `atoms`: One-hot encoded atom types (N × 13)
- `residues`: One-hot encoded residue types (N × 21)
- `same_neigh`: Same-residue neighbor indices (N × 10)
- `diff_neigh`: Different-residue neighbor indices (N × 10)

**Learnable Parameters**:
- `Wv`: Weight matrix for atom features (13 → 512)
- `Wr`: Weight matrix for residue features (21 → 512)
- `Wsr`: Weight matrix for same-residue neighbor messages (13 → 512)
- `Wdr`: Weight matrix for different-residue neighbor messages (13 → 512)

**Message Passing Formula**:
```
h_i = ReLU(atoms_i @ Wv + residues_i @ Wr + 
           mean(atoms_j @ Wsr for j in same_neighbors) +
           mean(atoms_k @ Wdr for k in diff_neighbors))
```

**Output**: Hidden representation (N × 512)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class GNN_First_Layer(nn.Module):
    """
    First layer of the GNN that processes raw atom and residue features.
    
    This layer:
    1. Transforms atom features through Wv
    2. Transforms residue features through Wr
    3. Aggregates same-residue neighbor information through Wsr
    4. Aggregates different-residue neighbor information through Wdr
    5. Combines all signals with ReLU activation
    """

    def __init__(self, filters, trainable=True, **kwargs):
        super(GNN_First_Layer, self).__init__()
        self.filters = filters
        self.trainable = trainable
        
        # Set device (GPU if available)
        use_cuda = torch.cuda.is_available()
        device = torch.device("cuda:0" if use_cuda else "cpu")
        self.cuda_device = device
        
        # Initialize learnable weight matrices
        self.Wv = nn.Parameter(torch.randn(13, self.filters, device=self.cuda_device, requires_grad=True))
        self.Wr = nn.Parameter(torch.randn(21, self.filters, device=self.cuda_device, requires_grad=True))
        self.Wsr = nn.Parameter(torch.randn(13, self.filters, device=self.cuda_device, requires_grad=True))
        self.Wdr = nn.Parameter(torch.randn(13, self.filters, device=self.cuda_device, requires_grad=True))
        
        self.neighbours = 10

    def forward(self, x):
        atoms, residues, same_neigh, diff_neigh = x
        
        # Transform node's own features
        node_signals = atoms @ self.Wv
        residue_signals = residues @ self.Wr
        
        # Transform atom features for neighbor aggregation
        neigh_signals_same = atoms @ self.Wsr
        neigh_signals_diff = atoms @ self.Wdr
        
        # Create masks for valid neighbors (indices > -1)
        unsqueezed_same_neigh_indicator = (same_neigh > -1).unsqueeze(2)
        unsqueezed_diff_neigh_indicator = (diff_neigh > -1).unsqueeze(2)
        
        # Gather neighbor features and mask out invalid neighbors
        same_neigh_features = neigh_signals_same[same_neigh] * unsqueezed_same_neigh_indicator
        diff_neigh_features = neigh_signals_diff[diff_neigh] * unsqueezed_diff_neigh_indicator
        
        # Calculate normalization factors (number of valid neighbors)
        same_norm = torch.sum(same_neigh > -1, 1).unsqueeze(1).type(torch.float)
        diff_norm = torch.sum(diff_neigh > -1, 1).unsqueeze(1).type(torch.float)
        
        # Prevent division by zero
        same_norm[same_norm == 0] = 1
        diff_norm[diff_norm == 0] = 1
        
        # Aggregate neighbor signals (mean pooling)
        neigh_same_atoms_signal = (torch.sum(same_neigh_features, axis=1)) / same_norm
        neigh_diff_atoms_signal = (torch.sum(diff_neigh_features, axis=1)) / diff_norm

        # Combine all signals with ReLU activation
        final_res = torch.relu(
            node_signals + residue_signals + 
            neigh_same_atoms_signal + neigh_diff_atoms_signal
        )

        return final_res, same_neigh, diff_neigh

### GNN_Layer

**Purpose**: Subsequent GNN layers that continue message passing on hidden representations.

**Inputs**:
- `Z`: Hidden representation from previous layer (N × v_feats)
- `same_neigh`: Same-residue neighbor indices (N × 10)
- `diff_neigh`: Different-residue neighbor indices (N × 10)

**Learnable Parameters**:
- `Wsv`: Weight matrix for self features (v_feats → filters)
- `Wsr`: Weight matrix for same-residue neighbors (v_feats → filters)
- `Wdr`: Weight matrix for different-residue neighbors (v_feats → filters)

**Difference from First Layer**:
- No separate residue features (already incorporated)
- Operates on learned representations instead of raw features

**Message Passing Formula**:
```
h_i^(l+1) = ReLU(h_i^l @ Wsv + 
                 mean(h_j^l @ Wsr for j in same_neighbors) +
                 mean(h_k^l @ Wdr for k in diff_neighbors))
```

**Output**: Updated hidden representation (N × filters)

In [None]:
class GNN_Layer(nn.Module):
    """
    Subsequent GNN layer that processes hidden representations.
    
    This layer:
    1. Transforms node's hidden features through Wsv
    2. Aggregates same-residue neighbor information through Wsr
    3. Aggregates different-residue neighbor information through Wdr
    4. Combines signals with ReLU activation
    """

    def __init__(self, filters, v_feats, trainable=True, **kwargs):
        super(GNN_Layer, self).__init__()
        self.v_feats = v_feats  # Input feature dimension
        self.filters = filters  # Output feature dimension
        self.trainable = trainable
        
        # Set device
        use_cuda = torch.cuda.is_available()
        device = torch.device("cuda:0" if use_cuda else "cpu")
        self.cuda_device = device
        
        # Initialize learnable weight matrices
        self.Wsv = nn.Parameter(torch.randn(self.v_feats, self.filters, device=self.cuda_device, requires_grad=True))
        self.Wdr = nn.Parameter(torch.randn(self.v_feats, self.filters, device=self.cuda_device, requires_grad=True))
        self.Wsr = nn.Parameter(torch.randn(self.v_feats, self.filters, device=self.cuda_device, requires_grad=True))
        
        self.neighbours = 10

    def forward(self, x):
        Z, same_neigh, diff_neigh = x
        
        # Transform node's own features
        node_signals = Z @ self.Wsv
        
        # Transform features for neighbor aggregation
        neigh_signals_same = Z @ self.Wsr
        neigh_signals_diff = Z @ self.Wdr
        
        # Create masks for valid neighbors
        unsqueezed_same_neigh_indicator = (same_neigh > -1).unsqueeze(2)
        unsqueezed_diff_neigh_indicator = (diff_neigh > -1).unsqueeze(2)
        
        # Gather and mask neighbor features
        same_neigh_features = neigh_signals_same[same_neigh] * unsqueezed_same_neigh_indicator
        diff_neigh_features = neigh_signals_diff[diff_neigh] * unsqueezed_diff_neigh_indicator
        
        # Calculate normalization factors
        same_norm = torch.sum(same_neigh > -1, 1).unsqueeze(1).type(torch.float)
        diff_norm = torch.sum(diff_neigh > -1, 1).unsqueeze(1).type(torch.float)

        # Prevent division by zero
        same_norm[same_norm == 0] = 1
        diff_norm[diff_norm == 0] = 1
        
        # Aggregate neighbor signals (mean pooling)
        neigh_same_atoms_signal = (torch.sum(same_neigh_features, axis=1)) / same_norm
        neigh_diff_atoms_signal = (torch.sum(diff_neigh_features, axis=1)) / diff_norm
        
        # Combine all signals with ReLU activation
        final_res = torch.relu(
            node_signals + neigh_same_atoms_signal + neigh_diff_atoms_signal
        )

        return final_res, same_neigh, diff_neigh

### Dense Layer

**Purpose**: Simple fully-connected layer with sigmoid activation.

**Usage**: Reduces dimensionality from 512 to 1 (final prediction layer in some contexts).

**Formula**: `y = sigmoid(x @ W)`

In [None]:
class Dense(nn.Module):
    """
    Simple dense (fully-connected) layer with sigmoid activation.
    """
    def __init__(self, in_dims, out_dims, trainable=True, **kwargs):
        super(Dense, self).__init__()
        self.in_dims = in_dims
        self.out_dims = out_dims
        
        use_cuda = torch.cuda.is_available()
        device = torch.device("cuda:0" if use_cuda else "cpu")
        self.cuda_device = device

        self.W = nn.Parameter(torch.randn(self.in_dims, self.out_dims, device=self.cuda_device, requires_grad=True))

    def forward(self, x):
        Z = torch.sigmoid(torch.matmul(x, self.W))
        return Z

### Complete GNN Model

**Purpose**: Chains all GNN components into a complete protein encoder.

**Architecture**:
```
Input: (atoms, residues, same_neigh, diff_neigh)
  ↓
GNN_First_Layer: (13+21) → 512
  ↓
GNN_Layer: 512 → 1024
  ↓
GNN_Layer: 1024 → 512
  ↓
Dense: 512 → 1 (not used in final pipeline)
  ↓
Global Sum Pooling: Aggregate all atom representations
  ↓
L2 Normalization: Normalize to unit vector
  ↓
Output: Protein embedding (1 × 512)
```

**Key Operations**:
1. **Global Sum Pooling**: `sum(all atom representations)` - creates protein-level representation
2. **L2 Normalization**: Ensures embeddings have unit norm for better comparison

**Output**: Fixed-size vector (512-dim) representing the entire protein structure

In [None]:
class GNN(torch.nn.Module):
    """
    Complete GNN model for encoding protein structures.
    
    Architecture:
    - Layer 1: Processes raw features (13+21) → 512
    - Layer 2: 512 → 1024 (expansion)
    - Layer 3: 1024 → 512 (compression)
    - Global sum pooling + L2 normalization
    
    Output: 512-dimensional protein embedding
    """
    def __init__(self):
        super(GNN, self).__init__()
        self.conv1 = GNN_First_Layer(filters=512)
        self.conv2 = GNN_Layer(v_feats=512, filters=1024)
        self.conv3 = GNN_Layer(v_feats=1024, filters=512)
        self.dense = Dense(in_dims=512, out_dims=1)
        
    def forward(self, x):
        # Pass through three GNN layers
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        
        # Extract features (ignore neighbor indices)
        x = x3[0]
        
        # Global sum pooling: aggregate all atom representations
        x = torch.sum(x, axis=0).view(1, -1)
        
        # L2 normalization: normalize to unit vector
        x = F.normalize(x)
        
        return x

    @staticmethod
    def processProtein(UniqueProtein, PdBloc):
        """
        Processes PDB files into graph representations.
        
        Args:
            UniqueProtein: List of protein names
            PdBloc: Path to directory containing PDB files
        
        Returns:
            PData_dict: Dictionary mapping protein names to graph data
                       [one_hot_atom, one_hot_res, neigh_same_res, neigh_diff_res]
        """
        use_cuda = torch.cuda.is_available()
        device = torch.device("cuda:0" if use_cuda else "cpu")
        PData_dict = {}
        
        for i in range(len(UniqueProtein)):
            UniqueProtein[i] = UniqueProtein[i].split('.pdb')[0]
            P1 = PdBloc + UniqueProtein[i] + '.pdb'
            
            # Parse PDB file
            parser = PDBParser()
            with warnings.catch_warnings(record=True) as w:
                structure = parser.get_structure("", P1)
            
            # Extract features
            one_hot_atom = atom1(structure)
            one_hot_res = res1(structure)
            neigh_same_res, neigh_diff_res = neigh1(structure)
            
            # Convert to tensors and move to device
            one_hot_atom = torch.tensor(one_hot_atom, dtype=torch.float32).to(device)
            one_hot_res = torch.tensor(one_hot_res, dtype=torch.float32).to(device)
            neigh_same_res = torch.tensor(neigh_same_res).to(device).long()
            neigh_diff_res = torch.tensor(neigh_diff_res).to(device).long()
            
            # Store as list
            GNNData = [one_hot_atom, one_hot_res, neigh_same_res, neigh_diff_res]
            PData_dict[UniqueProtein[i]] = GNNData
            
        return PData_dict

### IPPI_MLP_Net

**Purpose**: Multi-layer perceptron that combines all features for final prediction.

**Inputs**:
- `PFeatures`: GNN embedding of protein (512-dim)
- `ProteinInterfaceF`: Hand-crafted interface features (varies)
- `LigandFeatures`: Compound fingerprints (2048-dim typically)

**Architecture**:
```
Concatenate: [GNN_features | Interface_features | Compound_features] → 2840
  ↓
FC1: 2840 → 1024 (tanh)
  ↓
FC2: 1024 → 512 (tanh)
  ↓
FC3: 512 → 100 (relu)
  ↓
FC4: 100 → 1 (linear, for BCE loss)
  ↓
Output: Logit score (inhibitor probability after sigmoid)
```

**Why Concatenate?**
- GNN features: Learned structural patterns
- Interface features: Domain knowledge about binding sites
- Compound features: Chemical properties
- MLP learns how to combine these complementary information sources

In [None]:
class IPPI_MLP_Net(nn.Module):
    """
    Multi-layer perceptron for final inhibition prediction.
    Combines GNN protein features, interface features, and compound features.
    
    Input: Concatenated features (total 2840-dim)
    Output: Logit score for binary classification
    """
    def __init__(self):
        super(IPPI_MLP_Net, self).__init__()
        self.fc1 = nn.Linear(2840, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 100)
        self.fc6 = nn.Linear(100, 1)
        
    def forward(self, PFeatures, LigandFeatures, ProteinInterfaceF):
        # Concatenate all features
        P_all_Features = torch.hstack((PFeatures, ProteinInterfaceF))
        PC_Features = torch.hstack((P_all_Features, LigandFeatures))
        
        # Pass through MLP
        x = torch.tanh(self.fc1(PC_Features))
        x = torch.tanh(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc6(x)  # Output logits
        
        return x

## 6. Training Pipeline {#training}

### Training Strategy: Leave-One-Complex-Out Cross-Validation (LOCO-CV)

**Why LOCO-CV?**
- Standard CV would leak information (compounds tested on same protein)
- Real use case: predict for new protein complexes
- LOCO-CV tests generalization to unseen protein structures

**Process**:
1. Group all examples by protein complex
2. For each complex:
   - Train on all other complexes
   - Test on this complex
3. Aggregate results across all folds

### Key Components

1. **Data Loading**:
   - Protein structures (pre-processed as graphs)
   - Interface features (pre-computed)
   - Compound fingerprints (pre-computed)
   - Labels (inhibitor: 1, non-inhibitor: 0)

2. **Balanced Sampling**:
   - Uses BinaryBalancedSampler
   - Each batch has 50% positives, 50% negatives

3. **Class Weighting**:
   - Additional weighting in loss function
   - Accounts for different imbalance ratios across complexes

4. **Optimization**:
   - Adam optimizer (learning rate: 0.001)
   - BCE with Logits Loss
   - 5 epochs per fold

5. **Model Selection**:
   - Tracks best AUCROC on test set
   - Saves best model per fold
   - Early stopping if no improvement

### Training Loop Structure

```python
for each protein complex:
    # Split data
    train_data = all complexes except this one
    test_data = this complex
    
    # Standardize features
    fit scalers on train_data
    transform train and test data
    
    # Initialize models
    GNN_model = GNN()
    IPPI_Net = IPPI_MLP_Net()
    
    # Create balanced data loader
    loader = BinaryBalancedSampler(train_labels)
    
    best_aucroc = 0
    for epoch in range(5):
        for batch in loader:
            # Training step
            1. Pass unique proteins through GNN (avoid redundant computation)
            2. Gather features for batch examples
            3. Pass through IPPI_Net
            4. Calculate loss
            5. Backpropagate
            
            # Validation step (every iteration)
            1. Evaluate on test set
            2. Calculate AUCROC
            3. If best so far, save model
    
    # Load best model and save predictions
    save final scores and labels
```

## 7. Evaluation and Results {#evaluation}

### Evaluation Metrics

Two primary metrics are used:

1. **AUCROC (Area Under Receiver Operating Characteristic)**
   - Measures ability to rank inhibitors higher than non-inhibitors
   - Range: 0.5 (random) to 1.0 (perfect)
   - Good for overall classification performance

2. **AUCPR (Area Under Precision-Recall Curve)**
   - More informative for imbalanced datasets
   - Focuses on performance on positive class
   - Better reflects real-world utility

### Reported Performance

The pipeline achieves competitive performance across multiple protein complexes:

| Complex | AUCROC | Notes |
|---------|--------|-------|
| 2XA0 | 0.59 | Challenging case |
| 3WN7 | 0.92 | Strong performance |
| 3UVW | 0.85 | Good performance |
| 1YCR | 0.81 | Good performance |
| 4ESG | 0.93 | Strong performance |
| 3D9T | 0.88 | Strong performance |
| 2FLU | 0.95 | Excellent performance |
| 4QC3 | 0.93 | Strong performance |

**Average AUCROC**: 0.86 ± 0.09
**Average AUCPR**: 0.44 ± 0.20

### Comparison with Baselines

The notebook includes comparison code for:
- **SVM baseline**: Uses hand-crafted features only
- **GearNet**: Alternative protein structure encoder
- **GNN pipeline** (this model): Combines GNN with interface features

The GNN-based pipeline shows improved performance over SVM baseline, demonstrating the value of learned structural representations.

## Summary

### Key Takeaways

1. **Problem**: Predict if small molecules can inhibit protein-protein interactions

2. **Solution**: Graph Neural Network on 3D protein structures
   - Proteins → graphs (atoms as nodes, spatial proximity as edges)
   - GNN learns structural patterns predictive of inhibition
   - Combines with compound fingerprints via MLP

3. **Technical Innovations**:
   - Two edge types (same/different residue) with separate weights
   - Balanced batch sampling for imbalanced data
   - LOCO-CV for realistic evaluation

4. **Performance**: 
   - AUCROC ~0.86 across diverse protein complexes
   - Outperforms traditional ML baseline

### Potential Improvements

1. **Architecture**:
   - Attention mechanisms for neighbor aggregation
   - Residual connections for deeper networks
   - Edge features (distance, bond types)

2. **Features**:
   - Physicochemical properties (charge, hydrophobicity)
   - Secondary structure information
   - Evolutionary conservation scores

3. **Training**:
   - Data augmentation (rotations, perturbations)
   - Transfer learning from large protein datasets
   - Ensemble methods

4. **Evaluation**:
   - External test set (temporally split)
   - Analysis of failure cases
   - Interpretability studies (which atoms matter?)

## References

**Original Repository**: https://github.com/adibayaseen/PPI-Inhibitors

**Key Libraries**:
- BioPython: Cock et al., Bioinformatics 2009
- RDKit: https://www.rdkit.org/
- PyTorch: Paszke et al., NeurIPS 2019

**Related Work on GNNs for Proteins**:
- Graph Neural Networks for protein structure (various papers)
- Message passing neural networks
- Geometric deep learning on molecular graphs