Known protein modeler (ProtBERT, ESMfold)

In [None]:
fdjglkjdfhg

Use known LLM to determine binding pockets etc.

After parsing the input data should look like this, with a fasta protein ID, the protein sequence, and the 0/1 binary determining binding pocket residues

>protein_id
SEQUENCE
000000000011111000000000000000000000000000000000000000000000000000000000000

In [81]:
%%time

from Bio import SeqIO
import gzip

# Load protein sequences
def load_fasta_sequences(fasta_path):
    sequences = {}
    with gzip.open(fasta_path, "rt") as handle:
        for record in SeqIO.parse(handle, "fasta"):
            raw_id = record.id.split()[0]  # remove any trailing description
            if len(raw_id) >= 5:
                pdb_id = raw_id[:4]
                chain_id = raw_id[4:]
                key = f"{pdb_id}_{chain_id}".lower()  # match BioLiP format
                sequences[key] = str(record.seq)
    return sequences

# Parse BioLiP.txt
def parse_biolip_annotations(biolip_path):
    binding_map = {}  # protein_id -> set of binding residue indices
    with gzip.open(biolip_path, "rt") as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) < 9:
                continue
            pdb_id = parts[0]
            chain_id = parts[1]
            binding_residues = parts[8]  # e.g., "A:45,A:67,A:89"
            key = f"{pdb_id}_{chain_id}"
            indices = set()
            for res in binding_residues.split(","):
                if ":" in res:
                    _, idx = res.split(":")
                    if idx.isdigit():
                        indices.add(int(idx))
            binding_map[key] = indices
    return binding_map

# Generate labeled output
def generate_labeled_sequences(sequences, binding_map, output_path):
    with open(output_path, "w") as out:
        for protein_id, seq in sequences.items():
            if protein_id not in binding_map:
                continue
            labels = ["0"] * len(seq)
            for idx in binding_map[protein_id]:
                if 0 <= idx < len(seq):
                    labels[idx] = "1"
            out.write(f">{protein_id}\n{seq}\n{''.join(labels)}\n")

# Paths to your files
fasta_path = "protein.fasta.gz"
biolip_path = "BioLiP.txt.gz"
output_path = "biolip_labeled.txt"

# Run preprocessing
sequences = load_fasta_sequences(fasta_path)
binding_map = parse_biolip_annotations(biolip_path)
generate_labeled_sequences(sequences, binding_map, output_path)



CPU times: total: 5.08 s
Wall time: 5.28 s


In [82]:
%%time

from transformers import GPT2Tokenizer

# Load pretrained tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-125M")

# Add amino acid tokens (if needed)
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
tokenizer.add_tokens(amino_acids)

for aa in amino_acids:
    print(f"{aa}: {tokenizer.convert_tokens_to_ids(aa)}")

A: 32
C: 34
D: 35
E: 36
F: 37
G: 38
H: 39
I: 40
K: 42
L: 43
M: 44
N: 45
P: 47
Q: 48
R: 49
S: 50
T: 51
V: 53
W: 54
Y: 56
CPU times: total: 312 ms
Wall time: 550 ms


In [83]:
%%time

from transformers import GPT2LMHeadModel, GPT2Config

config = GPT2Config(vocab_size=len(tokenizer), n_embd=256, n_layer=6, n_head=4)
model = GPT2LMHeadModel(config)

print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 256)
    (wpe): Embedding(1024, 256)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x GPT2Block(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=768, nx=256)
          (c_proj): Conv1D(nf=256, nx=256)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=1024, nx=256)
          (c_proj): Conv1D(nf=256, nx=1024)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=256, out_features=50257, bias=False)
)
Total parameters: 17,867,008
CPU times: total: 5.62 s
Wa

In [84]:
%%time
import torch.nn as nn

class PocketPredictor(nn.Module):
    def __init__(self, gpt_model, hidden_dim):
        super().__init__()
        self.gpt = gpt_model
        self.head = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, input_ids):
        hidden = self.gpt.transformer(input_ids).last_hidden_state  # Extract the tensor
        return self.head(hidden)

CPU times: total: 0 ns
Wall time: 0 ns


In [97]:
sequence = "RRRRSVQWCAVSQPEATKCFQWQRNMRKVRGPPVSCIKRDSPIQCIQAIAENRADAVTLDGGFIYEAGLAPYKLRPVAAEVYGTERQPRTHYYAVAVVKKGGSFQLNELQGLKSCHTGLRRTAGWNVPIGTLRPFLNWTGPPEPIEAAVARFFSASCVPGADKGQFPNLCRLCAGTGENKCAFSSQEPYFSYSGAFKCLRDGAGDVAFIRESTVFEDLSDEAERDEYELLCPDNTRKPVDKFKDCHLARVPSHAVVARSVNGKEDAIWNLLRQAQEKFGKDKSPKFQLFGSPSGQKDLLFKDSAIGFSRVPPRIDSGLYLGSGYFTAIQNLRKSEEEVAARRARVVWCAVGEQELRKCNQWSGLSEGSVTCSSASTTEDCIALVLKGEADAMSLDEGYVYTAGKCGLVPVLAENYKSQQSSDPDPNCVDRPVEGYLAVAVVRRSDTSLTWNSVKGKKSCHTAVDRTAGWNIPMGLLFNQTGSCKFDEYFSQSCAPGSDPRSNLCALCIGDEQGENKCVPNSNERYYGYTGAFRCLAENAGDVAFVKDVTVLQNTDGNNNEAWAKDLKLADFALLCLDGKRKPVTEARSCHLAMAPNHAVVSRMDKVERLKQVLLHQQAKFGRNGSDCPDKFCLFQSETKNLLFNDNTECLARLHGKTTYEKYLGPQYVAGITNLKKCSTSPLLEACEFLRK"
input_ids = tokenizer.encode(sequence, return_tensors="pt")  # shape: [1, seq_len]

print("Input shape:", input_ids.shape)

decoded = tokenizer.decode(input_ids[0])
print("Decoded sequence:", decoded)
print("Token IDs:", input_ids[0].tolist())


Input shape: torch.Size([1, 691])
Decoded sequence: R R R R S V Q W C A V S Q P E A T K C F Q W Q R N M R K V R G P P V S C I K R D S P I Q C I Q A I A E N R A D A V T L D G G F I Y E A G L A P Y K L R P V A A E V Y G T E R Q P R T H Y Y A V A V V K K G G S F Q L N E L Q G L K S C H T G L R R T A G W N V P I G T L R P F L N W T G P P E P I E A A V A R F F S A S C V P G A D K G Q F P N L C R L C A G T G E N K C A F S S Q E P Y F S Y S G A F K C L R D G A G D V A F I R E S T V F E D L S D E A E R D E Y E L L C P D N T R K P V D K F K D C H L A R V P S H A V V A R S V N G K E D A I W N L L R Q A Q E K F G K D K S P K F Q L F G S P S G Q K D L L F K D S A I G F S R V P P R I D S G L Y L G S G Y F T A I Q N L R K S E E E V A A R R A R V V W C A V G E Q E L R K C N Q W S G L S E G S V T C S S A S T T E D C I A L V L K G E A D A M S L D E G Y V Y T A G K C G L V P V L A E N Y K S Q Q S S D P D P N C V D R P V E G Y L A V A V V R R S D T S L T W N S V K G K K S C H T A V D R T A G W N I P M G 

In [98]:
from Bio import pairwise2
from Bio.Seq import Seq

query_seq = sequence
best_match = None
best_score = -1

for pid, target_seq in sequences.items():
    alignments = pairwise2.align.globalxx(query_seq, target_seq, one_alignment_only=True)
    score = alignments[0].score
    if score > best_score:
        best_score = score
        best_match = pid

print("Best match:", best_match)
print("Score:", best_score)

Best match: 1cb6_a
Score: 689.0


In [99]:
from Bio import pairwise2
from Bio.Seq import Seq
import numpy as np

def transfer_binding_labels(query_seq, template_seq, binding_indices):
    # Run global alignment
    alignment = pairwise2.align.globalxx(query_seq, template_seq, one_alignment_only=True)[0]
    aligned_query = alignment.seqA
    aligned_template = alignment.seqB

    # Map binding labels from template to query
    y_true = []
    template_pos = -1  # position in template_seq
    for q_char, t_char in zip(aligned_query, aligned_template):
        if t_char != "-":
            template_pos += 1
        if q_char == "-":
            continue  # skip gaps in query
        if template_pos in binding_indices:
            y_true.append(1)
        else:
            y_true.append(0)

    return np.array(y_true)

In [100]:
print("Keys in binding_map:", list(binding_map.keys())[:10])
query_seq = Seq(sequence)
template_seq = Seq(sequences["5a22_a"])
binding_indices = binding_map["5a22_A"]

y_true = transfer_binding_labels(query_seq, template_seq, binding_indices)
print("y_true shape:", y_true.shape)  # should be (83,)
alignment = pairwise2.align.globalxx(query_seq, template_seq, one_alignment_only=True)[0]

print("Aligned query:   ", alignment.seqA)
print("Aligned template:", alignment.seqB)
print("Alignment score: ", alignment.score)


Keys in binding_map: ['101m_A', '102m_A', '103m_A', '104m_A', '105m_A', '106m_A', '107m_A', '108m_A', '109m_A', '10gs_A']
y_true shape: (691,)
Aligned query:    RRR--------------------R---SVQ------WCA-------V-----S-QPEA----TKCF-Q---WQRNMR----------------------K-------V-----RGP----PVSCI---K--R--DSP--IQ----CIQA----------I--A--E----NRAD-A------V------T---------LDGG--FIY-EAGL-APY-K-L-----RP-----VAAE--VY--GTER-QPRTHYYAVAVVKKGG-S--------F--Q----L-N---------EL-QG-----L-KS-----C-------------------H------T----GL----R--R----------TA--------G----WNV--P-I----G--------T------------L-----RPF--L----N----WT---G---P---P------E---PIE-AA-VAR--F------------F--------SA---SC----------V-------PG-----------------A----------D-KGQFPNLCR----LCA--GT--G-EN---KC-A--F-S--SQ----EPYFS----YS-----------G---AF-------KC--LRD---GA-GDV-----AFIR--------E---------ST---VFED------L---SDEA--ERD--EY--EL--LC------PD-----N-T-------R---------------KP--------VD-----KF------KDCH-LAR-----VP---------SHA--VV----ARS----V-NGK-ED----AIWN---

In [101]:
%%time

import torch
import numpy as np

pocket_model = PocketPredictor(model, hidden_dim=256)

with torch.no_grad():
    output = pocket_model(input_ids)  # shape: [1, 83, 1]
    y_pred_probs = output.squeeze(0).squeeze(-1).numpy()  # shape: [83]

print("Shape:", y_pred_probs.shape)
print("Sample predictions:", y_pred_probs[:10])
y_pred = (y_pred_probs > 0.5).astype(int)
print("Predicted pocket residue indices:", np.where(y_pred == 1)[0])

Shape: (691,)
Sample predictions: [0.43492684 0.43918118 0.4395398  0.49329808 0.50368553 0.49490663
 0.5472627  0.47116613 0.5474742  0.49970692]
Predicted pocket residue indices: [  4   6   8  17  25  28  29  32  33  36  37  39  41  42  43  44  45  46
  47  48  49  52  53  54  55  56  57  58  59  60  62  63  64  65  66  67
  68  69  72  73  74  75  76  77  78  79  80  83  84  86  87  89  90  91
  92  94  95  96  98  99 100 102 103 104 105 106 109 110 111 112 114 116
 117 118 119 121 122 123 124 126 127 128 131 132 133 134 135 136 138 139
 140 141 142 143 144 145 146 147 148 150 151 152 154 156 159 160 161 162
 163 164 165 166 167 168 169 170 171 172 174 176 178 179 180 181 182 185
 187 189 190 194 195 196 197 198 201 202 203 204 205 206 207 208 212 213
 214 215 217 219 220 221 222 223 224 225 227 228 229 230 232 233 235 236
 237 238 239 240 241 242 243 244 246 247 249 250 252 253 254 255 256 258
 259 260 261 263 264 265 266 267 268 269 270 271 272 273 275 276 277 278
 279 280 281 284

In [102]:
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score

# y_true: binary labels (0/1), y_pred_probs: predicted probabilities
query_seq = Seq(sequence)  # your actual 83-residue input
template_seq = Seq(sequences["5a22_a"])
binding_indices = binding_map["5a22_A"]

def transfer_binding_labels_full(query_seq, template_seq, binding_indices):
    alignment = pairwise2.align.globalxx(query_seq, template_seq, one_alignment_only=True)[0]
    aligned_query = alignment.seqA
    aligned_template = alignment.seqB

    y_true = []
    template_pos = -1
    query_pos = -1

    for q_char, t_char in zip(aligned_query, aligned_template):
        if t_char != "-":
            template_pos += 1
        if q_char != "-":
            query_pos += 1
            if template_pos in binding_indices:
                y_true.append(1)
            else:
                y_true.append(0)

    # Pad or truncate to match query length
    if len(y_true) < len(query_seq):
        y_true += [0] * (len(query_seq) - len(y_true))
    elif len(y_true) > len(query_seq):
        y_true = y_true[:len(query_seq)]

    return np.array(y_true)

y_true = transfer_binding_labels_full(query_seq, template_seq, binding_indices)

threshold = 0.5
y_pred = (y_pred_probs > threshold).astype(int)

roc_auc = roc_auc_score(y_true, y_pred_probs)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
print("y_true shape:", y_true.shape)
print("y_pred_probs shape:", y_pred_probs.shape)
print("y_pred shape:", y_pred.shape)
print("Positive labels in y_true:", np.sum(y_true))
print("Predicted positives:", np.sum(y_pred))
print(f"ROC AUC: {roc_auc:.3f}")
print(f"Precision: {precision:.3f}")
print(f"Recall: {recall:.3f}")
print(f"F1 Score: {f1:.3f}")

y_true shape: (691,)
y_pred_probs shape: (691,)
y_pred shape: (691,)
Positive labels in y_true: 0
Predicted positives: 521
ROC AUC: nan
Precision: 0.000
Recall: 0.000
F1 Score: 0.000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [103]:
true_pocket = set(i for i, label in enumerate(y_true) if label == 1)
predicted_pocket = set(i for i, prob in enumerate(y_pred_probs) if prob > threshold)

coverage = len(true_pocket & predicted_pocket) / len(true_pocket)
overlap = len(true_pocket & predicted_pocket) / len(predicted_pocket)

ZeroDivisionError: division by zero

In [None]:
fetch 1XYZ, async=0
select predicted_pocket, resi 45+67+89  # replace with predicted residue indices
color red, predicted_pocket
show surface, predicted_pocket

In [None]:
import numpy as np
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score

def evaluate_pocket_predictions(y_true, y_pred_probs, threshold=0.5):
    """
    Evaluate binding pocket predictions.
    
    Args:
        y_true (np.array): Binary ground truth labels (0 = non-pocket, 1 = pocket)
        y_pred_probs (np.array): Predicted probabilities for each residue
        threshold (float): Classification threshold for binary decision

    Returns:
        dict: Evaluation metrics
    """
    y_pred = (y_pred_probs > threshold).astype(int)

    # Residue-level metrics
    roc_auc = roc_auc_score(y_true, y_pred_probs)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)

    # Pocket-level coverage
    true_pocket = set(np.where(y_true == 1)[0])
    predicted_pocket = set(np.where(y_pred == 1)[0])
    coverage = len(true_pocket & predicted_pocket) / len(true_pocket) if true_pocket else 0.0
    overlap = len(true_pocket & predicted_pocket) / len(predicted_pocket) if predicted_pocket else 0.0

    return {
        "ROC-AUC": roc_auc,
        "Precision": precision,
        "Recall": recall,
        "F1": f1,
        "Pocket Coverage": coverage,
        "Pocket Overlap": overlap
    }

In [None]:
# Simulated predictions
y_true = np.array([0, 0, 1, 1, 0, 0, 1, 0, 0])
y_pred_probs = np.array([0.1, 0.2, 0.8, 0.7, 0.3, 0.2, 0.9, 0.1, 0.05])

metrics = evaluate_pocket_predictions(y_true, y_pred_probs)
for k, v in metrics.items():
    print(f"{k}: {v:.4f}")

In [None]:
import nglview as nv
view = nv.show_file("1XYZ.pdb")
view.add_representation("spacefill", selection="45 or 67 or 89", color="red")
view