In [1]:
import torch, sys

print("Python:", sys.version)
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA compiled:", torch.version.cuda)
print("cuDNN version:", torch.backends.cudnn.version())
if torch.cuda.is_available():
    print("GPU count:", torch.cuda.device_count())
    for i in range(torch.cuda.device_count()):
        print(f"Device {i}:", torch.cuda.get_device_name(i))
else:
    print("No CUDA-visible device in this environment.")


Python: 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:20:11) [MSC v.1938 64 bit (AMD64)]
Torch version: 2.6.0+cu124
CUDA available: True
CUDA compiled: 12.4
cuDNN version: 90100
GPU count: 1
Device 0: NVIDIA GeForce RTX 4090 Laptop GPU


In [2]:
%%time

from Bio import SeqIO
import gzip
import numpy as np

# Canonical ID helper: use consistently for sequences & binding_map
def canonical_id(pdb_id, chain_id):
    """
    Normalize a PDB+chain ID to a consistent form, e.g. '1abc_A' -> '1abc_a'.
    """
    return f"{pdb_id.lower()}_{chain_id.lower()}"

# Load protein sequences
def load_fasta_sequences(fasta_path):
    sequences = {}
    with gzip.open(fasta_path, "rt") as handle:
        for record in SeqIO.parse(handle, "fasta"):
            raw_id = record.id.split()[0]  # remove any trailing description

            # ---- ID PARSING: this may need adjusting depending on your FASTA headers ----
            # First, handle simple "1abcA" style IDs
            if "_" not in raw_id and len(raw_id) >= 5:
                pdb_id = raw_id[:4]
                chain_id = raw_id[4:]
            # If there *is* an underscore, assume "1abc_A" style and use first 4 chars as pdb,
            # first char after "_" as chain, ignore trailing stuff
            elif "_" in raw_id:
                left, right = raw_id.split("_", 1)
                pdb_id = left[:4]
                chain_id = right[0] if len(right) > 0 else "A"
            else:
                # Anything weird, skip
                continue

            key = canonical_id(pdb_id, chain_id)
            sequences[key] = str(record.seq)
    return sequences

# Parse BioLiP.txt
def parse_biolip_annotations(biolip_path):
    """
    Returns:
        binding_map: dict[canonical_id] -> set(0-based residue indices)

    Uses column 9 from BioLiP (binding residues renumbered starting from 1),
    as described in the official readme.txt.
    """
    binding_map = {}

    with gzip.open(biolip_path, "rt") as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith("#"):
                continue

            # BioLiP is tab-delimited
            parts = line.split("\t")
            # According to the readme, there are 21 columns; we just need the first 9+
            if len(parts) < 9:
                continue

            pdb_id = parts[0]          # column 1
            chain_id = parts[1]        # column 2
            key = canonical_id(pdb_id, chain_id)

            # Column 9 (index 8): binding residues, renumbered starting from 1
            # Example: "N73 L74 A75 V108 H111 ..."
            binding_field = parts[8]

            indices = []
            for token in binding_field.split():
                # Tokens are usually like "N73" or "A182". Extract the numeric part.
                digits = "".join(ch for ch in token if ch.isdigit())
                if not digits:
                    continue
                idx = int(digits) - 1   # convert from 1-based → 0-based
                if idx >= 0:
                    indices.append(idx)

            if not indices:
                continue

            # Merge indices if the same chain appears multiple times
            binding_map.setdefault(key, set()).update(indices)

    return binding_map

# Generate labeled output
def generate_labeled_sequences(sequences, binding_map, output_path):
    with open(output_path, "w") as out:
        for protein_id, seq in sequences.items():
            if protein_id not in binding_map:
                continue
            labels = ["0"] * len(seq)
            for idx in binding_map[protein_id]:
                if 0 <= idx < len(seq):
                    labels[idx] = "1"
            out.write(f">{protein_id}\n{seq}\n{''.join(labels)}\n")

# Paths to your files
fasta_path = "protein.fasta.gz"
biolip_path = "BioLiP.txt.gz"
output_path = "biolip_labeled.txt"

# Run preprocessing
sequences = load_fasta_sequences(fasta_path)
binding_map = parse_biolip_annotations(biolip_path)
generate_labeled_sequences(sequences, binding_map, output_path)

def load_labeled_dataset(labeled_path):
    """
    Read the labeled BioLiP file and return a list of (sequence, labels).
    Assumes the file is in repeating 3-line blocks:
      >protein_id
      SEQUENCE
      001010...
    """
    dataset = []
    with open(labeled_path) as f:
        # Remove empty lines to avoid misalignment
        raw_lines = f.read().splitlines()
        lines = [ln for ln in raw_lines if ln.strip()]

    if len(lines) % 3 != 0:
        print(f"Warning: labeled file has {len(lines)} non-empty lines, "
              f"which is not a multiple of 3. Truncating the last incomplete record.")

    # Go only up to len(lines)-2 so i+1 and i+2 are always valid
    for i in range(0, len(lines) - 2, 3):
        header = lines[i]
        if not header.startswith(">"):
            raise ValueError(f"Expected header starting with '>' at line {i}, got: {header!r}")

        protein_id = header[1:]   # remove ">"
        seq = lines[i + 1]
        labels_line = lines[i + 2]

        labels = np.array([int(x) for x in labels_line])
        # Sanity check: labels and sequence must match in length
        if len(labels) != len(seq):
            raise ValueError(
                f"Length mismatch for {protein_id}: seq={len(seq)}, labels={len(labels)}"
            )

        dataset.append((seq, labels))

    return dataset

# Now actually set it
output_path = "biolip_labeled.txt"
dataset = load_labeled_dataset(output_path)

# --- Debug stats so we can see what's going on ---
print("Total sequences parsed from FASTA:", len(sequences))
print("Total chains with binding annotations:", len(binding_map))
common_ids = set(sequences.keys()) & set(binding_map.keys())
print("Number of overlapping IDs:", len(common_ids))
print("Example overlapping IDs:", list(common_ids)[:10])

print("Dataset size:", len(dataset))
if len(dataset) > 0:
    print("First entry sequence length:", len(dataset[0][0]))
    print("First entry labels sum (pocket residues):", np.sum(dataset[0][1]))
else:
    print("No labeled sequences were written to biolip_labeled.txt")


Total sequences parsed from FASTA: 475001
Total chains with binding annotations: 474607
Number of overlapping IDs: 474607
Example overlapping IDs: ['1bi3_b', '4tvw_d', '4m19_a', '6jjk_c', '1zjb_b', '8fc5_2v', '4qij_e', '4rdq_b', '6zqe_dn', '6ndk_yw']
Dataset size: 474607
First entry sequence length: 154
First entry labels sum (pocket residues): 8
CPU times: total: 35.4 s
Wall time: 36.2 s


In [3]:
%%time

import torch
import torch.nn as nn
import numpy as np

# Amino acid vocabulary (20 standard AAs)
AA_VOCAB = "ACDEFGHIKLMNPQRSTVWY"
aa_to_idx = {aa: i for i, aa in enumerate(AA_VOCAB)}

class PocketPredictor(nn.Module):
    def __init__(self, vocab_size=20, embed_dim=64, hidden_dim=128):
        super().__init__()
        # Embedding layer for amino acids
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        # Simple feedforward layers
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids):
        # input_ids: [batch, seq_len]
        x = self.embedding(input_ids)           # [batch, seq_len, embed_dim]
        x = self.fc1(x)                         # [batch, seq_len, hidden_dim]
        x = self.relu(x)
        x = self.fc2(x)                         # [batch, seq_len, 1]
        return self.sigmoid(x).squeeze(-1)      # [batch, seq_len]

# Your query sequence (the one you want to predict pockets on)
sequence = "RRRRSVQWCAVSQPEATKCFQWQRNMRKVRGPPVSCIKRDSPIQCIQAIAENRADAVTLDGGFIYEAGLAPYKLRPVAAEVYGTERQPRTHYYAVAVVKKGGSFQLNELQGLKSCHTGLRRTAGWNVPIGTLRPFLNWTGPPEPIEAAVARFFSASCVPGADKGQFPNLCRLCAGTGENKCAFSSQEPYFSYSGAFKCLRDGAGDVAFIRESTVFEDLSDEAERDEYELLCPDNTRKPVDKFKDCHLARVPSHAVVARSVNGKEDAIWNLLRQAQEKFGKDKSPKFQLFGSPSGQKDLLFKDSAIGFSRVPPRIDSGLYLGSGYFTAIQNLRKSEEEVAARRARVVWCAVGEQELRKCNQWSGLSEGSVTCSSASTTEDCIALVLKGEADAMSLDEGYVYTAGKCGLVPVLAENYKSQQSSDPDPNCVDRPVEGYLAVAVVRRSDTSLTWNSVKGKKSCHTAVDRTAGWNIPMGLLFNQTGSCKFDEYFSQSCAPGSDPRSNLCALCIGDEQGENKCVPNSNERYYGYTGAFRCLAENAGDVAFVKDVTVLQNTDGNNNEAWAKDLKLADFALLCLDGKRKPVTEARSCHLAMAPNHAVVSRMDKVERLKQVLLHQQAKFGRNGSDCPDKFCLFQSETKNLLFNDNTECLARLHGKTTYEKYLGPQYVAGITNLKKCSTSPLLEACEFLRK"

# Check that the sequence only contains amino acids in our vocab
invalid_aas = sorted({aa for aa in sequence if aa not in AA_VOCAB})
if invalid_aas:
    raise ValueError(f"Sequence contains unknown AAs not in AA_VOCAB: {invalid_aas}")

# Encode to indices
input_ids = torch.tensor([[aa_to_idx[aa] for aa in sequence]], dtype=torch.long)

print("Input shape:", input_ids.shape)

# Decode back just to sanity-check mapping
decoded_seq = "".join(AA_VOCAB[idx] for idx in input_ids[0].tolist())
print("Decoded sequence matches original:", decoded_seq == sequence)
print("First 30 token IDs:", input_ids[0][:30].tolist())

Input shape: torch.Size([1, 691])
Decoded sequence matches original: True
First 30 token IDs: [14, 14, 14, 14, 15, 17, 13, 18, 1, 0, 17, 15, 13, 12, 3, 0, 16, 8, 1, 4, 13, 18, 13, 14, 11, 10, 14, 8, 17, 14]
CPU times: total: 0 ns
Wall time: 1 ms


In [4]:
%%time

from collections import Counter

def normalized_kmer_score(seq1, seq2, k=3):
    kmers1 = Counter(seq1[i:i+k] for i in range(len(seq1)-k+1))
    kmers2 = Counter(seq2[i:i+k] for i in range(len(seq2)-k+1))
    shared = sum((kmers1 & kmers2).values())
    denom = min(sum(kmers1.values()), sum(kmers2.values()))
    return shared / denom if denom > 0 else 0.0

scores = []
total = len(sequences)
with_binding = 0
with_length = 0

for pid, target_seq in sequences.items():
    # sequences and binding_map now share the same canonical IDs
    if pid not in binding_map:
        continue
    with_binding += 1

    ratio = len(target_seq) / len(sequence)
    if ratio < 0.3 or ratio > 10.0:
        continue
    with_length += 1

    score = normalized_kmer_score(sequence, target_seq, k=3)
    scores.append((pid, score))

scores.sort(key=lambda x: x[1], reverse=True)
top_candidates = scores[:50]

print("Total sequences:", total)
print("Survive binding filter (existence only):", with_binding)
print("Survive length filter:", with_length)
print("Final candidates:", len(scores))
print("Top 10 candidates:", top_candidates[:10])


Total sequences: 475001
Survive binding filter (existence only): 474607
Survive length filter: 263046
Final candidates: 263046
Top 10 candidates: [('1bka_a', 0.9956268221574344), ('7n88_b', 0.9956268221574344), ('1cb6_a', 0.9941944847605225), ('1fck_a', 0.9912917271407837), ('2pms_a', 0.9908814589665653), ('2pms_b', 0.9908814589665653), ('1lct_a', 0.9906832298136646), ('1h44_a', 0.990625), ('1b0l_a', 0.9898403483309144), ('2bjj_x', 0.9869375907111756)]
CPU times: total: 1min 11s
Wall time: 1min 12s


In [5]:
%%time

from Bio import pairwise2
from Bio.Seq import Seq

best_match = None
best_score = -1

for pid, _ in top_candidates:
    target_seq = sequences[pid]
    alignments = pairwise2.align.localms(sequence, target_seq,
                                         2, -1, -0.5, -0.1,
                                         one_alignment_only=True)
    score = alignments[0].score
    if score > best_score:
        best_score = score
        best_match = pid

print("Best match after prefilter:", best_match)
print("Local alignment score:", best_score)




Best match after prefilter: 1cb6_a
Local alignment score: 1377.0
CPU times: total: 3.64 s
Wall time: 3.68 s


In [6]:
from Bio import pairwise2
from Bio.Seq import Seq
import numpy as np

def transfer_binding_labels(query_seq, template_seq, binding_indices):
    """
    Global alignment-based transfer:
      - query_seq: your sequence (string or Seq)
      - template_seq: template sequence (string or Seq)
      - binding_indices: set of 0-based residue indices on the template
    """
    alignment = pairwise2.align.globalms(query_seq, template_seq,
                                         2, -1, -5, -1,
                                         one_alignment_only=True)[0]
    aligned_query = alignment.seqA
    aligned_template = alignment.seqB

    y_true = []
    template_pos = -1

    for q_char, t_char in zip(aligned_query, aligned_template):
        if t_char != "-":
            template_pos += 1
        if q_char == "-":
            # gap in query, skip (no label)
            continue
        y_true.append(1 if template_pos in binding_indices else 0)

    return np.array(y_true), alignment

# Example usage with the best_match chosen above
query_seq = Seq(sequence)
template_seq = Seq(sequences[best_match])
binding_indices = binding_map[best_match]   # 0-based indices for this canonical ID

y_true_global, alignment = transfer_binding_labels(query_seq, template_seq, binding_indices)

print("Aligned query (first 120 chars):   ", alignment.seqA)
print("Aligned template (first 120 chars):", alignment.seqB)
print("Global alignment score: ", alignment.score)
print("y_true_global length:", len(y_true_global), "positives:", np.sum(y_true_global))


Aligned query (first 120 chars):    RRRRSVQWCAVSQPEATKCFQWQRNMRKVRGPPVSCIKRDSPIQCIQAIAENRADAVTLDGGFIYEAGLAPYKLRPVAAEVYGTERQPRTHYYAVAVVKKGGSFQLNELQGLKSCHTGLRRTAGWNVPIGTLRPFLNWTGPPEPIEAAVARFFSASCVPGADKGQFPNLCRLCAGTGENKCAFSSQEPYFSYSGAFKCLRDGAGDVAFIRESTVFEDLSDEAERDEYELLCPDNTRKPVDKFKDCHLARVPSHAVVARSVNGKEDAIWNLLRQAQEKFGKDKSPKFQLFGSPSGQKDLLFKDSAIGFSRVPPRIDSGLYLGSGYFTAIQNLRKSEEEVAARRARVVWCAVGEQELRKCNQWSGLSEGSVTCSSASTTEDCIALVLKGEADAMSLDEGYVYTAGKCGLVPVLAENYKSQQSSDPDPNCVDRPVEGYLAVAVVRRSDTSLTWNSVKGKKSCHTAVDRTAGWNIPMGLLFNQTGSCKFDEYFSQSCAPGSDPRSNLCALCIGDEQGENKCVPNSNERYYGYTGAFRCLAENAGDVAFVKDVTVLQNTDGNNNEAWAKDLKLADFALLCLDGKRKPVTEARSCHLAMAPNHAVVSRMDKVERLKQVLLHQQAKFGRNGSDCPDKFCLFQSETKNLLFNDNTECLARLHGKTTYEKYLGPQYVAGITNLKKCSTSPLLEACEFLRK
Aligned template (first 120 chars): GRRRSVQWCAVSQPEATKCFQWQRNMRKVRGPPVSCIKRDSPIQCIQAIAENRADAVTLDGGFIYEAGLAPYKLRPVAAEVYGTERQPRTHYYAVAVVKKGGSFQLNELQGLKSCHTGLRRTAGWNVPIGTLRPFLNWTGPPEPIEAAVARFFSASCVPGADKGQFPNLCRLCAGTGENKCAFSSQEPYFSYSGAFKCLRDGAGDVAFIRESTVFEDLSDEAERDEYELLCPDNTR

In [7]:
%%time

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random

# Make sure AA_VOCAB, aa_to_idx, PocketPredictor, dataset are already defined

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

pocket_model = PocketPredictor(
    vocab_size=len(AA_VOCAB),
    embed_dim=64,
    hidden_dim=256
).to(device)

optimizer = optim.Adam(pocket_model.parameters(), lr=1e-4)
criterion = nn.BCELoss()

# For speed: train on a subset first (you can increase this later)
MAX_TRAIN_SEQS = 50000   # try 10_000 if you want it really quick
NUM_EPOCHS = 2           # bump to 5–10 later if you like

indices = list(range(len(dataset)))
random.shuffle(indices)
train_indices = indices[:MAX_TRAIN_SEQS]

print(f"Training on {len(train_indices)} sequences out of {len(dataset)}")

for epoch in range(NUM_EPOCHS):
    pocket_model.train()
    epoch_loss = 0.0
    used = 0

    for idx in train_indices:
        seq, labels_np = dataset[idx]

        # Skip sequences that contain non-standard AAs
        if any(aa not in aa_to_idx for aa in seq):
            continue

        # Encode sequence into indices 0–19
        inputs = torch.tensor(
            [aa_to_idx[aa] for aa in seq],
            dtype=torch.long,
            device=device
        ).unsqueeze(0)  # [1, L]

        labels = torch.tensor(
            labels_np,
            dtype=torch.float32,
            device=device
        ).unsqueeze(0)  # [1, L]

        optimizer.zero_grad()
        outputs = pocket_model(inputs)  # [1, L]
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        used += 1

    avg_loss = epoch_loss / max(used, 1)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, examples used: {used}, avg loss: {avg_loss:.4f}")


Using device: cuda
Training on 50000 sequences out of 474607
Epoch 1/2, examples used: 49999, avg loss: 0.3294
Epoch 2/2, examples used: 49999, avg loss: 0.3287
CPU times: total: 2min 2s
Wall time: 2min 18s


In [8]:
%%time

# Ensure your 'sequence' string and aa_to_idx are already defined

# Sanity: skip if sequence has unknown AAs
invalid_aas = sorted({aa for aa in sequence if aa not in aa_to_idx})
if invalid_aas:
    raise ValueError(f"Query sequence contains unknown AAs: {invalid_aas}")

pocket_model.eval()
with torch.no_grad():
    input_ids = torch.tensor(
        [[aa_to_idx[aa] for aa in sequence]],
        dtype=torch.long,
        device=device
    )
    y_pred_probs = pocket_model(input_ids).squeeze(0).cpu().numpy()

print("Trained prediction shape:", y_pred_probs.shape)
print("Sample trained predictions:", y_pred_probs[:10])


Trained prediction shape: (691,)
Sample trained predictions: [0.22895603 0.22895603 0.22895603 0.22895603 0.10172698 0.05561124
 0.09778902 0.1491107  0.23421668 0.06217684]
CPU times: total: 0 ns
Wall time: 2.46 ms


In [9]:
%%time

from Bio import pairwise2
from Bio.Seq import Seq
import numpy as np

def transfer_binding_labels_global(query_seq, template_seq, binding_indices):
    """
    Global alignment-based label transfer.

    binding_indices: set/list of 0-based residue indices in the *template* sequence.
    Returns:
        y_true: np.array of shape [len(query_seq)], 0/1 per residue.
        alignment: Biopython alignment object.
    """
    alignment = pairwise2.align.globalms(
        query_seq, template_seq,
        2, -1,    # match, mismatch
        -5, -1,   # gap open, gap extend
        one_alignment_only=True
    )[0]

    aligned_query = alignment.seqA
    aligned_template = alignment.seqB

    y_true = []
    template_pos = -1

    for q_char, t_char in zip(aligned_query, aligned_template):
        if t_char != "-":
            template_pos += 1
        if q_char == "-":
            # gap in query: skip (no residue)
            continue
        y_true.append(1 if template_pos in binding_indices else 0)

    y_true = np.array(y_true, dtype=int)

    # If alignment weirdness causes small off-by-one issues, be forgiving:
    if len(y_true) < len(query_seq):
        y_true = np.pad(y_true, (0, len(query_seq) - len(y_true)))
    elif len(y_true) > len(query_seq):
        y_true = y_true[:len(query_seq)]

    return y_true, alignment

# Use your best_match template
query_seq = Seq(sequence)
template_seq = Seq(sequences[best_match])
binding_indices = binding_map[best_match]   # 0-based indices

y_true, alignment = transfer_binding_labels_global(
    query_seq, template_seq, binding_indices
)

print("y_true length:", len(y_true), "positives:", np.sum(y_true))
print("Alignment score:", alignment.score)
print("Alignment:")
print("Query:   ", alignment.seqA)
print("Template:", alignment.seqB)


y_true length: 691 positives: 6
Alignment score: 1376.0
Alignment:
Query:    RRRRSVQWCAVSQPEATKCFQWQRNMRKVRGPPVSCIKRDSPIQCIQAIAENRADAVTLDGGFIYEAGLAPYKLRPVAAEVYGTERQPRTHYYAVAVVKKGGSFQLNELQGLKSCHTGLRRTAGWNVPIGTLRPFLNWTGPPEPIEAAVARFFSASCVPGADKGQFPNLCRLCAGTGENKCAFSSQEPYFSYSGAFKCLRDGAGDVAFIRESTVFEDLSDEAERDEYELLCPDNTRKPVDKFKDCHLARVPSHAVVARSVNGKEDAIWNLLRQAQEKFGKDKSPKFQLFGSPSGQKDLLFKDSAIGFSRVPPRIDSGLYLGSGYFTAIQNLRKSEEEVAARRARVVWCAVGEQELRKCNQWSGLSEGSVTCSSASTTEDCIALVLKGEADAMSLDEGYVYTAGKCGLVPVLAENYKSQQSSDPDPNCVDRPVEGYLAVAVVRRSDTSLTWNSVKGKKSCHTAVDRTAGWNIPMGLLFNQTGSCKFDEYFSQSCAPGSDPRSNLCALCIGDEQGENKCVPNSNERYYGYTGAFRCLAENAGDVAFVKDVTVLQNTDGNNNEAWAKDLKLADFALLCLDGKRKPVTEARSCHLAMAPNHAVVSRMDKVERLKQVLLHQQAKFGRNGSDCPDKFCLFQSETKNLLFNDNTECLARLHGKTTYEKYLGPQYVAGITNLKKCSTSPLLEACEFLRK
Template: GRRRSVQWCAVSQPEATKCFQWQRNMRKVRGPPVSCIKRDSPIQCIQAIAENRADAVTLDGGFIYEAGLAPYKLRPVAAEVYGTERQPRTHYYAVAVVKKGGSFQLNELQGLKSCHTGLRRTAGWNVPIGTLRPFLNWTGPPEPIEAAVARFFSASCVPGADKGQFPNLCRLCAGTGENKCAFSSQEPYFSYSGAFKCLRDGAGDVAFIRESTVFEDLSDE

In [10]:
%%time

from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score

def evaluate_pocket_predictions(y_true, y_pred_probs, threshold=0.5):
    """
    Evaluate binding pocket predictions.
    
    Args:
        y_true (np.array): Binary ground truth labels (0 = non-pocket, 1 = pocket)
        y_pred_probs (np.array): Predicted probabilities for each residue
        threshold (float): Classification threshold for binary decision

    Returns:
        dict: Evaluation metrics
    """
    if len(y_true) != len(y_pred_probs):
        raise ValueError(
            f"Shape mismatch: y_true={len(y_true)}, y_pred_probs={len(y_pred_probs)}"
        )

    y_pred = (y_pred_probs > threshold).astype(int)

    # Residue-level metrics
    roc_auc = roc_auc_score(y_true, y_pred_probs)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)

    # Pocket-level coverage / overlap
    true_pocket = set(np.where(y_true == 1)[0])
    predicted_pocket = set(np.where(y_pred == 1)[0])
    intersection = true_pocket & predicted_pocket

    coverage = len(intersection) / len(true_pocket) if true_pocket else 0.0
    overlap = len(intersection) / len(predicted_pocket) if predicted_pocket else 0.0

    return {
        "ROC-AUC": roc_auc,
        "Precision": precision,
        "Recall": recall,
        "F1": f1,
        "Pocket Coverage": coverage,
        "Pocket Overlap": overlap,
        "True pocket residues": len(true_pocket),
        "Predicted pocket residues": len(predicted_pocket),
    }

metrics = evaluate_pocket_predictions(y_true, y_pred_probs, threshold=0.5)

for k, v in metrics.items():
    if isinstance(v, float):
        print(f"{k}: {v:.3f}")
    else:
        print(f"{k}: {v}")


ROC-AUC: 0.698
Precision: 0.000
Recall: 0.000
F1: 0.000
Pocket Coverage: 0.000
Pocket Overlap: 0.000
True pocket residues: 6
Predicted pocket residues: 0
CPU times: total: 750 ms
Wall time: 773 ms


In [11]:
import numpy as np

print("Min prob:", float(y_pred_probs.min()))
print("Max prob:", float(y_pred_probs.max()))
print("Mean prob:", float(y_pred_probs.mean()))
print("Quantiles:", np.quantile(y_pred_probs, [0.5, 0.9, 0.95, 0.99]))


Min prob: 0.055070921778678894
Max prob: 0.23421667516231537
Mean prob: 0.10539519786834717
Quantiles: [0.09119008 0.22895603 0.23192109 0.23421668]


In [12]:
from sklearn.metrics import precision_score, recall_score, f1_score

def metrics_at_threshold(y_true, y_pred_probs, threshold):
    y_pred = (y_pred_probs > threshold).astype(int)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    num_pred = int(y_pred.sum())
    return precision, recall, f1, num_pred

for thr in [0.1, 0.2, 0.3, 0.4, 0.5]:
    p, r, f, n = metrics_at_threshold(y_true, y_pred_probs, thr)
    print(f"thr={thr:.2f} -> P={p:.3f}, R={r:.3f}, F1={f:.3f}, predicted+={n}")


thr=0.10 -> P=0.016, R=0.667, F1=0.032, predicted+=244
thr=0.20 -> P=0.023, R=0.333, F1=0.043, predicted+=86
thr=0.30 -> P=0.000, R=0.000, F1=0.000, predicted+=0
thr=0.40 -> P=0.000, R=0.000, F1=0.000, predicted+=0
thr=0.50 -> P=0.000, R=0.000, F1=0.000, predicted+=0


In [13]:
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score

k = int(y_true.sum())  # number of true pockets = 6
topk_idx = np.argsort(-y_pred_probs)[:k]  # indices of top-k residues

y_pred_topk = np.zeros_like(y_true)
y_pred_topk[topk_idx] = 1

p = precision_score(y_true, y_pred_topk, zero_division=0)
r = recall_score(y_true, y_pred_topk, zero_division=0)
f1 = f1_score(y_true, y_pred_topk, zero_division=0)

print(f"Top-k strategy (k = {k}):")
print(f"Precision: {p:.3f}, Recall: {r:.3f}, F1: {f1:.3f}")
print("Top-k indices:", topk_idx)
print("True pocket indices:", np.where(y_true == 1)[0])


Top-k strategy (k = 6):
Precision: 0.000, Recall: 0.000, F1: 0.000
Top-k indices: [ 44 458 574 588 426   8]
True pocket indices: [120 122 191 460 464 527]


In [14]:
import torch.nn as nn

class PocketPredictorLogits(nn.Module):
    def __init__(self, vocab_size=20, embed_dim=64, hidden_dim=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, input_ids):
        # [batch, seq_len] -> [batch, seq_len, embed_dim]
        x = self.embedding(input_ids)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)      # [batch, seq_len, 1]
        return x.squeeze(-1) # [batch, seq_len] logits


In [15]:
%%time

import torch
import torch.optim as optim
import numpy as np
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Rough estimate of positive / negative ratio on a subset
sample_for_ratio = 10000
all_pos = 0
all_total = 0
for i in random.sample(range(len(dataset)), min(sample_for_ratio, len(dataset))):
    _, labels_np = dataset[i]
    all_pos += labels_np.sum()
    all_total += len(labels_np)

pos_frac = all_pos / all_total
neg_frac = 1.0 - pos_frac
print(f"Estimated positive fraction: {pos_frac:.5f}")

# pos_weight > 1 means we penalize missing a positive more than missing a negative
pos_weight = torch.tensor([neg_frac / max(pos_frac, 1e-6)], device=device)
print("Using pos_weight:", float(pos_weight))

pocket_model2 = PocketPredictorLogits(
    vocab_size=len(AA_VOCAB),
    embed_dim=64,
    hidden_dim=256
).to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = optim.Adam(pocket_model2.parameters(), lr=1e-4)

MAX_TRAIN_SEQS = 50000
NUM_EPOCHS = 3

indices = list(range(len(dataset)))
random.shuffle(indices)
train_indices = indices[:MAX_TRAIN_SEQS]

print(f"Training (logits model) on {len(train_indices)} sequences out of {len(dataset)}")

for epoch in range(NUM_EPOCHS):
    pocket_model2.train()
    epoch_loss = 0.0
    used = 0

    for idx in train_indices:
        seq, labels_np = dataset[idx]

        if any(aa not in aa_to_idx for aa in seq):
            continue

        inputs = torch.tensor(
            [aa_to_idx[aa] for aa in seq],
            dtype=torch.long,
            device=device
        ).unsqueeze(0)

        labels = torch.tensor(
            labels_np,
            dtype=torch.float32,
            device=device
        ).unsqueeze(0)

        optimizer.zero_grad()
        logits = pocket_model2(inputs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        used += 1

    avg_loss = epoch_loss / max(used, 1)
    print(f"[logits] Epoch {epoch+1}/{NUM_EPOCHS}, examples used: {used}, avg loss: {avg_loss:.4f}")


Using device: cuda
Estimated positive fraction: 0.06433
Using pos_weight: 14.543944744092196
Training (logits model) on 50000 sequences out of 474607
[logits] Epoch 1/3, examples used: 49998, avg loss: 1.5366
[logits] Epoch 2/3, examples used: 49998, avg loss: 1.5351
[logits] Epoch 3/3, examples used: 49998, avg loss: 1.5346
CPU times: total: 3min 38s
Wall time: 3min 54s


In [16]:
%%time

# Predictions on your query sequence using the logits model
invalid_aas = sorted({aa for aa in sequence if aa not in aa_to_idx})
if invalid_aas:
    raise ValueError(f"Query sequence contains unknown AAs: {invalid_aas}")

pocket_model2.eval()
with torch.no_grad():
    input_ids = torch.tensor(
        [[aa_to_idx[aa] for aa in sequence]],
        dtype=torch.long,
        device=device
    )
    logits = pocket_model2(input_ids)  # [1, L]
    y_pred_probs2 = torch.sigmoid(logits).squeeze(0).cpu().numpy()

print("New prediction shape:", y_pred_probs2.shape)
print("New prob stats -> min/mean/max:",
      float(y_pred_probs2.min()),
      float(y_pred_probs2.mean()),
      float(y_pred_probs2.max()))


New prediction shape: (691,)
New prob stats -> min/mean/max: 0.47272753715515137 0.6249222755432129 0.8391643166542053
CPU times: total: 0 ns
Wall time: 1 ms


In [17]:
metrics2 = evaluate_pocket_predictions(y_true, y_pred_probs2, threshold=0.2)

print("\nMetrics with logits + pos_weight (threshold=0.2):")
for k, v in metrics2.items():
    if isinstance(v, float):
        print(f"{k}: {v:.3f}")
    else:
        print(f"{k}: {v}")



Metrics with logits + pos_weight (threshold=0.2):
ROC-AUC: 0.737
Precision: 0.009
Recall: 1.000
F1: 0.017
Pocket Coverage: 1.000
Pocket Overlap: 0.009
True pocket residues: 6
Predicted pocket residues: 691


In [18]:
for thr in [0.1, 0.2, 0.3, 0.4, 0.5]:
    p, r, f, n = metrics_at_threshold(y_true, y_pred_probs2, thr)
    print(f"thr={thr:.2f} -> P={p:.3f}, R={r:.3f}, F1={f:.3f}, predicted+={n}")


thr=0.10 -> P=0.009, R=1.000, F1=0.017, predicted+=691
thr=0.20 -> P=0.009, R=1.000, F1=0.017, predicted+=691
thr=0.30 -> P=0.009, R=1.000, F1=0.017, predicted+=691
thr=0.40 -> P=0.009, R=1.000, F1=0.017, predicted+=691
thr=0.50 -> P=0.011, R=1.000, F1=0.022, predicted+=543


In [19]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random

import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [20]:
from transformers import AutoTokenizer, AutoModel

MODEL_NAME = "Rostlab/prot_bert"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)
base_model = AutoModel.from_pretrained(MODEL_NAME).to(device)
print("Hidden size:", base_model.config.hidden_size)


Hidden size: 1024


In [21]:
import torch.nn as nn

class ProtBertPocket(nn.Module):
    def __init__(self, base_model, dropout=0.1):
        super().__init__()
        self.bert = base_model
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        seq_output = outputs.last_hidden_state      # [B, T, H]
        seq_output = self.dropout(seq_output)
        logits = self.classifier(seq_output).squeeze(-1)  # [B, T]
        return logits


In [22]:
pocket_protbert = ProtBertPocket(base_model).to(device)

# (optional) freeze encoder for cheap training
for p in pocket_protbert.bert.parameters():
    p.requires_grad = False
for p in pocket_protbert.classifier.parameters():
    p.requires_grad = True

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, pocket_protbert.parameters()),
    lr=1e-4
)


In [23]:
%%time

from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import random

# Define max token / residue lengths for ProtBert
MAX_TOKENS = 1024
MAX_RESIDUES = MAX_TOKENS - 2  # CLS + SEP take 2 slots

class BioLiPPocketProtDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_residues=None):
        """
        dataset: list of (seq, labels_np)
        """
        self.tokenizer = tokenizer
        self.max_residues = max_residues if max_residues is not None else MAX_RESIDUES

        self.data = []
        for seq, labels in dataset:
            if len(seq) != len(labels):
                continue
            if len(seq) <= self.max_residues:
                self.data.append((seq, labels))

        print(f"ProtBert dataset size (after length filter): {len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq, labels = self.data[idx]

        # ProtBert expects space-separated AAs
        seq_spaced = " ".join(list(seq))

        enc = tokenizer(
            seq_spaced,
            return_tensors="pt",
            max_length=self.max_residues + 2,  # CLS + residues + SEP
            padding="max_length",
            truncation=True
        )
        input_ids = enc["input_ids"].squeeze(0)            # [T]
        attention_mask = enc["attention_mask"].squeeze(0)  # [T]

        T = input_ids.size(0)

        labels_full = torch.zeros(T, dtype=torch.float32)
        residue_mask = torch.zeros(T, dtype=torch.bool)

        L = min(len(labels), T - 2)  # residues kept

        labels_full[1:1+L] = torch.tensor(labels[:L], dtype=torch.float32)
        residue_mask[1:1+L] = True

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels_full,
            "residue_mask": residue_mask,
            "seq_len": L,
        }

# Build dataset  (assumes `dataset` and `tokenizer` already exist)
prot_dataset = BioLiPPocketProtDataset(dataset, tokenizer, max_residues=MAX_RESIDUES)

# Subsample for training
train_subset = 1000   # keep it small to test GPU speed; bump later
BATCH_SIZE = 8        # adjust if you hit OOM

indices = list(range(len(prot_dataset)))
random.shuffle(indices)
train_indices = indices[:train_subset]

class SubsetDataset(Dataset):
    def __init__(self, base_ds, indices):
        self.base_ds = base_ds
        self.indices = indices
    def __len__(self):
        return len(self.indices)
    def __getitem__(self, idx):
        return self.base_ds[self.indices[idx]]

train_ds = SubsetDataset(prot_dataset, train_indices)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

print("Train batches:", len(train_loader))


ProtBert dataset size (after length filter): 467723
Train batches: 125
CPU times: total: 344 ms
Wall time: 338 ms


In [24]:
%%time

NUM_EPOCHS = 3  # start small to sanity-check speed

pocket_protbert.train()

for epoch in range(NUM_EPOCHS):
    epoch_loss = 0.0
    n_batches = 0

    for i, batch in enumerate(train_loader, start=1):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)            # [B, T]
        residue_mask = batch["residue_mask"].to(device)  # [B, T]

        optimizer.zero_grad()

        logits = pocket_protbert(input_ids, attention_mask)  # [B, T]

        # Only compute loss on residue positions
        logits_res = logits[residue_mask]
        labels_res = labels[residue_mask]

        loss = criterion(logits_res, labels_res)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        n_batches += 1

        if i % 20 == 0:
            print(f"Batch {i}/{len(train_loader)}")

    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, avg loss: {epoch_loss / max(n_batches,1):.4f}")


Batch 20/125
Batch 40/125
Batch 60/125
Batch 80/125
Batch 100/125
Batch 120/125
Epoch 1/3, avg loss: 0.6473
Batch 20/125
Batch 40/125
Batch 60/125
Batch 80/125
Batch 100/125
Batch 120/125
Epoch 2/3, avg loss: 0.5590
Batch 20/125
Batch 40/125
Batch 60/125
Batch 80/125
Batch 100/125
Batch 120/125
Epoch 3/3, avg loss: 0.4952
CPU times: total: 2min 48s
Wall time: 2min 51s


In [25]:
%%time

prot_dataset = BioLiPPocketProtDataset(dataset, tokenizer)

BATCH_SIZE = 4  # ProtBert is big; adjust based on GPU memory
train_subset = 5000  # number of sequences to actually use for quick fine-tuning

indices = list(range(len(prot_dataset)))
random.shuffle(indices)
train_indices = indices[:train_subset]

class SubsetDataset(Dataset):
    def __init__(self, base_ds, indices):
        self.base_ds = base_ds
        self.indices = indices
    def __len__(self):
        return len(self.indices)
    def __getitem__(self, idx):
        return self.base_ds[self.indices[idx]]

train_ds = SubsetDataset(prot_dataset, train_indices)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
print("Train batches:", len(train_loader))


ProtBert dataset size (after length filter): 467723
Train batches: 1250
CPU times: total: 328 ms
Wall time: 332 ms


In [26]:
class ProtBertPocket(nn.Module):
    def __init__(self, base_model, dropout=0.1):
        super().__init__()
        self.bert = base_model
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        seq_output = outputs.last_hidden_state      # [B, T, H]
        seq_output = self.dropout(seq_output)
        logits = self.classifier(seq_output).squeeze(-1)  # [B, T]
        return logits


In [27]:
%%time

pocket_protbert = ProtBertPocket(base_model).to(device)

# Option 1: freeze the BERT encoder, train only the head
for param in pocket_protbert.bert.parameters():
    param.requires_grad = False

# If you want to fine-tune the whole model, comment out the loop above.

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, pocket_protbert.parameters()),
    lr=1e-4
)


CPU times: total: 15.6 ms
Wall time: 5 ms


In [28]:
%%time

pocket_protbert.eval()

seq_spaced = " ".join(list(sequence))
enc = tokenizer(
    seq_spaced,
    return_tensors="pt",
    max_length=MAX_RESIDUES + 2,
    padding="max_length",
    truncation=True
)

input_ids = enc["input_ids"].to(device)
attention_mask = enc["attention_mask"].to(device)

with torch.no_grad():
    logits_q = pocket_protbert(input_ids, attention_mask)  # [1, T]
    probs_q = torch.sigmoid(logits_q).squeeze(0).cpu().numpy()  # [T]

# Extract only the residue positions (skip CLS at 0, SEP at L+1, and padding)
Lq = min(len(sequence), MAX_RESIDUES)
y_pred_probs_prot = probs_q[1:1+Lq]  # [Lq]

print("ProtBert per-residue probs shape:", y_pred_probs_prot.shape)
print("First 10 probs:", y_pred_probs_prot[:10])


ProtBert per-residue probs shape: (691,)
First 10 probs: [0.49345312 0.48102584 0.49102947 0.47661766 0.47565004 0.49132064
 0.48090282 0.5065823  0.51055634 0.4988366 ]
CPU times: total: 156 ms
Wall time: 137 ms


In [29]:
%%time

print("Check shapes:", len(y_true), len(y_pred_probs_prot))

metrics_prot = evaluate_pocket_predictions(y_true, y_pred_probs_prot, threshold=0.5)

print("=== ProtBert pocket model metrics (threshold=0.5) ===")
for k, v in metrics_prot.items():
    if isinstance(v, float):
        print(f"{k}: {v:.3f}")
    else:
        print(f"{k}: {v}")


Check shapes: 691 691
=== ProtBert pocket model metrics (threshold=0.5) ===
ROC-AUC: 0.420
Precision: 0.004
Recall: 0.167
F1: 0.008
Pocket Coverage: 0.167
Pocket Overlap: 0.004
True pocket residues: 6
Predicted pocket residues: 249
CPU times: total: 15.6 ms
Wall time: 9.53 ms


In [30]:
for thr in [0.1, 0.2, 0.3, 0.4, 0.5]:
    p, r, f, n = metrics_at_threshold(y_true, y_pred_probs_prot, thr)
    print(f"ProtBert thr={thr:.2f} -> P={p:.3f}, R={r:.3f}, F1={f:.3f}, predicted+={n}")


ProtBert thr=0.10 -> P=0.009, R=1.000, F1=0.017, predicted+=691
ProtBert thr=0.20 -> P=0.009, R=1.000, F1=0.017, predicted+=691
ProtBert thr=0.30 -> P=0.009, R=1.000, F1=0.017, predicted+=691
ProtBert thr=0.40 -> P=0.009, R=1.000, F1=0.017, predicted+=691
ProtBert thr=0.50 -> P=0.004, R=0.167, F1=0.008, predicted+=249


In [31]:
k = int(y_true.sum())
topk_idx_prot = np.argsort(-y_pred_probs_prot)[:k]

y_pred_topk_prot = np.zeros_like(y_true)
y_pred_topk_prot[topk_idx_prot] = 1

p_k = precision_score(y_true, y_pred_topk_prot, zero_division=0)
r_k = recall_score(y_true, y_pred_topk_prot, zero_division=0)
f1_k = f1_score(y_true, y_pred_topk_prot, zero_division=0)

print(f"ProtBert Top-k (k={k}) -> P={p_k:.3f}, R={r_k:.3f}, F1={f1_k:.3f}")
print("Top-k indices:", topk_idx_prot)
print("True pocket indices:", np.where(y_true == 1)[0])


ProtBert Top-k (k=6) -> P=0.000, R=0.000, F1=0.000
Top-k indices: [409  60  55 390 395 663]
True pocket indices: [120 122 191 460 464 527]


In [32]:
print("len(sequence):", len(sequence))
print("len(y_true):", len(y_true))
print("len(y_pred_probs_prot):", len(y_pred_probs_prot))

metrics_prot = evaluate_pocket_predictions(y_true, y_pred_probs_prot, threshold=0.5)
print("ProtBert ROC-AUC:", metrics_prot["ROC-AUC"])


len(sequence): 691
len(y_true): 691
len(y_pred_probs_prot): 691
ProtBert ROC-AUC: 0.4204379562043795


In [33]:
# choose your simple model probs
y_pred_probs_simple = y_pred_probs2  # or y_pred_probs

print("=== AUC comparison ===")
m_simple = evaluate_pocket_predictions(y_true, y_pred_probs_simple, threshold=0.5)
m_prot   = evaluate_pocket_predictions(y_true, y_pred_probs_prot,   threshold=0.5)
print("Simple model ROC-AUC:", m_simple["ROC-AUC"])
print("ProtBert ROC-AUC:    ", m_prot["ROC-AUC"])

thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]

print("\n=== Threshold sweep: Simple model ===")
for thr in thresholds:
    p, r, f, n = metrics_at_threshold(y_true, y_pred_probs_simple, thr)
    print(f"thr={thr:.2f} -> P={p:.3f}, R={r:.3f}, F1={f:.3f}, predicted+={n}")

print("\n=== Threshold sweep: ProtBert ===")
for thr in thresholds:
    p, r, f, n = metrics_at_threshold(y_true, y_pred_probs_prot, thr)
    print(f"thr={thr:.2f} -> P={p:.3f}, R={r:.3f}, F1={f:.3f}, predicted+={n}")


=== AUC comparison ===
Simple model ROC-AUC: 0.7369829683698297
ProtBert ROC-AUC:     0.4204379562043795

=== Threshold sweep: Simple model ===
thr=0.10 -> P=0.009, R=1.000, F1=0.017, predicted+=691
thr=0.20 -> P=0.009, R=1.000, F1=0.017, predicted+=691
thr=0.30 -> P=0.009, R=1.000, F1=0.017, predicted+=691
thr=0.40 -> P=0.009, R=1.000, F1=0.017, predicted+=691
thr=0.50 -> P=0.011, R=1.000, F1=0.022, predicted+=543

=== Threshold sweep: ProtBert ===
thr=0.10 -> P=0.009, R=1.000, F1=0.017, predicted+=691
thr=0.20 -> P=0.009, R=1.000, F1=0.017, predicted+=691
thr=0.30 -> P=0.009, R=1.000, F1=0.017, predicted+=691
thr=0.40 -> P=0.009, R=1.000, F1=0.017, predicted+=691
thr=0.50 -> P=0.004, R=0.167, F1=0.008, predicted+=249
