In [None]:
!pip install torch-geometric biopython fair-esm tqdm # install this running baseline GCN

In [None]:
BASE_DIR = '/content/drive/MyDrive/02703/9606' # replace with your own

Before running any of the cells, please download 9606.protein.physical.links.v12.0.txt, 9606.protein.enrichment.terms.v12.0.txt, 9606.protein.sequences.v12.0.fa, 9606.protein.aliases.v12.0.txt into your [BASE_DIR], we will need the first two files to construct PPIN, and latter two to extract ESM embeddings and 3D structure data.

To install package for running baseline GCN, run:

pip install torch-geometric biopython fair-esm tqdm

For the pacakges to run GVP, there is a section afterwards to install dgl



# Construct PPIN

In [None]:
import os

PPI_LINKS_PATH = os.path.join(BASE_DIR, '9606.protein.physical.links.v12.0.txt')
GO_TERMS_PATH = os.path.join(BASE_DIR, '9606.protein.enrichment.terms.v12.0.txt')

In [None]:
import torch

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
def load_ppi_links(fpath):
    protein_set = set()
    edge_list = []
    with open(fpath, 'r') as file:
        next(file)  # Skip header
        for line in file:
            p1, p2, score = line.strip().split()
            edge_list.append((p1, p2))
            protein_set.add(p1)
            protein_set.add(p2)
    return list(protein_set), edge_list

proteins, edge_list = load_ppi_links(PPI_LINKS_PATH)

# Step 3: Load GO terms
def load_go_terms(fpath, target_category="Molecular Function (Gene Ontology)"):
    go_term_dict = {}
    with open(fpath, 'r') as file:
        next(file)  # Skip header
        for line in file:
            parts = line.strip().split('\t')
            if len(parts) == 4:
                string_id, category, term, desc = parts
                if category == target_category:
                    if string_id not in go_term_dict:
                        go_term_dict[string_id] = []
                    go_term_dict[string_id].append(term)
    return go_term_dict

go_terms = load_go_terms(GO_TERMS_PATH)


# Pre-processing

In [None]:
from sklearn.preprocessing import MultiLabelBinarizer

proteins_with_labels = [protein for protein in proteins if protein in go_terms]
filtered_labels = [go_terms[protein] for protein in proteins_with_labels]
filtered_mlb = MultiLabelBinarizer()
filtered_y = filtered_mlb.fit_transform(filtered_labels)
data_y = torch.tensor(filtered_y, dtype=torch.float)

# Map proteins to indices and create filtered edge index
protein_to_idx = {protein: i for i, protein in enumerate(proteins_with_labels)}
filtered_edge_list = [(protein_to_idx[p1], protein_to_idx[p2])
                      for p1, p2 in edge_list if p1 in protein_to_idx and p2 in protein_to_idx]
filtered_edge_index = torch.tensor(filtered_edge_list, dtype=torch.long).t().contiguous()

In [None]:
from sklearn.model_selection import train_test_split

#Train/Val/Test split
num_proteins = len(proteins_with_labels)
train_indices, test_indices = train_test_split(range(num_proteins), test_size=0.2, random_state=42)
train_indices, val_indices = train_test_split(train_indices, test_size=0.25, random_state=42)  # 60% train, 20% val, 20% test

train_mask = torch.zeros(num_proteins, dtype=torch.bool)
val_mask = torch.zeros(num_proteins, dtype=torch.bool)
test_mask = torch.zeros(num_proteins, dtype=torch.bool)
train_mask[train_indices] = True
val_mask[val_indices] = True
test_mask[test_indices] = True

# Baseline GCN

## GCN

In [None]:
from torch_geometric.data import Data
from torch.optim import AdamW
import torch
from torch_geometric.nn import GCNConv
from sklearn.preprocessing import MultiLabelBinarizer
import torch.nn as nn
import torch_geometric.nn as pyg_nn
import torch.nn.functional as F
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
import matplotlib.pyplot as plt

In [None]:
# one-hot encoding
num_unique_proteins = len(set(proteins_with_labels))  # Count unique proteins
one_hot_features = torch.eye(num_unique_proteins)

# zero features encoding
feature_dim = 1  # Minimal feature dimension; can be adjusted if needed
minimal_features = torch.zeros((len(proteins_with_labels), feature_dim))

In [None]:
# Create data objects for one-hot
minimal_features_data = Data(x=minimal_features, edge_index=filtered_edge_index, y=data_y,
                   train_mask=train_mask, val_mask=val_mask, test_mask=test_mask).to(device)


In [None]:
class GCNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCNModel, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

In [None]:
# Define training function
def train_minibatch_model(model, train_loader, val_loader, optimizer, criterion, epochs=100):
    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch.x.to(device), batch.edge_index.to(device))
            loss = criterion(out[batch.train_mask].to(device), batch.y[batch.train_mask].to(device))
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_losses.append(train_loss / len(train_loader))

        # Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                out = model(batch.x.to(device), batch.edge_index.to(device))
                loss = criterion(out[batch.val_mask].to(device), batch.y[batch.val_mask].to(device))
                val_loss += loss.item()

        val_losses.append(val_loss / len(val_loader))

        # Print progress
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}")

    return train_losses, val_losses

# Define evaluation function
def evaluate_minibatch_model(model, test_loader):
    model.eval()
    predictions = []
    labels = []

    with torch.no_grad():
        for batch in test_loader:
            out = model(batch.x.to(device), batch.edge_index.to(device))
            preds = torch.sigmoid(out[batch.test_mask].to(device)) > 0.5
            predictions.append(preds.cpu())
            labels.append(batch.y[batch.test_mask].cpu())

    predictions = torch.cat(predictions, dim=0)
    labels = torch.cat(labels, dim=0)

    # Compute metrics
    accuracy = (predictions == labels).sum().item() / labels.numel()
    f1 = f1_score(labels, predictions, average="micro")
    precision = precision_score(labels, predictions, average="micro")
    recall = recall_score(labels, predictions, average="micro")
    print(f"Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")
    return accuracy, f1, precision, recall


In [None]:
# import NeighborLoader
from torch_geometric.loader import NeighborLoader

train_loader = NeighborLoader(
    minimal_features_data,
    num_neighbors=[10, 10],
    input_nodes=minimal_features_data.train_mask,
    batch_size=128,
    shuffle=True
)

val_loader = NeighborLoader(
    minimal_features_data,
    num_neighbors=[10, 10],
    input_nodes=minimal_features_data.val_mask,
    batch_size=128,
    shuffle=False
)

test_loader = NeighborLoader(
    minimal_features_data,
    num_neighbors=[10, 10],
    input_nodes=minimal_features_data.test_mask,
    batch_size=128,
    shuffle=False
)

In [None]:
# Initialize models, loss, and optimizers
hidden_dim = 8
output_dim = data_y.shape[1]  # Number of unique GO terms

In [None]:
!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install pyg-lib -f https://data.pyg.org/whl/torch-2.4.0+cu121.html



In [None]:
# min feature Model
min_feat_model = GCNModel(input_dim=minimal_features_data.num_features, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
optimizer = AdamW(min_feat_model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

# Train and evaluate the min-feat encoding model
print("Training min-feat Encoding Model")
min_feat_train_losses, min_feat_val_losses = train_minibatch_model(min_feat_model, train_loader, val_loader, optimizer, criterion, epochs=150)

print("\nFinal Evaluation on Test Set (min-feat Encoding Model)")
evaluate_minibatch_model(min_feat_model, test_loader)


In [None]:
from sklearn.metrics import precision_recall_curve, auc
import numpy as np

import numpy as np
from sklearn.metrics import precision_recall_curve, auc

def compute_metrics(predictions, labels, thresholds=np.linspace(0, 1, 101)):
    """
    predictions: numpy array of shape (num_proteins, num_terms), predicted probabilities
    labels: numpy array of shape (num_proteins, num_terms)
    thresholds: array-like, list of thresholds to evaluate for Fmax

    Returns:
        fmax: Average maximum F1 score across proteins
        aupr: Average area under precision-recall curve across GO terms
        f1_max_scores: numpy array of fmax_values
        aupr_scores: numpy array of aupr values
    """
    num_proteins, num_terms = labels.shape

    f1_max_scores = []
    for i in range(num_proteins):
        protein_preds = predictions[i]
        protein_labels = labels[i]

        max_f1 = 0
        for t in thresholds:
            binary_preds = (protein_preds >= t).astype(int)
            tp = np.sum((binary_preds == 1) & (protein_labels == 1))
            fp = np.sum((binary_preds == 1) & (protein_labels == 0))
            fn = np.sum((binary_preds == 0) & (protein_labels == 1))

            precision = tp / (tp + fp) if tp + fp > 0 else 0
            recall = tp / (tp + fn) if tp + fn > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
            max_f1 = max(max_f1, f1)

        f1_max_scores.append(max_f1)

    fmax = np.mean(f1_max_scores)

    aupr_scores = []
    for j in range(num_terms):
        term_preds = predictions[:, j]
        term_labels = labels[:, j]

        precision, recall, _ = precision_recall_curve(term_labels, term_preds)
        aupr = auc(recall, precision)
        aupr_scores.append(aupr)

    aupr = np.mean(aupr_scores)

    return fmax, aupr, f1_max_scores, aupr_scores


def evaluate_minibatch_model_with_metrics(model, test_loader):
    model.eval()
    predictions = []
    labels = []

    with torch.no_grad():
        for batch in test_loader:
            out = model(batch.x.to(device), batch.edge_index.to(device))
            preds = torch.sigmoid(out[batch.test_mask]).cpu()
            predictions.append(preds)
            labels.append(batch.y[batch.test_mask].cpu())

    # Combine predictions and labels across all batches
    predictions = torch.cat(predictions, dim=0).numpy()
    labels = torch.cat(labels, dim=0).numpy()

    # Compute metrics
    accuracy = (predictions.round() == labels).sum() / labels.size
    f1 = f1_score(labels, predictions.round(), average="micro")
    precision = precision_score(labels, predictions.round(), average="micro")
    recall = recall_score(labels, predictions.round(), average="micro")

    # Compute Fmax and AUPR
    fmax, aupr, f1_max_scores, aupr_scores = compute_metrics(predictions, labels)

    print(f"Accuracy: {accuracy:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"Fmax: {fmax:.4f}")
    print(f"AUPR: {aupr:.4f}")
    return accuracy, f1, precision, recall, fmax, aupr

# min_feat_model = GCNModel(input_dim=minimal_features_data.num_features, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
# min_feat_model.load_state_dict(torch.load('/content/drive/MyDrive/02703/9606/weights/gcn_model.pth'))

# evaluete on test set
print("\nFinal Evaluation on Test Set (min-feat Encoding Model)")
evaluate_minibatch_model_with_metrics(min_feat_model, test_loader)

## GAT

In [None]:
from torch_geometric.nn import GATConv

class GATModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, heads=8, dropout=0.5):
        super(GATModel, self).__init__()
        # Define the first GAT layer with multi-head attention
        self.gat1 = GATConv(input_dim, hidden_dim, heads=heads, dropout=dropout)
        # Define the second GAT layer, combining heads
        self.gat2 = GATConv(hidden_dim * heads, output_dim, heads=1, concat=False, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, edge_index):
        x = self.gat1(x, edge_index)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.gat2(x, edge_index)
        return x

In [None]:
# One-Hot Model
onehot_gat = GATModel(input_dim=onehot_data.num_features, hidden_dim=hidden_dim, output_dim=output_dim)
optimizer_onehot_gat = AdamW(onehot_gat.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

# Train and evaluate the one-hot encoding model
print("Training One-Hot Encoding Model")
onehot_gat_train_losses, onehot_gat_val_losses = train_model(onehot_gat, onehot_data, optimizer_onehot_gat, criterion, epochs=150)
print("\nFinal Evaluation on Test Set (One-Hot Encoding Model)")
evaluate_model(onehot_gat, onehot_data, mask=onehot_data.test_mask)

In [None]:
# ESM Model
esm_gat = GATModel(input_dim=esm_data.num_features, hidden_dim=hidden_dim, output_dim=output_dim)
optimizer_esm_gat = AdamW(esm_gat.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

# Train and evaluate the one-hot encoding model
print("Training ESM Encoding Model")
esm_gat_train_losses, esm_gat_val_losses = train_model(esm_gat, esm_data, optimizer_esm_gat, criterion, epochs=150)
print("\nFinal Evaluation on Test Set (ESM Encoding Model)")
evaluate_model(esm_gat, esm_data, mask=esm_data.test_mask)

In [None]:
plt.figure(figsize=(12, 6))

plt.plot(onehot_train_losses, label="GCN (onehot) Train Loss")
plt.plot(onehot_val_losses, label="GCN (onehot) Val Loss")
plt.plot(esm_train_losses, label="GCN (onehot + ESM) Train Loss")
plt.plot(esm_val_losses, label="GCN (onehot + ESM) Val Loss")

plt.plot(onehot_gat_train_losses, label="GAT (onehot) Train Loss")
plt.plot(onehot_gat_val_losses, label="GAT (onehot) Val Loss")
plt.plot(esm_gat_train_losses, label="GAT (onehot + ESM) Train Loss")
plt.plot(esm_gat_val_losses, label="GAT (onehot + ESM) Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Multi-label classification loss")
plt.legend()
plt.show()

# extract ESM embeddding

In [None]:
import gzip
import esm
from Bio import SeqIO
from tqdm import tqdm
from sklearn.preprocessing import MultiLabelBinarizer
import torch.nn as nn
import torch_geometric.nn as pyg_nn
import torch.nn.functional as F

BASE_DIR = '/content/drive/MyDrive/02703/9606'  # Update with your Google Drive or Colab path
SEQUENCE_FILE = os.path.join(BASE_DIR, '9606.protein.sequences.v12.0.fa.gz')

# Parse sequences from .fa.gz file using STRING IDs as keys
def parse_sequences(sequence_file):
    string_to_sequence = {}
    with gzip.open(sequence_file, "rt") as f:
        for record in SeqIO.parse(f, "fasta"):
            string_id = record.id  # Directly use STRING ID from record.id
            sequence = str(record.seq)
            string_to_sequence[string_id] = sequence
    return string_to_sequence

# Load STRING sequences directly from the .fa.gz file
string_to_sequence = parse_sequences(SEQUENCE_FILE)
print("Sample STRING ID to sequence mapping:", list(string_to_sequence.items())[:5])

# Load pre-trained ESM model
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
model = model.to("cuda")  # Send model to GPU
batch_converter = alphabet.get_batch_converter()
model.eval()

# Generate ESM Embeddings
protein_embeddings = []
max_seq_len = 1022  # Truncate to avoid exceeding model's max length
for protein in tqdm(list(string_to_sequence.keys()), desc="Generating ESM embeddings"):
    sequence = string_to_sequence[protein]
    if sequence:
        truncated_sequence = sequence[:max_seq_len]
        batch_labels, batch_strs, batch_tokens = batch_converter([(protein, truncated_sequence)])
        batch_tokens = batch_tokens.to("cuda")  # Send tokens to GPU

        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[33], return_contacts=False)
            token_representations = results["representations"][33]
            embedding = token_representations.mean(dim=1).squeeze().cpu()
            protein_embeddings.append(embedding)
    else:
        protein_embeddings.append(torch.zeros(model.args.embed_dim))

x = torch.stack(protein_embeddings)
print("ESM Embeddings generated successfully.")

# Save embeddings to avoid regenerating them
torch.save(x, "/content/drive/MyDrive/02703/9606/protein_esm_embeddings.pt") # save it under BASE_DIR
print("ESM embeddings saved successfully.")

# ESM

In [None]:
esm_data = Data(x=filtered_esm_embeddings, edge_index=filtered_edge_index, y=data_y,
                train_mask=train_mask, val_mask=val_mask, test_mask=test_mask).to(device)

In [None]:
# create esm model
esm_feat_model = GCNModel(input_dim=esm_data.num_features, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
optimizer = AdamW(esm_feat_model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

esm_feat_train_losses, esm_feat_val_losses = train_minibatch_model(esm_feat_model, train_loader, val_loader, optimizer, criterion, epochs=150)

In [None]:
# ESM models
esm_embeddings = torch.load(f"{BASE_DIR}/protein_esm_embeddings.pt")
filtered_esm_embeddings = torch.stack([esm_embeddings[protein_to_idx[protein]]
                                       for protein in proteins_with_labels])

test_loader = NeighborLoader(
    esm_data,
    num_neighbors=[10, 10],
    input_nodes=esm_data.test_mask,
    batch_size=128,
    shuffle=False
)

esm_model = GCNModel(input_dim=esm_data.num_features, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
esm_model.load_state_dict(torch.load('/content/drive/MyDrive/02703/9606/weights/gcn_esm_model.pth'))

# evaluete on test set
print("\nFinal Evaluation on Test Set (ESM Encoding Model)")
evaluate_minibatch_model_with_metrics(esm_model, test_loader)

In [None]:
# ESM Model
esm_embeddings = torch.load(f"{BASE_DIR}/protein_esm_embeddings.pt")
filtered_esm_embeddings = torch.stack([esm_embeddings[protein_to_idx[protein]]
                                       for protein in proteins_with_labels])
esm_data = Data(x=filtered_esm_embeddings, edge_index=filtered_edge_index, y=data_y,
                train_mask=train_mask, val_mask=val_mask, test_mask=test_mask).to(device)
esm_model = GCNModel(input_dim=esm_data.num_features, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
optimizer = AdamW(esm_model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

# Train and evaluate the min-feat encoding model
print("Training esm Encoding Model")
esm_train_losses,esm_val_losses = train_minibatch_model(esm_model, train_loader, val_loader, optimizer, criterion, epochs=150)

In [None]:
plt.figure(figsize=(12, 6))
plt.plot(onehot_train_losses, label="One-Hot Training Loss", color='blue')
plt.plot(onehot_val_losses, label="One-Hot Validation Loss", color='cyan')
plt.plot(esm_train_losses, label="ESM Training Loss", color='red')
plt.plot(esm_val_losses, label="ESM Validation Loss", color='orange')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("GCN using one-hot VS ESM embedding")
plt.legend()
plt.show()

# GVP

## extrat PDB files from AlphaFold using protein alias

In [None]:
import os
import requests
from tqdm import tqdm

def extract_aliases(base_dir):
    """
    Extracts aliases from the alias file under the base directory.

    :param base_dir: Directory containing input files, including the aliases file.
    :return: A dictionary where keys are STRING protein IDs and values are lists of UniProt aliases.
    """
    alias_file = os.path.join(base_dir, f'{base_dir}.protein.aliases.v12.0.txt')
    alias_dict = {}

    with open(alias_file, 'r') as file:
        next(file)  # Skip the header
        for line in file:
            string_id, alias, source = line.strip().split('\t')
            if source == 'UniProt_AC':  # Filter for UniProt aliases
                if string_id not in alias_dict:
                    alias_dict[string_id] = []
                alias_dict[string_id].append(alias)

    print(f"Extracted {len(alias_dict)} protein aliases from {alias_file}.")
    return alias_dict

def download_structures(alias_dict, base_dir):
    """
    Downloads PDB structure files for proteins using UniProt aliases from AlphaFold.
    Saves the files in a 'structures' subdirectory under the specified base directory.

    :param alias_dict: Dictionary where keys are STRING protein IDs and values are lists of UniProt aliases.
    :param base_dir: Directory containing input files. The 'structures' directory will be created here.
    """
    # Define the structures directory under the base directory
    structures_dir = os.path.join(base_dir, "structures")
    os.makedirs(structures_dir, exist_ok=True)  # Ensure the structures directory exists

    total_proteins = len(alias_dict)
    print(f"Starting download for {total_proteins} proteins into '{structures_dir}'...")

    for protein in tqdm(alias_dict, desc="Downloading structures"):
        downloaded = False
        for alias in alias_dict[protein]:
            url = f'https://alphafold.ebi.ac.uk/files/AF-{alias}-F1-model_v4.pdb'
            try:
                response = requests.get(url, timeout=10)
                if response.status_code == 200:
                    file_path = os.path.join(structures_dir, f'{protein}.pdb')
                    with open(file_path, 'wb') as file:
                        file.write(response.content)
                    downloaded = True
                    break  # Exit loop after successful download
            except requests.RequestException as e:
                print(f"Error downloading {protein} with alias {alias}: {e}")

        if not downloaded:
            print(f"Failed to download structure for protein {protein}.")

    print(f"Download completed. Structures saved in '{structures_dir}'.")

base_dir = "9606"
alias_dict = extract_aliases(base_dir)
download_structures(alias_dict, base_dir)


## GVP + GNN

## install correct version of dgl

In [None]:
import torch
print(torch.__version__)  # Should match 2.0.1
print(torch.cuda.is_available())  # Should return True
print(torch.version.cuda)  # Should return '11.7' (from the above output)


In [None]:
pip uninstall -y torch torchvision torchaudio

In [None]:
pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121


In [None]:
pip install torchdata==0.8.0

In [None]:
pip install  dgl -f https://data.dgl.ai/wheels/torch-2.4/cu121/repo.html

In [None]:
# # install dgl for cpu
# !pip uninstall -y torch torchvision torchaudio torchdata dgl
# !pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cpu
# !pip uninstall -y torchdata
# !pip install torchdata==0.6.1  # Replace with the compatible version for PyTorch 2.5.1 if different
# !pip install dgl


In [None]:
import dgl
import torch

# Check DGL version
print("DGL version:", dgl.__version__)

# Check if DGL can access CUDA
print("Is CUDA available in DGL:", torch.cuda.is_available())

# Create a simple graph and move it to the GPU
g = dgl.graph(([0, 1], [1, 2])).to('cuda')
print("Graph device:", g.device)


## Construct the graph with 3D structure data using dgl

In [None]:
import os
os.environ["DGLBACKEND"] = "pytorch"
import dgl

In [None]:
import os
import torch
from tqdm import tqdm

# Define amino acids for one-hot encoding
amino_acids = [
    "ALA", "ARG", "ASN", "ASP", "CYS",
    "GLU", "GLN", "GLY", "HIS", "ILE",
    "LEU", "LYS", "MET", "PHE", "PRO",
    "SER", "THR", "TRP", "TYR", "VAL"
]
aa_to_index = {aa: idx for idx, aa in enumerate(amino_acids)}

def parse_pdb_residue_level(pdb_file_path):
    """
    Parse residue-level features from a PDB file.

    Args:
        pdb_file_path: Path to the PDB file.

    Returns:
        coordinates: List of residue-level centroid coordinates.
        scalar_features: List of one-hot encoded residue types.
        missing: Boolean indicating if the file is missing or failed to parse.
    """
    residues = {}
    coordinates = []
    scalar_features = []
    missing = False

    if not os.path.isfile(pdb_file_path):
        return coordinates, scalar_features, True

    try:
        with open(pdb_file_path, 'r') as file:
            for line in file:
                if line.startswith("ATOM") or line.startswith("HETATM"):
                    try:
                        # Extract residue ID and atom coordinates
                        residue_id = (line[21], int(line[22:26].strip()))  # Chain ID and Residue Number
                        x = float(line[30:38].strip())
                        y = float(line[38:46].strip())
                        z = float(line[46:54].strip())

                        # Extract residue type
                        aa = line[17:20].strip()
                        if residue_id not in residues:
                            residues[residue_id] = {"coords": [], "aa": aa}

                        residues[residue_id]["coords"].append([x, y, z])
                    except ValueError:
                        continue

        # Compute residue centroids and one-hot encode residue types
        for residue_id, residue_data in residues.items():
            if residue_data["aa"] in aa_to_index:
                centroid = torch.tensor(residue_data["coords"], dtype=torch.float).mean(dim=0)
                one_hot = torch.zeros(len(amino_acids))
                one_hot[aa_to_index[residue_data["aa"]]] = 1.0
                coordinates.append(centroid)
                scalar_features.append(one_hot)

    except Exception as e:
        print(f"Error reading PDB file {pdb_file_path}: {e}")
        missing = True

    return coordinates, scalar_features, missing

def prepare_features_with_coords(proteins, pdb_dir):
    """
    Prepare residue-level features for a list of proteins.

    Args:
        proteins: List of protein IDs.
        pdb_dir: Directory containing PDB files.

    Returns:
        scalar_features: Dictionary of scalar features for each protein.
        vector_features: Dictionary of vector features for each protein.
        coord_features: Dictionary of centroid coordinates for each protein.
        missing_proteins: List of proteins missing PDB files.
    """
    scalar_features = {}
    vector_features = {}
    coord_features = {}
    missing_proteins = []

    for protein in tqdm(proteins, desc="Processing proteins"):
        pdb_file = os.path.join(pdb_dir, f"{protein}.pdb")
        coords, scalar_feats, missing = parse_pdb_residue_level(pdb_file)
        if not missing:
            coords_tensor = torch.stack(coords)
            vector_feats = coords_tensor[1:] - coords_tensor[:-1]  # Residue-level vector differences
            coord_features[protein] = coords_tensor  # All residue centroids
            scalar_features[protein] = torch.stack(scalar_feats)  # All residue scalar features
            vector_features[protein] = vector_feats
        else:
            missing_proteins.append(protein)

    print(f"Missing structures for {len(missing_proteins)} proteins.")
    return scalar_features, vector_features, coord_features, missing_proteins


# Prepare features with coordinates
scalar_feats_dict, vector_feats_dict, coord_feats_dict, missing_proteins = prepare_features_with_coords(
    proteins_with_labels, "/content/drive/MyDrive/02703/9606/structures/structures"
    )


In [None]:
def align_features_to_graph(graph_nodes, scalar_feats_dict, vector_feats_dict, coord_feats_dict, max_scalar_len=None, max_vector_len=None):
    """
    Align scalar, vector, and coordinate features to graph nodes.

    Args:
        graph_nodes: List of protein identifiers in the graph.
        scalar_feats_dict: Dictionary of scalar features per protein.
        vector_feats_dict: Dictionary of vector features per protein.
        coord_feats_dict: Dictionary of coordinate features per protein.
        max_scalar_len: Maximum length of scalar features (truncate/pad to this length).
        max_vector_len: Maximum length of vector features (truncate/pad to this length).

    Returns:
        aligned_scalar_feats: Tensor of shape [num_graph_nodes, max_scalar_len, 20].
        aligned_vector_feats: Tensor of shape [num_graph_nodes, max_vector_len, 3].
        aligned_coord_feats: Tensor of shape [num_graph_nodes, 3].
        max_scalar_len, max_vector_len: Maximum lengths used for scalar and vector features.
    """
    aligned_scalar_feats = []
    aligned_vector_feats = []
    aligned_coord_feats = []

    # Calculate max_scalar_len and max_vector_len dynamically if not provided
    if max_scalar_len is None:
        max_scalar_len = max([v.size(0) for v in scalar_feats_dict.values()]) if scalar_feats_dict else 0
        print(f"Computed max_scalar_len: {max_scalar_len}")

    if max_vector_len is None:
        max_vector_len = max([v.size(0) for v in vector_feats_dict.values()]) if vector_feats_dict else 0
        print(f"Computed max_vector_len: {max_vector_len}")

    for protein in graph_nodes:
        if protein in scalar_feats_dict:
            # Scalar features: truncate or pad
            scalar_feat = scalar_feats_dict[protein]
            if scalar_feat.size(0) > max_scalar_len:
                scalar_feat = scalar_feat[:max_scalar_len]  # Truncate
            elif scalar_feat.size(0) < max_scalar_len:
                padding = torch.zeros((max_scalar_len - scalar_feat.size(0), scalar_feat.size(1)))
                scalar_feat = torch.cat([scalar_feat, padding], dim=0)  # Pad
            aligned_scalar_feats.append(scalar_feat)

            # Vector features: truncate or pad
            vector_feat = vector_feats_dict[protein]
            if vector_feat.size(0) > max_vector_len:
                vector_feat = vector_feat[:max_vector_len]  # Truncate
            elif vector_feat.size(0) < max_vector_len:
                padding = torch.zeros((max_vector_len - vector_feat.size(0), vector_feat.size(1)))
                vector_feat = torch.cat([vector_feat, padding], dim=0)  # Pad
            aligned_vector_feats.append(vector_feat)

            # Coordinate features: use aggregated centroid
            coord_feat = coord_feats_dict[protein].mean(dim=0)  # Aggregate to single coordinate
            aligned_coord_feats.append(coord_feat)
        else:
            print(f"Protein: {protein} missing in features. Appending zeros.")
            # Handle missing proteins
            aligned_scalar_feats.append(torch.zeros((max_scalar_len, 20)))  # Placeholder for scalar features
            aligned_vector_feats.append(torch.zeros((max_vector_len, 3)))  # Placeholder for vector features
            aligned_coord_feats.append(torch.zeros(3))  # Placeholder for aggregated coordinate

    # Convert lists of tensors into batched tensors
    aligned_scalar_feats = torch.stack(aligned_scalar_feats)  # Shape: [num_graph_nodes, max_scalar_len, 20]
    aligned_vector_feats = torch.stack(aligned_vector_feats)  # Shape: [num_graph_nodes, max_vector_len, 3]
    aligned_coord_feats = torch.stack(aligned_coord_feats)  # Shape: [num_graph_nodes, 3]

    return aligned_scalar_feats, aligned_vector_feats, aligned_coord_feats, max_scalar_len, max_vector_len

# Adjust `max_vector_len` as necessary based on the average size of vector features in your dataset.
aligned_scalar_feats, aligned_vector_feats, aligned_coord_feats, max_scalar_len, max_vector_len = align_features_to_graph(
    graph_nodes=proteins_with_labels,
    scalar_feats_dict=scalar_feats_dict,
    vector_feats_dict=vector_feats_dict,
    coord_feats_dict=coord_feats_dict
)

In [None]:
# Save the features
os.makedirs(f"{BASE_DIR}/saved_features", exist_ok=True)

torch.save(aligned_scalar_feats, os.path.join(f"{BASE_DIR}/saved_features", "aligned_scalar_feats.pt"))
torch.save(aligned_vector_feats, os.path.join(f"{BASE_DIR}/saved_features", "aligned_vector_feats.pt"))
torch.save(aligned_coord_feats, os.path.join(f"{BASE_DIR}/saved_features", "aligned_coord_feats.pt"))

print(f"Features saved to {BASE_DIR}/saved_features")

In [None]:
# Load the features
aligned_scalar_feats = torch.load(os.path.join(f"{BASE_DIR}/saved_features", "aligned_scalar_feats.pt"))
aligned_vector_feats = torch.load(os.path.join(f"{BASE_DIR}/saved_features", "aligned_vector_feats.pt"))
aligned_coord_feats = torch.load(os.path.join(f"{BASE_DIR}/saved_features", "aligned_coord_feats.pt"))

print("Features loaded successfully!")

In [None]:
# Create the graph
graph = dgl.graph((filtered_edge_index[0], filtered_edge_index[1]))

# assign features to the graph
graph.ndata['scalar_feats'] = aligned_scalar_feats.mean(dim=1)
graph.ndata['vector_feats'] = aligned_vector_feats
graph.ndata['coords'] = aligned_coord_feats
graph.ndata['esm'] = filtered_esm_embeddings

## GVP module

In [None]:
import torch
from torch import nn, einsum
import dgl
import dgl.function as fn
from typing import List, Tuple, Union, Dict
import math

# helper functions
def exists(val):
    return val is not None

def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True):
    '''
    L2 norm of tensor clamped above a minimum value `eps`.

    :param sqrt: if `False`, returns the square of the L2 norm
    '''
    out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
    return torch.sqrt(out) if sqrt else out

# the classes GVP, GVPDropout, and GVPLayerNorm are taken from lucidrains' geometric-vector-perceptron repository
# https://github.com/lucidrains/geometric-vector-perceptron/tree/main
# some adaptations have been made to these classes to make them more consistent with the original GVP paper/implementation
# specifically, using _norm_no_nan instead of torch's built in norm function, and the weight intialiation scheme for Wh and Wu

def _rbf(D, D_min=0., D_max=20., D_count=16):
    '''
    From https://github.com/jingraham/neurips19-graph-protein-design

    Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1.
    That is, if `D` has shape [...dims], then the returned tensor will have
    shape [...dims, D_count].
    '''
    device = D.device
    D_mu = torch.linspace(D_min, D_max, D_count, device=device)
    D_mu = D_mu.view([1, -1])
    D_sigma = (D_max - D_min) / D_count
    D_expand = torch.unsqueeze(D, -1)

    RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2)
    return RBF

class GVP(nn.Module):
    def __init__(
        self,
        dim_vectors_in,
        dim_vectors_out,
        dim_feats_in,
        dim_feats_out,
        n_cp_feats = 0, # number of cross-product features added to hidden vector features
        hidden_vectors = None,
        feats_activation = nn.SiLU(),
        vectors_activation = nn.Sigmoid(),
        vector_gating = True,
        xavier_init = False
    ):
        super().__init__()
        self.dim_vectors_in = dim_vectors_in
        self.dim_feats_in = dim_feats_in

        # debug
        # print(f"GVP initialized with dim_feats_in: {self.dim_feats_in}")

        self.n_cp_feats = n_cp_feats

        self.dim_vectors_out = dim_vectors_out
        dim_h = max(dim_vectors_in, dim_vectors_out) if hidden_vectors is None else hidden_vectors

        # create Wh matrix
        wh_k = 1/math.sqrt(dim_vectors_in)
        self.Wh = torch.zeros(dim_vectors_in, dim_h, dtype=torch.float32).uniform_(-wh_k, wh_k)
        self.Wh = nn.Parameter(self.Wh)

        # create Wcp matrix if we are using cross-product features
        if n_cp_feats > 0:
            wcp_k = 1/math.sqrt(dim_vectors_in)
            self.Wcp = torch.zeros(dim_vectors_in, n_cp_feats*2, dtype=torch.float32).uniform_(-wcp_k, wcp_k)
            self.Wcp = nn.Parameter(self.Wcp)



        # create Wu matrix
        if n_cp_feats > 0: # the number of vector features going into Wu is increased by n_cp_feats if we are using cross-product features
            wu_in_dim = dim_h + n_cp_feats
        else:
            wu_in_dim = dim_h
        wu_k = 1/math.sqrt(wu_in_dim)
        self.Wu = torch.zeros(wu_in_dim, dim_vectors_out, dtype=torch.float32).uniform_(-wu_k, wu_k)
        self.Wu = nn.Parameter(self.Wu)

        self.vectors_activation = vectors_activation

        self.to_feats_out = nn.Sequential(
            nn.Linear(dim_h + n_cp_feats + dim_feats_in, dim_feats_out),
            feats_activation
        )

        # branching logic to use old GVP, or GVP with vector gating
        if vector_gating:
            self.scalar_to_vector_gates = nn.Linear(dim_feats_out, dim_vectors_out)
            if xavier_init:
                nn.init.xavier_uniform_(self.scalar_to_vector_gates.weight, gain=1)
                nn.init.constant_(self.scalar_to_vector_gates.bias, 0)
        else:
            self.scalar_to_vector_gates = None

        # self.scalar_to_vector_gates = nn.Linear(dim_feats_out, dim_vectors_out) if vector_gating else None

    def forward(self, data):
        feats, vectors = data

        # debug
        # print(f"feats shape: {feats.shape}")
        # print(f"vectors shape: {vectors.shape}")

        b, n, _, v, c  = *feats.shape, *vectors.shape

        # feats has shape (batch_size, n_feats)
        # vectors has shape (batch_size, n_vectors, 3)

        assert c == 3 and v == self.dim_vectors_in, 'vectors have wrong dimensions'

        # debug
        # print(f"n (scalar feature size): {n}")
        # print(f"self.dim_feats_in: {self.dim_feats_in}")
        assert n == self.dim_feats_in, 'scalar features have wrong dimensions'

        Vh = einsum('b v c, v h -> b h c', vectors, self.Wh) # has shape (batch_size, dim_h, 3)

        # if we are including cross-product features, compute them here
        if self.n_cp_feats > 0:
            # convert dim_vectors_in vectors to n_cp_feats*2 vectors
            Vcp = einsum('b v c, v p -> b p c', vectors, self.Wcp) # has shape (batch_size, n_cp_feats*2, 3)
            # split the n_cp_feats*2 vectors into two sets of n_cp_feats vectors
            cp_src, cp_dst = torch.split(Vcp, self.n_cp_feats, dim=1) # each has shape (batch_size, n_cp_feats, 3)
            # take the cross product of the two sets of vectors
            cp = torch.linalg.cross(cp_src, cp_dst, dim=-1) # has shape (batch_size, n_cp_feats, 3)

            # add the cross product features to the hidden vector features
            Vh = torch.cat((Vh, cp), dim=1) # has shape (batch_size, dim_h + n_cp_feats, 3)

        Vu = einsum('b h c, h u -> b u c', Vh, self.Wu) # has shape (batch_size, dim_vectors_out, 3)

        sh = _norm_no_nan(Vh)

        s = torch.cat((feats, sh), dim = 1)

        feats_out = self.to_feats_out(s)

        if exists(self.scalar_to_vector_gates):
            gating = self.scalar_to_vector_gates(feats_out)
            gating = gating.unsqueeze(dim = -1)
        else:
            gating = _norm_no_nan(Vu)

        vectors_out = self.vectors_activation(gating) * Vu

        # if torch.isnan(feats_out).any() or torch.isnan(vectors_out).any():
        #     raise ValueError("NaNs in GVP forward pass")

        return (feats_out, vectors_out)

class _VDropout(nn.Module):
    '''
    Vector channel dropout where the elements of each
    vector channel are dropped together.
    '''
    def __init__(self, drop_rate):
        super(_VDropout, self).__init__()
        self.drop_rate = drop_rate
        self.dummy_param = nn.Parameter(torch.empty(0))

    def forward(self, x):
        '''
        :param x: `torch.Tensor` corresponding to vector channels
        '''
        device = self.dummy_param.device
        if not self.training:
            return x
        mask = torch.bernoulli(
            (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device)
        ).unsqueeze(-1)
        x = mask * x / (1 - self.drop_rate)
        return x

class GVPDropout(nn.Module):
    """ Separate dropout for scalars and vectors. """
    def __init__(self, rate):
        super().__init__()
        self.vector_dropout = _VDropout(rate)
        self.feat_dropout = nn.Dropout(rate)

    def forward(self, feats, vectors):
        return self.feat_dropout(feats), self.vector_dropout(vectors)


class GVPLayerNorm(nn.Module):
    """ Normal layer norm for scalars, nontrainable norm for vectors. """
    def __init__(self, feats_h_size, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.feat_norm = nn.LayerNorm(feats_h_size)

    def forward(self, feats, vectors):

        normed_feats = self.feat_norm(feats)

        vn = _norm_no_nan(vectors, axis=-1, keepdims=True, sqrt=False)
        vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True) + self.eps ) + self.eps
        normed_vectors = vectors / vn
        return normed_feats, normed_vectors



class GVPConv(nn.Module):

    """GVP graph convolution on a homogenous graph."""

    def __init__(self, scalar_size: int = 128, vector_size: int = 16, n_cp_feats: int = 0,
                  scalar_activation=nn.SiLU, vector_activation=nn.Sigmoid,
                  n_message_gvps: int = 1, n_update_gvps: int = 1,
                  use_dst_feats: bool = False, rbf_dmax: float = 20, rbf_dim: int = 16,
                  edge_feat_size: int = 0, coords_range=10, message_norm: Union[float, str] = 10, dropout: float = 0.0,):

        super().__init__()

        # self.edge_type = edge_type
        # self.src_ntype = edge_type[0]
        # self.dst_ntype = edge_type[2]
        self.scalar_size = scalar_size
        self.vector_size = vector_size
        self.n_cp_feats = n_cp_feats
        self.scalar_activation = scalar_activation
        self.vector_activation = vector_activation
        self.n_message_gvps = n_message_gvps
        self.n_update_gvps = n_update_gvps
        self.edge_feat_size = edge_feat_size
        self.use_dst_feats = use_dst_feats
        self.rbf_dmax = rbf_dmax
        self.rbf_dim = rbf_dim
        self.dropout_rate = dropout
        self.message_norm = message_norm

        # create message passing function
        message_gvps = []
        for i in range(n_message_gvps):

            dim_vectors_in = vector_size
            dim_feats_in = scalar_size

            # on the first layer, there is an extra edge vector for the displacement vector between the two node positions
            if i == 0:
                dim_vectors_in += 1
                dim_feats_in += rbf_dim + edge_feat_size

            # if this is the first layer and we are using destination node features to compute messages, add them to the input dimensions
            if use_dst_feats and i == 0:
                dim_vectors_in += vector_size
                dim_feats_in += scalar_size

            message_gvps.append(
                GVP(dim_vectors_in=dim_vectors_in,
                    dim_vectors_out=vector_size,
                    n_cp_feats=n_cp_feats,
                    dim_feats_in=dim_feats_in,
                    dim_feats_out=scalar_size,
                    feats_activation=scalar_activation(),
                    vectors_activation=vector_activation(),
                    vector_gating=True)
            )
        self.edge_message = nn.Sequential(*message_gvps)

        # create update function
        update_gvps = []
        for i in range(n_update_gvps):
            update_gvps.append(
                GVP(dim_vectors_in=vector_size,
                    dim_vectors_out=vector_size,
                    n_cp_feats=n_cp_feats,
                    dim_feats_in=scalar_size,
                    dim_feats_out=scalar_size,
                    feats_activation=scalar_activation(),
                    vectors_activation=vector_activation(),
                    vector_gating=True)
            )
        self.node_update = nn.Sequential(*update_gvps)

        self.dropout = GVPDropout(self.dropout_rate)
        self.message_layer_norm = GVPLayerNorm(self.scalar_size)
        self.update_layer_norm = GVPLayerNorm(self.scalar_size)

        if isinstance(self.message_norm, str) and self.message_norm not in ['mean', 'sum']:
            raise ValueError(f"message_norm must be either 'mean', 'sum', or a number, got {self.message_norm}")
        else:
            assert isinstance(self.message_norm, (float, int)), "message_norm must be either 'mean', 'sum', or a number"

        if self.message_norm == 'mean':
            self.agg_func = fn.mean
        else:
            self.agg_func = fn.sum

    def forward(self, g: dgl.DGLGraph,
                scalar_feats: torch.Tensor,
                coord_feats: torch.Tensor,
                vec_feats: torch.Tensor,
                edge_feats: torch.Tensor = None,
                x_diff: torch.Tensor = None,
                d: torch.Tensor = None):
        # vec_feat has shape (n_nodes, n_vectors, 3)

        with g.local_scope():

            g.ndata['h'] = scalar_feats
            g.ndata['x'] = coord_feats
            g.ndata['v'] = vec_feats

            if x_diff is not None and d is not None:
                g.edata['x_diff'] = x_diff
                g.edata['d'] = d

            # edge feature
            if self.edge_feat_size > 0:
                assert edge_feats is not None, "Edge features must be provided."
                g.edata["a"] = edge_feats



            # normalize x_diff and compute rbf embedding of edge distance
            # dij = torch.norm(g.edges[self.edge_type].data['x_diff'], dim=-1, keepdim=True)
            if 'x_diff' not in g.edata:
                # get vectors between node positions
                g.apply_edges(fn.u_sub_v("x", "x", "x_diff"))
                dij = _norm_no_nan(g.edata['x_diff'], keepdims=True) + 1e-8
                g.edata['x_diff'] = g.edata['x_diff'] / dij
                g.edata['d'] = _rbf(dij.squeeze(1), D_max=self.rbf_dmax, D_count=self.rbf_dim)

            # compute messages on every edge
            g.apply_edges(self.message)

            # aggregate messages from every edge
            g.update_all(fn.copy_e("scalar_msg", "m"), self.agg_func("m", "scalar_msg"))
            g.update_all(fn.copy_e("vec_msg", "m"), self.agg_func("m", "vec_msg"))

            # get aggregated scalar and vector messages
            if isinstance(self.message_norm, str):
                z = 1
            else:
                z = self.message_norm

            scalar_msg = g.ndata["scalar_msg"] / z
            vec_msg = g.ndata["vec_msg"] / z

            # dropout scalar and vector messages
            scalar_msg, vec_msg = self.dropout(scalar_msg, vec_msg)

            # update scalar and vector features, apply layernorm
            scalar_feat = g.ndata['h'] + scalar_msg
            vec_feat = g.ndata['v'] + vec_msg
            scalar_feat, vec_feat = self.message_layer_norm(scalar_feat, vec_feat)

            # apply node update function, apply dropout to residuals, apply layernorm
            scalar_residual, vec_residual = self.node_update((scalar_feat, vec_feat))
            scalar_residual, vec_residual = self.dropout(scalar_residual, vec_residual)
            scalar_feat = scalar_feat + scalar_residual
            vec_feat = vec_feat + vec_residual
            scalar_feat, vec_feat = self.update_layer_norm(scalar_feat, vec_feat)

        return scalar_feat, vec_feat

    def message(self, edges):

        # concatenate x_diff and v on every edge to produce vector features
        vec_feats = [ edges.data["x_diff"].unsqueeze(1), edges.src["v"] ]
        if self.use_dst_feats:
            vec_feats.append(edges.dst["v"])
        vec_feats = torch.cat(vec_feats, dim=1)

        # Before concatenation
        # create scalar features
        scalar_feats = [ edges.src['h'], edges.data['d'] ]
        if self.edge_feat_size > 0:
            scalar_feats.append(edges.data['a'])

        if self.use_dst_feats:
            scalar_feats.append(edges.dst['h'])

        scalar_feats = torch.cat(scalar_feats, dim=1)

        # print(f"scalar_feats shape after concatenation: {scalar_feats.shape}")
        # print(f"Expected dim_feats_in: {self.edge_message[0].dim_feats_in}")

        scalar_message, vector_message = self.edge_message((scalar_feats, vec_feats))

        return {"scalar_msg": scalar_message, "vec_msg": vector_message}

## define our own GVP-GCN model

In [None]:
class GVP_GCN(nn.Module):
    def __init__(self, scalar_input_dim, vector_input_dim, scalar_hidden_dim, vector_hidden_dim, output_dim):
        super(GVP_GCN, self).__init__()
        # Single GVPConv layer
        self.gvp_conv = GVPConv(
            scalar_size=scalar_input_dim,
            vector_size=vector_input_dim,
            scalar_activation=nn.ReLU,
            vector_activation=nn.ReLU,
            n_message_gvps=2,
            n_update_gvps=2,
            rbf_dmax=20,  # Example value for radial basis function
            rbf_dim=8,   # Example value for radial basis function dimensions
            dropout=0.1   # Example dropout rate
        )
        self.fc = nn.Linear(scalar_input_dim, output_dim)  # Final prediction layer
        self.dropout = nn.Dropout(0.5)

    def forward(self, graph, scalar_feats, coords, vector_feats):
        # Update graph with node features
        graph.ndata['scalar_feats'] = scalar_feats
        graph.ndata['vector_feats'] = vector_feats
        graph.ndata['coords'] = coords

        # Apply the GVPConv layer
        scalar_feats, vector_feats = self.gvp_conv(graph, scalar_feats, coords, vector_feats)

        # Apply final linear layer for prediction
        out = self.fc(scalar_feats)
        return out



## mini-batch training

In [None]:
train_nids = torch.nonzero(train_mask, as_tuple=False).squeeze()
val_nids = torch.nonzero(val_mask, as_tuple=False).squeeze()

graph = graph.to(device)
data_y = data_y.to(device)


In [None]:
# Define the model
class GVP_GCN(nn.Module):
    def __init__(self, scalar_input_dim, vector_input_dim, scalar_hidden_dim, vector_hidden_dim, output_dim):
        super(GVP_GCN, self).__init__()
        self.gvp_conv = GVPConv(
            scalar_size=scalar_input_dim,
            vector_size=vector_input_dim,
            scalar_activation=nn.ReLU,
            vector_activation=nn.ReLU,
            n_message_gvps=2,
            n_update_gvps=2,
            rbf_dmax=20,
            rbf_dim=8,
            dropout=0.1
        )
        self.fc = nn.Linear(scalar_input_dim, output_dim)
        self.dropout = nn.Dropout(0.5)

    def forward(self, graph, scalar_feats, coords, vector_feats):
        scalar_feats, vector_feats = self.gvp_conv(graph, scalar_feats, coords, vector_feats)
        out = self.fc(scalar_feats)
        return out

In [None]:
import torch
import torch.nn as nn
import dgl
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

# Custom Dataset for Node IDs
class NodeDataset(Dataset):
    def __init__(self, node_ids):
        self.node_ids = node_ids

    def __len__(self):
        return len(self.node_ids)

    def __getitem__(self, idx):
        return self.node_ids[idx]

# Define training function
def train_epoch(model, data_loader, graph, data_y, criterion, optimizer):
    model.train()
    total_loss = 0
    count = 0

    for batch in tqdm(data_loader, desc="Training"):
        # Ensure the batch is on the same device as the graph
        batch = batch.to(graph.device)

        # Sample a mini-batch subgraph
        mini_batch_graph = dgl.node_subgraph(graph, batch)
        mini_batch_graph = mini_batch_graph.to(device)  # Move subgraph to the device

        # Extract features and labels for the mini-batch
        scalar_feats = mini_batch_graph.ndata['scalar_feats'].to(device)
        coords = mini_batch_graph.ndata['coords'].to(device)
        vector_feats = mini_batch_graph.ndata['vector_feats'].to(device)
        labels = data_y[mini_batch_graph.ndata[dgl.NID].to('cpu')].to(device)

        # Forward pass
        out = model(mini_batch_graph, scalar_feats, coords, vector_feats)
        loss = criterion(out, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(batch)
        count += len(batch)

    return total_loss / count

@torch.no_grad()
def evaluate(model, data_loader, graph, data_y, criterion):
    model.eval()
    total_loss = 0
    count = 0

    for batch in tqdm(data_loader, desc="Evaluation"):
        # Ensure the batch is on the same device as the graph
        batch = batch.to(graph.device)

        # Sample a mini-batch subgraph
        mini_batch_graph = dgl.node_subgraph(graph, batch)
        mini_batch_graph = mini_batch_graph.to(device)  # Move subgraph to the device

        # Extract features and labels for the mini-batch
        scalar_feats = mini_batch_graph.ndata['scalar_feats'].to(device)
        coords = mini_batch_graph.ndata['coords'].to(device)
        vector_feats = mini_batch_graph.ndata['vector_feats'].to(device)
        labels = data_y[mini_batch_graph.ndata[dgl.NID].to('cpu')].to(device)

        # Forward pass
        out = model(mini_batch_graph, scalar_feats, coords, vector_feats)
        loss = criterion(out, labels)

        total_loss += loss.item() * len(batch)
        count += len(batch)

    return total_loss / count

# Create datasets and data loaders
train_nids = torch.nonzero(train_mask, as_tuple=False).squeeze()
val_nids = torch.nonzero(val_mask, as_tuple=False).squeeze()

train_dataset = NodeDataset(train_nids)
val_dataset = NodeDataset(val_nids)

train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False)

# Model and optimizer setup
scalar_input_dim = graph.ndata['scalar_feats'].shape[1]
vector_input_dim = graph.ndata['vector_feats'].shape[1]
output_dim = data_y.shape[1]

In [None]:
model = GVP_GCN(scalar_input_dim, vector_input_dim, 8, 4, output_dim).to(device)

# Calculate pos_weight
num_samples = train_mask.shape[0]  # Total number of training samples
num_positive_per_class = train_mask.sum(dim=0).float()  # Number of positive samples per class

# Avoid division by zero by replacing zero counts with a very small number
num_positive_per_class[num_positive_per_class == 0] = float('inf')  # To avoid division by zero
pos_weight = num_samples / (2 * num_positive_per_class)

# Replace inf values with 0.0 or any default value you prefer
pos_weight[torch.isinf(pos_weight)] = 0.0

# Pass pos_weight to BCEWithLogitsLoss
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Training loop
epochs = 10
for epoch in range(epochs):
    train_loss = train_epoch(model, train_loader, graph, data_y, criterion, optimizer)
    val_loss = evaluate(model, val_loader, graph, data_y, criterion)
    print(f"Epoch {epoch+1}/{epochs}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

In [None]:
import numpy as np
from sklearn.metrics import precision_recall_curve, auc

def compute_metrics(predictions, labels, thresholds=np.linspace(0, 1, 101)):
    """
    predictions: numpy array of shape (num_proteins, num_terms), predicted probabilities
    labels: numpy array of shape (num_proteins, num_terms)
    thresholds: array-like, list of thresholds to evaluate for Fmax

    Returns:
        fmax: Average maximum F1 score across proteins
        aupr: Average area under precision-recall curve across GO terms
        f1_max_scores: numpy array of fmax_values
        aupr_scores: numpy array of aupr values
    """
    num_proteins, num_terms = labels.shape

    f1_max_scores = []
    for i in range(num_proteins):
        protein_preds = predictions[i]
        protein_labels = labels[i]

        max_f1 = 0
        for t in thresholds:
            binary_preds = (protein_preds >= t).astype(int)
            tp = np.sum((binary_preds == 1) & (protein_labels == 1))
            fp = np.sum((binary_preds == 1) & (protein_labels == 0))
            fn = np.sum((binary_preds == 0) & (protein_labels == 1))

            precision = tp / (tp + fp) if tp + fp > 0 else 0
            recall = tp / (tp + fn) if tp + fn > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
            max_f1 = max(max_f1, f1)

        f1_max_scores.append(max_f1)

    fmax = np.mean(f1_max_scores)

    aupr_scores = []
    for j in range(num_terms):
        term_preds = predictions[:, j]
        term_labels = labels[:, j]

        precision, recall, _ = precision_recall_curve(term_labels, term_preds)
        aupr = auc(recall, precision)
        aupr_scores.append(aupr)

    aupr = np.mean(aupr_scores)

    return fmax, aupr, f1_max_scores, aupr_scores

In [None]:
@torch.no_grad()
def evaluate_on_test_set(model, test_loader, graph, data_y):
    """
    Evaluate the model on the test set using Fmax and AUPR metrics.

    Parameters:
        model: Trained GVP-GCN model.
        test_loader: DataLoader for the test set.
        graph: Original DGL graph.
        data_y: Ground-truth labels for all nodes.

    Returns:
        fmax: Average maximum F1 score across proteins.
        aupr: Average area under precision-recall curve across GO terms.
        f1_max_scores: numpy array of Fmax values.
        aupr_scores: numpy array of AUPR values.
    """
    model.eval()
    predictions = []
    labels = []

    for batch in tqdm(test_loader, desc="Evaluating Test Set"):
        # Ensure batch is on the same device as the graph
        batch = batch.to(graph.device)

        # Sample a mini-batch subgraph
        mini_batch_graph = dgl.node_subgraph(graph, batch)
        mini_batch_graph = mini_batch_graph.to(graph.device)

        # Extract features for the mini-batch
        scalar_feats = mini_batch_graph.ndata['scalar_feats'].to(graph.device)
        coords = mini_batch_graph.ndata['coords'].to(graph.device)
        vector_feats = mini_batch_graph.ndata['vector_feats'].to(graph.device)

        # Predict outputs
        outputs = torch.sigmoid(model(mini_batch_graph, scalar_feats, coords, vector_feats)).cpu().numpy()

        # Extract labels for the mini-batch
        batch_labels = data_y[mini_batch_graph.ndata[dgl.NID].to('cpu')].to('cpu').numpy()

        predictions.append(outputs)
        labels.append(batch_labels)

    # Concatenate all predictions and labels
    predictions = np.vstack(predictions)
    labels = np.vstack(labels)

    # Compute metrics
    fmax, aupr, f1_max_scores, aupr_scores = compute_metrics(predictions, labels)

    print(f"Test Set Evaluation:")
    print(f"Fmax: {fmax:.4f}, AUPR: {aupr:.4f}")
    return fmax, aupr, f1_max_scores, aupr_scores

In [None]:
# Evaluate on test set
# Get test node IDs
test_nids = torch.nonzero(test_mask, as_tuple=False).squeeze()

# Create a dataset and data loader for the test set
test_dataset = NodeDataset(test_nids)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)  # No need to shuffle the test set

fmax, aupr, f1_max_scores, aupr_scores = evaluate_on_test_set(
    model, test_loader, graph, data_y
)

## GVP + ESM


In [None]:
class GVP_ESM_GCN(nn.Module):
    def __init__(self, scalar_input_dim, vector_input_dim, scalar_hidden_dim, vector_hidden_dim, esm_embedding_dim, output_dim):
        super(GVP_ESM_GCN, self).__init__()
        # GVP layer
        self.gvp_conv = GVPConv(
            scalar_size=scalar_input_dim,
            vector_size=vector_input_dim,
            scalar_activation=nn.ReLU,
            vector_activation=nn.ReLU,
            n_message_gvps=2,
            n_update_gvps=2,
            rbf_dmax=20,
            rbf_dim=8,
            dropout=0.1
        )
        # Dropout for regularization
        self.dropout = nn.Dropout(0.5)

        # Fully connected layer for concatenated features
        concatenated_dim = scalar_input_dim + esm_embedding_dim
        self.fc = nn.Linear(concatenated_dim, output_dim)

    def forward(self, graph, scalar_feats, coords, vector_feats, esm_embeddings):
        # GVP processing
        scalar_feats, vector_feats = self.gvp_conv(graph, scalar_feats, coords, vector_feats)

        # Concatenate GVP scalar output with ESM embeddings
        combined_feats = torch.cat((scalar_feats, esm_embeddings), dim=1)

        # Apply dropout and final fully connected layer
        combined_feats = self.dropout(combined_feats)
        out = self.fc(combined_feats)
        return out


In [None]:
# load ESM embeddings
esm_embeddings = torch.load(f"{BASE_DIR}/protein_esm_embeddings.pt")
filtered_esm_embeddings = torch.stack([esm_embeddings[protein_to_idx[protein]]
                                       for protein in proteins_with_labels])

In [None]:
# Define training function
def train_with_esm_epoch(model, data_loader, graph, data_y, criterion, optimizer):
    model.train()
    total_loss = 0
    count = 0

    for batch in tqdm(data_loader, desc="Training"):
        # Ensure the batch is on the same device as the graph
        batch = batch.to(graph.device)

        # Sample a mini-batch subgraph
        mini_batch_graph = dgl.node_subgraph(graph, batch)
        mini_batch_graph = mini_batch_graph.to(device)  # Move subgraph to the device

        # Extract features and labels for the mini-batch
        scalar_feats = mini_batch_graph.ndata['scalar_feats'].to(device)
        coords = mini_batch_graph.ndata['coords'].to(device)
        vector_feats = mini_batch_graph.ndata['vector_feats'].to(device)
        esm_feats = mini_batch_graph.ndata['esm'].to(device)
        labels = data_y[mini_batch_graph.ndata[dgl.NID].to('cpu')].to(device)

        # Forward pass
        out = model(mini_batch_graph, scalar_feats, coords, vector_feats, esm_feats)
        loss = criterion(out, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(batch)
        count += len(batch)

    return total_loss / count

@torch.no_grad()
def evaluate_with_esm(model, data_loader, graph, data_y, criterion):
    model.eval()
    total_loss = 0
    count = 0

    for batch in tqdm(data_loader, desc="Evaluation"):
        # Ensure the batch is on the same device as the graph
        batch = batch.to(graph.device)

        # Sample a mini-batch subgraph
        mini_batch_graph = dgl.node_subgraph(graph, batch)
        mini_batch_graph = mini_batch_graph.to(device)  # Move subgraph to the device

        # Extract features and labels for the mini-batch
        scalar_feats = mini_batch_graph.ndata['scalar_feats'].to(device)
        coords = mini_batch_graph.ndata['coords'].to(device)
        vector_feats = mini_batch_graph.ndata['vector_feats'].to(device)
        esm_feats = mini_batch_graph.ndata['esm'].to(device)
        labels = data_y[mini_batch_graph.ndata[dgl.NID].to('cpu')].to(device)

        # Forward pass
        out = model(mini_batch_graph, scalar_feats, coords, vector_feats, esm_feats)
        loss = criterion(out, labels)

        total_loss += loss.item() * len(batch)
        count += len(batch)

    return total_loss / count

In [None]:
# Define the model
scalar_input_dim = graph.ndata['scalar_feats'].shape[1]
vector_input_dim = graph.ndata['vector_feats'].shape[1]
esm_embedding_dim = filtered_esm_embeddings.shape[1]
output_dim = data_y.shape[1]

model = GVP_ESM_GCN(
    scalar_input_dim=scalar_input_dim,
    vector_input_dim=vector_input_dim,
    scalar_hidden_dim=8,
    vector_hidden_dim=4,
    esm_embedding_dim=esm_embedding_dim,
    output_dim=output_dim
).to(device)

In [None]:
# Calculate pos_weight
num_samples = train_mask.shape[0]  # Total number of training samples
num_positive_per_class = train_mask.sum(dim=0).float()  # Number of positive samples per class

# Avoid division by zero by replacing zero counts with a very small number
num_positive_per_class[num_positive_per_class == 0] = float('inf')  # To avoid division by zero
pos_weight = num_samples / (2 * num_positive_per_class)

# Replace inf values with 0.0 or any default value you prefer
pos_weight[torch.isinf(pos_weight)] = 0.0

# Pass pos_weight to BCEWithLogitsLoss
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Training loop
# save train and val loss
train_losses = []
val_losses = []
epochs = 50
for epoch in range(epochs):
    train_loss = train_with_esm_epoch(model, train_loader, graph, data_y, criterion, optimizer)
    val_loss = evaluate_with_esm(model, val_loader, graph, data_y, criterion)
    print(f"Epoch {epoch+1}/{epochs}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    train_losses.append(train_loss)
    val_losses.append(val_loss)

In [None]:
@torch.no_grad()
def evaluate_on_test_set_with_esm(model, test_loader, graph, data_y):
    """
    Evaluate the model on the test set using Fmax and AUPR metrics.

    Parameters:
        model: Trained GVP-GCN model.
        test_loader: DataLoader for the test set.
        graph: Original DGL graph.
        data_y: Ground-truth labels for all nodes.

    Returns:
        fmax: Average maximum F1 score across proteins.
        aupr: Average area under precision-recall curve across GO terms.
        f1_max_scores: numpy array of Fmax values.
        aupr_scores: numpy array of AUPR values.
    """
    model.eval()
    predictions = []
    labels = []

    for batch in tqdm(test_loader, desc="Evaluating Test Set"):
        # Ensure batch is on the same device as the graph
        batch = batch.to(graph.device)

        # Sample a mini-batch subgraph
        mini_batch_graph = dgl.node_subgraph(graph, batch)
        mini_batch_graph = mini_batch_graph.to(graph.device)

        # Extract features for the mini-batch
        scalar_feats = mini_batch_graph.ndata['scalar_feats'].to(graph.device)
        coords = mini_batch_graph.ndata['coords'].to(graph.device)
        vector_feats = mini_batch_graph.ndata['vector_feats'].to(graph.device)
        esm_feats = mini_batch_graph.ndata['esm'].to(graph.device)

        # Predict outputs
        outputs = torch.sigmoid(model(mini_batch_graph, scalar_feats, coords, vector_feats, esm_feats)).cpu().numpy()

        # Extract labels for the mini-batch
        batch_labels = data_y[mini_batch_graph.ndata[dgl.NID].to('cpu')].to('cpu').numpy()

        predictions.append(outputs)
        labels.append(batch_labels)

    # Concatenate all predictions and labels
    predictions = np.vstack(predictions)
    labels = np.vstack(labels)

    # Compute metrics
    fmax, aupr, f1_max_scores, aupr_scores = compute_metrics(predictions, labels)

    print(f"Test Set Evaluation:")
    print(f"Fmax: {fmax:.4f}, AUPR: {aupr:.4f}")
    return fmax, aupr, f1_max_scores, aupr_scores

In [None]:
# Move model and graph to device
model = model.to(device)
graph = graph.to('cpu')
data_y = data_y.to(device)

In [None]:
# Evaluate on test set
# Get test node IDs
test_nids = torch.nonzero(test_mask, as_tuple=False).squeeze()

# Create a dataset and data loader for the test set
test_dataset = NodeDataset(test_nids)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)  # No need to shuffle the test set

fmax, aupr, f1_max_scores, aupr_scores = evaluate_on_test_set_with_esm(
    model, test_loader, graph, data_y
)