Known protein modeler (ProtBERT, ESMfold)

Use known LLM to determine binding pockets etc.

After parsing the input data should look like this, with a fasta protein ID, the protein sequence, and the 0/1 binary determining binding pocket residues

>protein_id
SEQUENCE
000000000011111000000000000000000000000000000000000000000000000000000000000

In [12]:
%%time

from Bio import SeqIO
import gzip

# Load protein sequences
def load_fasta_sequences(fasta_path):
    sequences = {}
    with gzip.open(fasta_path, "rt") as handle:
        for record in SeqIO.parse(handle, "fasta"):
            sequences[record.id] = str(record.seq)
    return sequences

# Parse BioLiP.txt
def parse_biolip_annotations(biolip_path):
    binding_map = {}  # protein_id -> set of binding residue indices
    with gzip.open(biolip_path, "rt") as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) < 9:
                continue
            pdb_id = parts[0]
            chain_id = parts[1]
            binding_residues = parts[8]  # e.g., "A:45,A:67,A:89"
            key = f"{pdb_id}_{chain_id}"
            indices = set()
            for res in binding_residues.split(","):
                if ":" in res:
                    _, idx = res.split(":")
                    if idx.isdigit():
                        indices.add(int(idx))
            binding_map[key] = indices
    return binding_map

# Generate labeled output
def generate_labeled_sequences(sequences, binding_map, output_path):
    with open(output_path, "w") as out:
        for protein_id, seq in sequences.items():
            if protein_id not in binding_map:
                continue
            labels = ["0"] * len(seq)
            for idx in binding_map[protein_id]:
                if 0 <= idx < len(seq):
                    labels[idx] = "1"
            out.write(f">{protein_id}\n{seq}\n{''.join(labels)}\n")

# Paths to your files
fasta_path = "protein.fasta.gz"
biolip_path = "BioLiP.txt.gz"
output_path = "biolip_labeled.txt"

# Run preprocessing
sequences = load_fasta_sequences(fasta_path)
binding_map = parse_biolip_annotations(biolip_path)
generate_labeled_sequences(sequences, binding_map, output_path)

CPU times: total: 4.61 s
Wall time: 4.69 s


In [13]:
%%time

from transformers import GPT2Tokenizer

# Load pretrained tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-125M")

# Add amino acid tokens (if needed)
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
tokenizer.add_tokens(amino_acids)

for aa in amino_acids:
    print(f"{aa}: {tokenizer.convert_tokens_to_ids(aa)}")

A: 32
C: 34
D: 35
E: 36
F: 37
G: 38
H: 39
I: 40
K: 42
L: 43
M: 44
N: 45
P: 47
Q: 48
R: 49
S: 50
T: 51
V: 53
W: 54
Y: 56
CPU times: total: 297 ms
Wall time: 531 ms


In [14]:
%%time

from transformers import GPT2LMHeadModel, GPT2Config

config = GPT2Config(vocab_size=len(tokenizer), n_embd=256, n_layer=6, n_head=4)
model = GPT2LMHeadModel(config)

print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 256)
    (wpe): Embedding(1024, 256)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x GPT2Block(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=768, nx=256)
          (c_proj): Conv1D(nf=256, nx=256)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=1024, nx=256)
          (c_proj): Conv1D(nf=256, nx=1024)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=256, out_features=50257, bias=False)
)
Total parameters: 17,867,008
CPU times: total: 5.31 s
Wa

In [None]:
%%time
import torch.nn as nn

class PocketPredictor(nn.Module):
    def __init__(self, gpt_model, hidden_dim):
        super().__init__()
        self.gpt = gpt_model
        self.head = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, input_ids):
        hidden = self.gpt.transformer(input_ids).last_hidden_state  # Extract the tensor
        return self.head(hidden)

CPU times: total: 0 ns
Wall time: 0 ns


In [24]:
sequence = "SKMSDVKCTSVVLLSVLQQLRVESSSKLWAQCVQLHNDILLAKDTTEAFEKMVSLLSVLLSMQGAVDINKLCEEMLDNRATLQ"
input_ids = tokenizer.encode(sequence, return_tensors="pt")  # shape: [1, seq_len]

print("Input shape:", input_ids.shape)

decoded = tokenizer.decode(input_ids[0])
print("Decoded sequence:", decoded)
print("Token IDs:", input_ids[0].tolist())


Input shape: torch.Size([1, 83])
Decoded sequence: S K M S D V K C T S V V L L S V L Q Q L R V E S S S K L W A Q C V Q L H N D I L L A K D T T E A F E K M V S L L S V L L S M Q G A V D I N K L C E E M L D N R A T L Q
Token IDs: [50, 42, 44, 50, 35, 53, 42, 34, 51, 50, 53, 53, 43, 43, 50, 53, 43, 48, 48, 43, 49, 53, 36, 50, 50, 50, 42, 43, 54, 32, 48, 34, 53, 48, 43, 39, 45, 35, 40, 43, 43, 32, 42, 35, 51, 51, 36, 32, 37, 36, 42, 44, 53, 50, 43, 43, 50, 53, 43, 43, 50, 44, 48, 38, 32, 53, 35, 40, 45, 42, 43, 34, 36, 36, 44, 43, 35, 45, 49, 32, 51, 43, 48]


In [25]:
%%time

import torch

pocket_model = PocketPredictor(model, hidden_dim=256)

with torch.no_grad():
    output = pocket_model(input_ids)  # shape: [1, 83, 1]
    y_pred_probs = output.squeeze(0).squeeze(-1).numpy()  # shape: [83]

CPU times: total: 406 ms
Wall time: 19 ms


TypeError: linear(): argument 'input' (position 1) must be Tensor, not BaseModelOutputWithPastAndCrossAttentions

In [None]:
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score

# y_true: binary labels (0/1), y_pred_probs: predicted probabilities
threshold = 0.5
y_pred = (y_pred_probs > threshold).astype(int)

roc_auc = roc_auc_score(y_true, y_pred_probs)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)

In [None]:
true_pocket = set(i for i, label in enumerate(y_true) if label == 1)
predicted_pocket = set(i for i, prob in enumerate(y_pred_probs) if prob > threshold)

coverage = len(true_pocket & predicted_pocket) / len(true_pocket)
overlap = len(true_pocket & predicted_pocket) / len(predicted_pocket)

In [None]:
fetch 1XYZ, async=0
select predicted_pocket, resi 45+67+89  # replace with predicted residue indices
color red, predicted_pocket
show surface, predicted_pocket

In [None]:
import numpy as np
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score

def evaluate_pocket_predictions(y_true, y_pred_probs, threshold=0.5):
    """
    Evaluate binding pocket predictions.
    
    Args:
        y_true (np.array): Binary ground truth labels (0 = non-pocket, 1 = pocket)
        y_pred_probs (np.array): Predicted probabilities for each residue
        threshold (float): Classification threshold for binary decision

    Returns:
        dict: Evaluation metrics
    """
    y_pred = (y_pred_probs > threshold).astype(int)

    # Residue-level metrics
    roc_auc = roc_auc_score(y_true, y_pred_probs)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)

    # Pocket-level coverage
    true_pocket = set(np.where(y_true == 1)[0])
    predicted_pocket = set(np.where(y_pred == 1)[0])
    coverage = len(true_pocket & predicted_pocket) / len(true_pocket) if true_pocket else 0.0
    overlap = len(true_pocket & predicted_pocket) / len(predicted_pocket) if predicted_pocket else 0.0

    return {
        "ROC-AUC": roc_auc,
        "Precision": precision,
        "Recall": recall,
        "F1": f1,
        "Pocket Coverage": coverage,
        "Pocket Overlap": overlap
    }

In [None]:
# Simulated predictions
y_true = np.array([0, 0, 1, 1, 0, 0, 1, 0, 0])
y_pred_probs = np.array([0.1, 0.2, 0.8, 0.7, 0.3, 0.2, 0.9, 0.1, 0.05])

metrics = evaluate_pocket_predictions(y_true, y_pred_probs)
for k, v in metrics.items():
    print(f"{k}: {v:.4f}")

In [None]:
import nglview as nv
view = nv.show_file("1XYZ.pdb")
view.add_representation("spacefill", selection="45 or 67 or 89", color="red")
view