<a href="https://colab.research.google.com/github/adibayaseen/PPI-Inhibitors/blob/main/Comprehensive_GNN_PPI_Inhibitor_Prediction_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Comprehensive GNN-Based Pipeline for Predicting Small-Molecule Inhibition of Protein Complexes

## Overview

This notebook provides a **complete, step-by-step implementation** of the Graph Neural Network (GNN) based model for predicting whether a small molecule can act as an inhibitor of a specific protein complex.

### Research Paper Summary

**Title:** Predicting small-molecule inhibition of protein complexes

**Problem Statement:**
- Protein-Protein Interactions (PPIs) are crucial in biological processes and disease mechanisms
- Traditional methods for discovering PPI inhibitors are expensive and time-consuming
- **No existing method** takes both a protein complex AND a compound as inputs to predict targeted inhibition

**Solution:**
- First **targeted** machine learning predictor of small molecule inhibition of protein complexes
- Integrates:
  - 3D structure of protein complex (via Graph Neural Network)
  - Protein-protein binding interface features
  - Compound SMILES representation (via molecular fingerprints)

**Results:**
- Cross-validation AUC-ROC: **0.86** (86% accuracy)
- External test set 1 (recent publications): AUC-ROC **0.82**
- External test set 2 (SARS-CoV-2): AUC-ROC **0.78**
- Outperforms baseline SVM and GearNet embeddings

### Notebook Structure

This notebook is organized into the following sections:

1. **Environment Setup**: Install required packages and dependencies
2. **Data Loading**: Load datasets (2P2I database, negative examples)
3. **Feature Extraction**: Extract features from protein complexes and compounds
4. **Data Loaders**: Create balanced data loaders for training
5. **Model Architecture**: Implement GNN and MLP models
6. **Training**: Leave-one-complex-out cross-validation
7. **Validation**: Evaluate model performance
8. **External Testing**: Test on independent datasets
9. **Baseline Comparison**: Compare with SVM and GearNet
10. **Results Visualization**: Generate ROC and PR curves

---

**IMPORTANT NOTE:**
- Set Runtime → Change Runtime Type to **GPU** for optimal performance
- This notebook uses only the functions and approaches from the original research
- All code is heavily documented with explanations

---

## Section 1: Environment Setup and Package Installation

### Purpose
Install all required Python packages and dependencies for:
- **Structural biology**: BioPython (PDB file handling)
- **Cheminformatics**: RDKit (SMILES processing, molecular fingerprints)
- **Deep Learning**: PyTorch (neural network implementation)
- **Data Science**: NumPy, Pandas, Scikit-learn

### Key Packages
- `biopython`: Parse protein structures from PDB files
- `rdkit`: Generate molecular fingerprints from SMILES strings
- `torch`: Deep learning framework for GNN implementation
- `scikit-learn`: Data preprocessing and evaluation metrics

In [None]:
# Mount Google Drive to access pre-computed protein features
# These features are too large to recompute in this notebook
from google.colab import drive
drive.mount('/content/drive')

# Expected structure:
# /content/drive/MyDrive/GNN-PPI-Inhibitor/
#   ├── ProteinData_dict.pickle          # 2P2I protein GNN features
#   └── DBD5_ProteinData_dict.pickle    # DBD5 protein GNN features

In [None]:
# Clone the GitHub repository containing data and utility functions
# This will be removed and re-cloned if it already exists
!rm -rf PPI-Inhibitors
!git clone https://github.com/adibayaseen/PPI-Inhibitors

# Repository contains:
# - Data/: Datasets (complexes, inhibitors, SMILES)
# - Features/: Pre-computed features (interface, sequence, compound fingerprints)
# - code/: Original notebook implementations

In [None]:
# Install BioPython for protein structure parsing
# BioPython provides:
# - PDBParser: Parse PDB files into structured objects
# - NeighborSearch: Find neighboring atoms in 3D space
# - Structure navigation: Access chains, residues, atoms
!pip install biopython

In [None]:
# Install RDKit for cheminformatics operations
# RDKit provides:
# - SMILES parsing: Convert string representation to molecular objects
# - Fingerprint generation: Extended-Connectivity Fingerprints (ECFP)
# - Molecular descriptors: Properties and features
!pip install rdkit

---

## Section 2: Import Required Libraries

### Purpose
Import all necessary Python libraries and modules for the complete pipeline.

### Library Categories

1. **Deep Learning (PyTorch)**
   - `torch`: Core PyTorch functionality
   - `torch.nn`: Neural network layers and loss functions
   - `torch.optim`: Optimizers (Adam)
   - `torch.utils.data`: Dataset and DataLoader classes

2. **Structural Biology (BioPython)**
   - `Bio.PDB.PDBParser`: Parse PDB structure files
   - `Bio.PDB.NeighborSearch`: Find neighboring atoms within distance threshold

3. **Cheminformatics (RDKit)**
   - `rdkit.Chem`: SMILES parsing
   - Morgan Fingerprints: Molecular structure encoding

4. **Machine Learning (Scikit-learn)**
   - Preprocessing: StandardScaler, OneHotEncoder
   - Cross-validation: GroupKFold
   - Metrics: ROC-AUC, PR-AUC, precision, recall

5. **Data Manipulation**
   - NumPy: Numerical computations
   - Pandas: Data structures and analysis
   - Pickle: Serialization of Python objects

6. **Visualization**
   - Matplotlib: Plotting ROC and PR curves

7. **Utilities**
   - tqdm: Progress bars
   - warnings: Suppress BioPython warnings

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Comprehensive GNN-Based PPI Inhibitor Prediction Pipeline

This module implements the complete pipeline for predicting small-molecule
inhibition of protein complexes using Graph Neural Networks.

Author: Based on research by Yaseen et al. (2024)
Paper: Predicting small-molecule inhibition of protein complexes
"""

# ============================================================================
# DEEP LEARNING LIBRARIES (PyTorch)
# ============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.data.sampler import WeightedRandomSampler
from torch.autograd import Variable

# ============================================================================
# STRUCTURAL BIOLOGY LIBRARIES (BioPython)
# ============================================================================
from Bio.PDB import *
from Bio.PDB import PDBParser, NeighborSearch
import warnings  # Suppress PDB parsing warnings

# ============================================================================
# CHEMINFORMATICS LIBRARIES (RDKit)
# ============================================================================
from rdkit import Chem
from rdkit.Chem import AllChem

# ============================================================================
# MACHINE LEARNING LIBRARIES (Scikit-learn)
# ============================================================================
from sklearn.preprocessing import OneHotEncoder, StandardScaler, normalize
from sklearn.model_selection import GroupKFold, StratifiedKFold
from sklearn.metrics import (
    roc_auc_score, roc_curve,
    precision_recall_curve, average_precision_score,
    precision_score, recall_score, auc
)

# ============================================================================
# DATA MANIPULATION LIBRARIES
# ============================================================================
import numpy as np
import pandas as pd
import pickle
import random

# ============================================================================
# VISUALIZATION LIBRARIES
# ============================================================================
import matplotlib.pyplot as plt

# ============================================================================
# UTILITY LIBRARIES
# ============================================================================
from tqdm import tqdm  # Progress bars
import glob
import os

# ============================================================================
# CUDA CONFIGURATION
# ============================================================================
# Check if GPU is available and configure device
USE_CUDA = torch.cuda.is_available()

def cuda(v):
    """
    Move tensor to GPU if CUDA is available.
    
    Args:
        v: Tensor or variable to move to GPU
    
    Returns:
        Tensor on GPU if available, otherwise on CPU
    """
    if USE_CUDA:
        return v.cuda()
    return v

def toTensor(v, dtype=torch.float, requires_grad=False):
    """
    Convert numpy array or list to PyTorch tensor and move to GPU.
    
    Args:
        v: Input data (numpy array, list, etc.)
        dtype: PyTorch data type (default: torch.float)
        requires_grad: Whether to track gradients (default: False)
    
    Returns:
        PyTorch tensor on appropriate device
    """
    return cuda(Variable(torch.tensor(v)).type(dtype).requires_grad_(requires_grad))

def toNumpy(v):
    """
    Convert PyTorch tensor to NumPy array.
    Handles both CPU and GPU tensors.
    
    Args:
        v: PyTorch tensor
    
    Returns:
        NumPy array
    """
    if USE_CUDA:
        return v.detach().cpu().numpy()
    return v.detach().numpy()

# Print GPU information
if torch.cuda.is_available():
    print(f"✓ CUDA is available. Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("✗ CUDA is not available. Using CPU.")
    print("  Warning: Training will be significantly slower without GPU.")

print("\n" + "="*70)
print("All libraries imported successfully!")
print("="*70)

---

## Section 3: Custom Data Loaders and Samplers

### Purpose
Implement custom Dataset and Sampler classes to handle **class imbalance**.

### The Class Imbalance Problem

In PPI inhibitor prediction:
- **Positive examples** (inhibitors): ~714 examples
- **Negative examples** (non-inhibitors): ~10,413 examples
- **Ratio**: ~1:15 (highly imbalanced)

**Why this is a problem:**
- Model can achieve high accuracy by always predicting "negative"
- Gradient updates dominated by negative examples
- Poor recall for positive (inhibitor) class

### Solutions Implemented

**1. BalancedDataset with WeightedRandomSampler**
- Assigns higher sampling probability to minority class
- Uses weighted random sampling with replacement
- Creates approximately balanced batches over time

**2. BinaryBalancedSampler (Stratified Sampling)**
- Oversamples minority class to match majority class size
- Creates batches with exactly 50% positive and 50% negative
- Uses StratifiedKFold to ensure balance in every batch
- Guarantees equal representation in each iteration

### When to Use Which
- **BalancedDataset**: When you want soft balancing with randomness
- **BinaryBalancedSampler**: When you need strict 50/50 splits (used in this work)

In [None]:
# ============================================================================
# CUSTOM DATASET AND SAMPLER CLASSES
# ============================================================================

class BalancedDataset(Dataset):
    """
    A custom dataset class that creates a balanced dataset from imbalanced data
    using weighted random sampling.
    
    This dataset calculates sample weights inversely proportional to class frequencies,
    which can be used with WeightedRandomSampler to achieve balanced batches.
    
    Mathematical Formulation:
    -------------------------
    For class c with n_c examples:
        weight_c = 1 / n_c
    
    For each sample i of class c:
        sample_weight_i = weight_c
    
    Example:
    --------
    If we have 100 positive and 1000 negative examples:
        - Positive weight = 1/100 = 0.01
        - Negative weight = 1/1000 = 0.001
    
    Positive examples are 10x more likely to be sampled.
    
    Attributes:
    -----------
    data : array-like
        Input data (can be list, NumPy array, or PyTorch tensor)
    labels : array-like
        Labels corresponding to the data (1D array)
    sample_weights : torch.Tensor
        Weights for each sample (inversely proportional to class frequency)
    
    NOTE: This involves stochastic sampling, so some training examples 
    may never be selected in a given epoch.
    """
    
    def __init__(self, data, labels):
        """
        Initialize the balanced dataset.
        
        Args:
            data: Input features (examples)
            labels: Class labels (0 or 1 for binary classification)
        """
        self.data = data
        self.labels = labels
        
        # Count the number of examples in each class
        # np.bincount([0, 1, 1, 0, 1]) -> [2, 3]
        # Index 0 = count of class 0, Index 1 = count of class 1
        class_counts = np.bincount(self.labels)
        
        # Assign weight inversely proportional to class frequency
        # Minority class gets higher weight
        weights = 1. / torch.tensor(class_counts, dtype=torch.float)
        
        # Create a weight list for each sample based on its class
        # If sample i has label j, assign weight_j to sample i
        self.sample_weights = weights[labels]
    
    def __len__(self):
        """Return the total number of samples in the dataset."""
        return len(self.data)
    
    def __getitem__(self, idx):
        """
        Get a single sample and its label.
        
        Args:
            idx: Index of the sample
        
        Returns:
            Tuple of (data, label) at index idx
        """
        return self.data[idx], self.labels[idx]


def create_balanced_loader(data, labels, batch_size=32):
    """
    Creates a DataLoader with balanced batches for imbalanced datasets.
    
    This function wraps BalancedDataset and WeightedRandomSampler to create
    a DataLoader that yields approximately balanced batches.
    
    How it works:
    -------------
    1. Create BalancedDataset (computes sample weights)
    2. Create WeightedRandomSampler with these weights
    3. Sampler draws samples with probability proportional to weights
    4. Minority class samples drawn more frequently
    5. Over many batches, classes become approximately balanced
    
    Args:
        data: Input features
        labels: Class labels (1D array)
        batch_size: Size of each batch (default: 32)
    
    Returns:
        DataLoader: PyTorch DataLoader yielding balanced batches
    
    Usage Example:
    --------------
    >>> data = [features1, features2, ...]  
    >>> labels = [0, 1, 1, 0, 1, ...]
    >>> balanced_loader = create_balanced_loader(data, labels, batch_size=32)
    >>> for batch_data, batch_labels in balanced_loader:
    >>>     # Train your model with approximately balanced batches
    >>>     pass
    """
    # Create the balanced dataset
    dataset = BalancedDataset(data, labels)
    
    # WeightedRandomSampler handles the actual balancing
    # - weights: probability of selecting each sample
    # - num_samples: how many samples to draw (typically len(dataset))
    # - replacement: True allows same sample multiple times (necessary for balancing)
    sampler = WeightedRandomSampler(
        weights=dataset.sample_weights,
        num_samples=len(dataset.sample_weights),
        replacement=True
    )
    
    # Create DataLoader with the balanced sampler
    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    return loader


class BinaryBalancedSampler(Sampler):
    """
    A PyTorch Sampler that returns batches with EXACTLY 50% positive and 50% negative.
    
    This sampler oversamples the minority class to match the majority class size,
    ensuring each batch contains equal numbers of positive and negative examples.
    
    Algorithm:
    ----------
    1. Identify majority class M and minority class m
    2. Count: |M| majority examples, |m| minority examples
    3. Oversample minority class: randomly select |M| examples from m (with replacement)
    4. Combine: [M examples] + [|M| oversampled m examples] = 2|M| total
    5. Use StratifiedKFold to create balanced batches from this balanced pool
    6. Each batch will have 50% from M and 50% from oversampled m
    
    Mathematical Properties:
    ------------------------
    - Original dataset size: N
    - After balancing: 2 * |M| examples per iteration
    - Equivalent epochs = (2 * |M|) / N
    
    Example:
    --------
    Original: 100 positive, 900 negative (1000 total)
    After balancing: 900 positive (800 oversampled), 900 negative (1800 total)
    Equivalent epochs = 1800 / 1000 = 1.8 epochs per iteration
    
    Attributes:
    -----------
    class_vector : array-like
        Class labels for all samples
    batch_size : int
        Size of each batch
    n_splits : int
        Number of batches in one iteration
    equivalent_epochs : float
        How many times the minority class is seen per iteration
    
    NOTE: This leads to more examples per epoch than the original dataset size.
    """
    
    def __init__(self, class_vector, batch_size=10):
        """
        Initialize the binary balanced sampler.
        
        Args:
            class_vector: Array of class labels (0 or 1)
            batch_size: Number of examples per batch
        """
        self.batch_size = batch_size
        self.class_vector = class_vector
        
        # Convert to numpy array for easier manipulation
        YY = np.array(self.class_vector)
        
        # Find unique classes and their counts
        # U: unique classes [0, 1]
        # C: counts [n_negative, n_positive]
        U, C = np.unique(YY, return_counts=True)
        
        # Find majority class (class with most examples)
        M = U[np.argmax(C)]  # Majority class label
        
        # Get indices of majority and minority classes
        Midx = np.nonzero(YY == M)[0]  # Indices where class == M
        midx = np.nonzero(YY != M)[0]  # Indices where class != M
        
        # Oversample minority class to match majority class size
        # np.random.choice samples WITH replacement by default
        midx_ = np.random.choice(midx, size=len(Midx))
        
        # Create balanced dataset:
        # - All majority class examples
        # - Oversampled minority class examples (same count as majority)
        self.YY = np.array(list(YY[Midx]) + list(YY[midx_]))
        self.idx = np.array(list(Midx) + list(midx_))
        
        # Calculate number of batches
        self.n_splits = int(np.ceil(len(self.idx) / self.batch_size))
        
        # Calculate equivalent epochs
        # This tells us how many times we see the original dataset
        self.equivalent_epochs = len(self.idx) / len(self.class_vector)
        
        print(f'Equivalent epochs in one iteration of data loader: {self.equivalent_epochs:.2f}')
        print(f'Original dataset size: {len(self.class_vector)}')
        print(f'Balanced dataset size: {len(self.idx)}')
        print(f'Number of batches: {self.n_splits}')
    
    def gen_sample_array(self):
        """
        Generate balanced batches using StratifiedKFold.
        
        StratifiedKFold ensures each fold (batch) has the same class distribution
        as the overall dataset. Since our balanced dataset is 50/50, each batch
        will also be approximately 50/50.
        
        Yields:
            numpy array: Indices for each balanced batch
        """
        # Use StratifiedKFold to maintain class balance in each split
        skf = StratifiedKFold(n_splits=self.n_splits, shuffle=True)
        
        # For each split, yield the test indices
        # tridx: training indices (unused)
        # ttidx: test indices (used as batch)
        for tridx, ttidx in skf.split(self.idx, self.YY):
            yield np.array(self.idx[ttidx])
    
    def __iter__(self):
        """Return an iterator over batch indices."""
        return iter(self.gen_sample_array())
    
    def __len__(self):
        """Return the number of batches."""
        return self.n_splits


class CustomDataset(Dataset):
    """
    Simple custom dataset wrapper for PyTorch DataLoader.
    
    This is a minimal Dataset implementation that stores data and labels
    and returns them when indexed.
    
    Attributes:
        data: Input features or example identifiers
        labels: Corresponding class labels
    """
    
    def __init__(self, data, labels):
        """
        Initialize the dataset.
        
        Args:
            data: Input data
            labels: Class labels
        """
        self.data = data
        self.labels = labels
    
    def __len__(self):
        """Return the number of samples."""
        return len(self.data)
    
    def __getitem__(self, idx):
        """
        Get a single sample.
        
        Args:
            idx: Index of the sample
        
        Returns:
            Tuple of (data, label) at index idx
        """
        return self.data[idx], self.labels[idx]


# ============================================================================
# TESTING THE SAMPLERS (Example Usage)
# ============================================================================
if __name__ == '__main__':
    print("\n" + "="*70)
    print("Testing BinaryBalancedSampler with Example Data")
    print("="*70)
    
    # Create example data: 100 samples with 70% class 0 and 30% class 1 (imbalanced)
    E = [(str(p_i), str(-1 * c_i)) for p_i, c_i in zip(range(100), range(100))]
    Y = np.random.randint(0, 2, size=100, p=[0.7, 0.3])  # Imbalanced: 70% class 0
    batch_size = 10
    
    print(f"\nOriginal class distribution:")
    print(f"  Class 0: {np.sum(Y == 0)} examples ({np.sum(Y == 0)/len(Y)*100:.1f}%)")
    print(f"  Class 1: {np.sum(Y == 1)} examples ({np.sum(Y == 1)/len(Y)*100:.1f}%)")
    
    # Test BinaryBalancedSampler
    print("\nCreating BinaryBalancedSampler...")
    dataset = CustomDataset(E, Y)
    batch_sampler = BinaryBalancedSampler(Y, batch_size)
    data_loader = DataLoader(dataset, batch_sampler=batch_sampler)
    
    print("\nFirst 3 batches from BinaryBalancedSampler:")
    for i, batch in enumerate(data_loader):
        if i >= 3:  # Only show first 3 batches
            break
        batch_data, batch_labels = batch
        n_class0 = torch.sum(batch_labels == 0).item()
        n_class1 = torch.sum(batch_labels == 1).item()
        print(f"  Batch {i+1}: Class 0: {n_class0}, Class 1: {n_class1} "
              f"(Balance: {n_class0/(n_class0+n_class1)*100:.1f}% / "
              f"{n_class1/(n_class0+n_class1)*100:.1f}%)")
    
    # Test create_balanced_loader
    print("\nTesting create_balanced_loader...")
    balanced_loader = create_balanced_loader(E, Y, batch_size)
    
    print("\nFirst 3 batches from WeightedRandomSampler:")
    for i, (batch_data, batch_labels) in enumerate(balanced_loader):
        if i >= 3:  # Only show first 3 batches
            break
        n_class0 = torch.sum(batch_labels == 0).item()
        n_class1 = torch.sum(batch_labels == 1).item()
        print(f"  Batch {i+1}: Class 0: {n_class0}, Class 1: {n_class1} "
              f"(Balance: {n_class0/(n_class0+n_class1)*100:.1f}% / "
              f"{n_class1/(n_class0+n_class1)*100:.1f}%)")
    
    print("\n" + "="*70)
    print("Custom Data Loaders and Samplers Defined Successfully!")
    print("="*70)

---

## Section 4: Protein Feature Extraction Functions

### Purpose
Extract structural and chemical features from protein complexes in PDB format.

### Overview of Protein Structure

A protein structure consists of a hierarchy:
```
Structure
  └── Model (usually just 1)
       └── Chain (e.g., A, B)
            └── Residue (amino acid)
                 └── Atom (C, N, O, etc.)
```

### Features Extracted

**1. Atom Features (`atom1`)**
- One-hot encoding of atom types
- 13 atom categories: C, CA, CB, CG, CH2, N, NH2, OG, OH, O1, O2, SE, Unknown
- Output: N × 13 matrix (N = number of atoms)
- Purpose: Represent chemical identity of each atom for GNN

**2. Residue Features (`res1`)**
- One-hot encoding of amino acid types
- 21 categories: 20 standard amino acids + Unknown
- Output: N × 21 matrix (N = number of atoms)
- Each atom tagged with its parent residue type
- Purpose: Capture biochemical properties of amino acids

**3. Neighborhood Graph (`neigh1`)**
- Finds 10 nearest neighbors for each atom
- Separates neighbors by residue:
  - **Same residue**: Atoms within same amino acid (local structure)
  - **Different residue**: Atoms from other amino acids (inter-residue contacts)
- Distance threshold: 6 Angstroms
- Output: Two N × 10 adjacency matrices
- Purpose: Define graph structure for message passing in GNN

### Why These Features?

The GNN learns from **local atomic environment**:
- Atom type: What is this atom?
- Residue type: What amino acid does it belong to?
- Neighbors: What atoms are nearby?

This captures the 3D structure and chemistry of the protein complex.