Known protein modeler (ProtBERT, ESMfold)

Use known LLM to determine binding pockets etc.

In [None]:
>protein_id
SEQUENCE
000000000011111000000000000000000000000000000000000000000000000000000000000

In [None]:
from transformers import PreTrainedTokenizerFast

amino_acids = "ACDEFGHIKLMNPQRSTVWY"
tokenizer = PreTrainedTokenizerFast(tokenizer_file=None)
tokenizer.add_tokens(list(amino_acids))

In [None]:
from transformers import GPT2LMHeadModel, GPT2Config

config = GPT2Config(vocab_size=len(tokenizer), n_embd=256, n_layer=6, n_head=4)
model = GPT2LMHeadModel(config)

In [None]:
import torch.nn as nn

class PocketPredictor(nn.Module):
    def __init__(self, gpt_model, hidden_dim):
        super().__init__()
        self.gpt = gpt_model
        self.head = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, input_ids):
        outputs = self.gpt(input_ids).last_hidden_state
        return self.head(outputs)

In [None]:
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score

# y_true: binary labels (0/1), y_pred_probs: predicted probabilities
threshold = 0.5
y_pred = (y_pred_probs > threshold).astype(int)

roc_auc = roc_auc_score(y_true, y_pred_probs)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)

In [None]:
true_pocket = set(i for i, label in enumerate(y_true) if label == 1)
predicted_pocket = set(i for i, prob in enumerate(y_pred_probs) if prob > threshold)

coverage = len(true_pocket & predicted_pocket) / len(true_pocket)
overlap = len(true_pocket & predicted_pocket) / len(predicted_pocket)

In [None]:
fetch 1XYZ, async=0
select predicted_pocket, resi 45+67+89  # replace with predicted residue indices
color red, predicted_pocket
show surface, predicted_pocket

In [None]:
import numpy as np
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score

def evaluate_pocket_predictions(y_true, y_pred_probs, threshold=0.5):
    """
    Evaluate binding pocket predictions.
    
    Args:
        y_true (np.array): Binary ground truth labels (0 = non-pocket, 1 = pocket)
        y_pred_probs (np.array): Predicted probabilities for each residue
        threshold (float): Classification threshold for binary decision

    Returns:
        dict: Evaluation metrics
    """
    y_pred = (y_pred_probs > threshold).astype(int)

    # Residue-level metrics
    roc_auc = roc_auc_score(y_true, y_pred_probs)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)

    # Pocket-level coverage
    true_pocket = set(np.where(y_true == 1)[0])
    predicted_pocket = set(np.where(y_pred == 1)[0])
    coverage = len(true_pocket & predicted_pocket) / len(true_pocket) if true_pocket else 0.0
    overlap = len(true_pocket & predicted_pocket) / len(predicted_pocket) if predicted_pocket else 0.0

    return {
        "ROC-AUC": roc_auc,
        "Precision": precision,
        "Recall": recall,
        "F1": f1,
        "Pocket Coverage": coverage,
        "Pocket Overlap": overlap
    }

In [None]:
# Simulated predictions
y_true = np.array([0, 0, 1, 1, 0, 0, 1, 0, 0])
y_pred_probs = np.array([0.1, 0.2, 0.8, 0.7, 0.3, 0.2, 0.9, 0.1, 0.05])

metrics = evaluate_pocket_predictions(y_true, y_pred_probs)
for k, v in metrics.items():
    print(f"{k}: {v:.4f}")

In [None]:
import nglview as nv
view = nv.show_file("1XYZ.pdb")
view.add_representation("spacefill", selection="45 or 67 or 89", color="red")
view