# Linear Model for structure prediction

In [1]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score
import esm
from esm.data import ESMStructuralSplitDataset
import matplotlib.pyplot as plt

current_dir = os.getcwd()
# Gehe einen Ordner nach oben
project_root = os.path.abspath(os.path.join(current_dir, '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

from helpers import helper

### Config

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 4 
LR = 1e-4
NUM_EPOCHS = 3

# Load model

In [3]:
# Run this Block to load the pretrained model
model, alphabet = esm.pretrained.esm1_t6_43M_UR50S()
model = model.to(DEVICE)
model.eval()

ProteinBertModel(
  (embed_tokens): Embedding(35, 768, padding_idx=1)
  (layers): ModuleList(
    (0-5): 6 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (self_attn_layer_norm): ESM1LayerNorm()
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (fc2): Linear(in_features=3072, out_features=768, bias=True)
      (final_layer_norm): ESM1LayerNorm()
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=72, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (embed_positions): SinusoidalPositionalEmbedding()
)

In [4]:
# Run this Block to load the randomized model
untrained_model = helper.randomize_model(model)
model = model.to(DEVICE)
model.eval()

ProteinBertModel(
  (embed_tokens): Embedding(35, 768, padding_idx=1)
  (layers): ModuleList(
    (0-5): 6 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (self_attn_layer_norm): ESM1LayerNorm()
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (fc2): Linear(in_features=3072, out_features=768, bias=True)
      (final_layer_norm): ESM1LayerNorm()
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=72, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (embed_positions): SinusoidalPositionalEmbedding()
)

In [5]:
EMBED_DIM = model.args.embed_dim   
LAYER_IDX = model.num_layers

Define stuff we need later

### Motivation

One of the oldest assumptions of sequencing biology: The underlying structure of a protein is a hidden variable that influences the patterns observed in sequence data. And vice versa the patterns observed in the sequence data influence the dtructure of a protein.
In short: Structural information is encoded in the sequences.

- secondary structure decides local choice and order of sequences
- tertiary decides over long range choice and order of sequences

Underlying general Hypothesis: Since 3d struture is encoded in the sequences. It is a logical hypothesis that via unsupervised learning the model learns to decode the hidden information about the secondary and tertiary strucure of the protein implicitly. 

In the paper they start by using simple linear models on top of the learned respresentations to see whether or not even simple models can infer about structure using the learned representations. If they are able to do that that would be very impressive.

enabling a direct inspection of the structural content of representations.

By comparing representations of the Transformer before and after pretraining, we can identify the information that emerges as a result of the unsupervised learning

fivefold cross validation experiment to study generalization of structural information at the family, superfamily, and fold level.
-	For each of the three levels, we construct a dataset of 15,297 protein structures using the SCOPe database.


# Hypothesis 1: Linear Logistic Regression for secondary structure prediction
Via fivefold cross validation experiment

### Define the DSSP Mapping which Encodes the Secondary structures as indexes such that we can do logistic regression

Welches DSSP mapping wird verwendet?

- H: α-Helix (Alpha helix) 
- E: Extended Strand (Participates in a β-ladder/sheet) 
- T: Turn (Hydrogen-bonded turn) 
- S: Bend 
- G: 310​-Helix (Three-ten helix) 
- B: β-Bridge (Residue in an isolated β-bridge) 
- I: π-Helix (Pi helix) 
- C or -: Coil (Irregular/Loop/None)
- P: theta-Helix

In [6]:
SS8_MAPPING = {
    'H': 0, 'E': 1, 'T': 2, 'S': 3, 
    'G': 4, 'B': 5, 'I': 6, '-': 7, 'C': 7 #, 'P': 8
}

### Define the hat model which will be put on the embeddings 

This will must be trained with CrossEntropy then its equivalent to Logistic regression

In [7]:
class LinearProbeSSP(nn.Module):
    def __init__(self, input_dim, num_classes=8):
        super().__init__()
        self.proj = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.proj(x)

Fun fact the collate_fn function is a bridge between the batch represented as a List of tupels bridging to a batch representation as one X y matrix, so it generates these matrices resolving issue like varying len(x1) =/= len(x2) for out model to handle [(x1,y1), (x2,y2), ...] -> X, y 

In [8]:
class LinearProbeCollator:
    def __init__(self, alphabet):
        self.alphabet = alphabet
        self.batch_converter = alphabet.get_batch_converter()
        self.ss8_mapping = {
            'H': 0, 'E': 1, 'T': 2, 'S': 3, 
            'G': 4, 'B': 5, 'I': 6, '-': 7, 'C': 7
        }
        self.ignore_index = -100

    def _encode_ssp(self, ssp_str, max_len):
        # Map string to indices
        indices = [self.ss8_mapping.get(c, 7) for c in ssp_str]
        
        # Wrap with ignore_index for <cls> and <eos> tokens
        # The ESM model input is: [<cls>, seq..., <eos>, <pad>...]
        padded_labels = [self.ignore_index] + indices + [self.ignore_index]
        
        # --- ROBUSTNESS FIX ---
        # 1. Truncate if the label is longer than the batch's max_len
        #    (Handles cases where batch_converter truncated the sequence)
        if len(padded_labels) > max_len:
            padded_labels = padded_labels[:max_len]
        
        # 2. Pad if the label is shorter than the batch's max_len
        padding_needed = max_len - len(padded_labels)
        if padding_needed > 0:
            padded_labels.extend([self.ignore_index] * padding_needed)
            
        return torch.tensor(padded_labels, dtype=torch.long)

    def __call__(self, batch):
        # Filter None
        batch = [item for item in batch if item is not None]
        
        # 1. Prepare inputs for ESM converter: List of (id, seq_string)
        raw_inputs = [(str(i), item['seq']) for i, item in enumerate(batch)]
        
        # 2. Convert inputs: returns labels, sequences, and the PADDED tokens tensor
        _, _, tokens = self.batch_converter(raw_inputs)
        
        # 3. Prepare targets
        # We MUST use the tokens size as the ground truth for dimensions
        max_len = tokens.size(1) 
        target_labels = []
        
        for item in batch:
            # Get label string
            ssp_str = item.get('ssp') 
            if ssp_str is None: ssp_str = item.get('label')
            
            # Encode with strict length enforcement
            target_labels.append(self._encode_ssp(ssp_str, max_len))
            
        return {
            'input_ids': tokens,
            'labels': torch.stack(target_labels)
        }

# Re-instantiate the collator with the fix
collator = LinearProbeCollator(alphabet)

# Train Loop for every split level for 5 folds (Will train 15 models)

In [None]:
# Dict to store results
final_results = {level: [] for level in ['superfamily']}# , 'family', 'fold']}

criterion = nn.CrossEntropyLoss(ignore_index=-100)

for split_level in ['superfamily', 'family', 'fold']:
    print(f"\n{'='*20}\nStarting Split Level: {split_level}\n{'='*20}")
    
    # List to track accuracies for this specific split level (for the current loop)
    fold_accuracies = []
    
    for fold in ['0']: #, '1', '2', '3', '4']:
        print(f"  > Fold {fold}...")

        # 1. Prepare Datasets
        train_ds = ESMStructuralSplitDataset(
            split_level=split_level, 
            cv_partition=fold, 
            split='train', 
            root_path='./data',
            download=True
        )

        valid_ds = ESMStructuralSplitDataset(
            split_level=split_level, 
            cv_partition=fold, 
            split='valid', 
            root_path='./data',
            download=True
        )

        # 2. Prepare DataLoaders
        train_loader = DataLoader(
            train_ds, 
            batch_size=BATCH_SIZE, 
            shuffle=True, 
            collate_fn=collator,
            num_workers=0
        )
        
        valid_loader = DataLoader(
            valid_ds, 
            batch_size=BATCH_SIZE, 
            shuffle=False, 
            collate_fn=collator,
            num_workers=0
        )

        # 3. Initialize model
        probe = LinearProbeSSP(input_dim=EMBED_DIM).to(DEVICE)
        optimizer = optim.Adam(probe.parameters(), lr=LR)

        # 4. Training Loop
        for epoch in range(NUM_EPOCHS):
            # --- TRAIN ---
            probe.train()
            for batch in tqdm(train_loader, desc=f"    Epoch {epoch+1} Train", leave=False):
                if batch is None: continue
                
                tokens = batch['input_ids'].to(DEVICE)
                labels = batch['labels'].to(DEVICE)

                print(tokens.shape, labels.shape)
                
                with torch.no_grad():
                    results = model(tokens, repr_layers=[LAYER_IDX], return_contacts=False)
                token_embeddings = results["representations"][LAYER_IDX]
                
                optimizer.zero_grad()
                logits = probe(token_embeddings)
                loss = criterion(logits.view(-1, 8), labels.view(-1))
                loss.backward()
                optimizer.step()

        # --- VALIDATION (After final epoch) ---
        # We collect ALL targets and preds to use your specific snippet
        probe.eval()
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for batch in tqdm(valid_loader, desc=f"    Validating Fold {fold}", leave=False):
                if batch is None: continue

                tokens = batch['input_ids'].to(DEVICE)
                labels = batch['labels'].to(DEVICE)
                
                results = model(tokens, repr_layers=[LAYER_IDX], return_contacts=False)
                token_embeddings = results["representations"][LAYER_IDX]
                
                logits = probe(token_embeddings)
                
                # Get predictions
                batch_preds = torch.argmax(logits, dim=-1)
                
                # --- MASKING ---
                # We must filter out the padding/-100 tokens to get valid accuracy
                mask = (labels != -100)
                
                # Extract only valid positions and move to CPU numpy
                valid_preds = batch_preds[mask].cpu().numpy()
                valid_labels = labels[mask].cpu().numpy()
                
                all_preds.extend(valid_preds)
                all_targets.extend(valid_labels)

        # --- YOUR SNIPPET ---
        # Calculate Q8 Accuracy for this fold
        acc = accuracy_score(all_targets, all_preds)
        fold_accuracies.append(acc)
        print(f"Fold {fold} Accuracy: {acc:.4f}")
        
    # Store the results for this split level
    final_results[split_level] = fold_accuracies


Starting Split Level: superfamily
  > Fold 0...
Files already downloaded and verified
Files already downloaded and verified


                                                           

KeyboardInterrupt: 

In [None]:
# Prepare data for plotting
labels = list(final_results.keys())
data = list(final_results.values())

plt.figure(figsize=(10, 6))

# Create boxplot
# patch_artist=True allows us to color the boxes
bplot = plt.boxplot(data, 
                    tick_labels=labels, 
                    patch_artist=True,
                    medianprops=dict(color="black", linewidth=1.5))

# Styling
colors = ['lightblue', 'lightgreen', 'lightcoral']
for patch, color in zip(bplot['boxes'], colors):
    patch.set_facecolor(color)

# Add jittered scatter points to show individual fold performance
for i, accs in enumerate(data):
    x = np.random.normal(i + 1, 0.04, size=len(accs)) # Add small jitter to x-axis
    plt.plot(x, accs, 'r.', alpha=0.6)

plt.title('ESM Linear Probe Q8 Accuracy by Split Level')
plt.ylabel('Q8 Accuracy')
plt.xlabel('Structural Split Level')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.ylim(0, 1.0) # Set reasonable y-limits (0 to 100%)

# Print statistics table
print(f"{'Split Level':<15} | {'Mean Acc':<10} | {'Std Dev':<10}")
print("-" * 40)
for level, accs in final_results.items():
    print(f"{level:<15} | {np.mean(accs):.4f}     | {np.std(accs):.4f}")

plt.show()

# Hypothesis 2: Linear Binary Map prediction for contact point prediction
Via fivefold cross validation experiment

### Define the hat model which will be put on the embeddings 

Two seperate linear Projections which are combined via dot product 

In [10]:
class LinearContactProbe(nn.Module):
    """
    Implements the linear contact probe described in Rives et al. (ESM-1b paper).
    Uses TWO SEPARATE linear projections (W1, W2) and computes their dot product.
    """
    def __init__(self, input_dim, projection_dim=128):
        super().__init__()
        # The paper specifies "two separate linear projections"
        self.proj1 = nn.Linear(input_dim, projection_dim)
        self.proj2 = nn.Linear(input_dim, projection_dim)

    def forward(self, x):
        # x shape: [Batch, SeqLen, EmbedDim]
        
        # Project separately
        z1 = self.proj1(x) # [Batch, SeqLen, ProjDim]
        z2 = self.proj2(x) # [Batch, SeqLen, ProjDim]
        
        # Compute Dot Product: z1 * z2^T
        # Shape: [Batch, SeqLen, SeqLen]
        logits = torch.bmm(z1, z2.transpose(1, 2))
        
        # Symmetrize the output
        # Physical contacts are symmetric (C_ij = C_ji). 
        # Although the projections are separate, we enforce symmetry on the prediction.
        logits = (logits + logits.transpose(1, 2)) / 2
        
        return logits

In [11]:
# --- COLLATOR ---
class ContactProbeCollator:
    def __init__(self, alphabet):
        self.alphabet = alphabet
        self.batch_converter = alphabet.get_batch_converter()

    def __call__(self, batch):
        batch = [item for item in batch if item is not None]
        
        # 1. Inputs
        raw_inputs = [(str(i), item['seq']) for i, item in enumerate(batch)]
        _, _, tokens = self.batch_converter(raw_inputs)
        
        # 2. Contact Maps
        max_len = tokens.size(1)
        batch_contacts = []
        
        for item in batch:
            seq_len = len(item['seq'])
            dist_map = item['dist'] # Shape (L, L)
            
            # Contact definition: distance < 8.0 Angstroms
            contacts = (dist_map < 8.0).astype(np.float32)
            
            # Pad to [max_len, max_len]
            padded_contacts = np.full((max_len, max_len), -1.0, dtype=np.float32)
            # Offset by 1 for <cls>
            padded_contacts[1:seq_len+1, 1:seq_len+1] = contacts
            
            batch_contacts.append(torch.from_numpy(padded_contacts))
            
        return {
            'input_ids': tokens,
            'contacts': torch.stack(batch_contacts)
        }

contact_collator = ContactProbeCollator(alphabet)

In [12]:
def compute_precision_at_l(logits, labels, seq_len):
    """
    Computes Precision of Top-L predictions for long-range contacts (|i-j| >= 24).
    """
    # Exclude special tokens
    valid_logits = logits[1:seq_len+1, 1:seq_len+1]
    valid_labels = labels[1:seq_len+1, 1:seq_len+1]
    
    # Mask for long-range contacts
    l = valid_logits.shape[0]
    idx = torch.arange(l, device=logits.device)
    i, j = torch.meshgrid(idx, idx, indexing='ij')
    mask = torch.abs(i - j) >= 24
    
    masked_logits = valid_logits[mask]
    masked_labels = valid_labels[mask]
    
    if masked_labels.numel() == 0: return 0.0
    
    # Top-L (where L is sequence length)
    k = min(l, masked_logits.numel())
    if k == 0: return 0.0

    _, top_indices = torch.topk(masked_logits, k)
    top_labels = masked_labels[top_indices]
    
    return (top_labels.sum() / k).item()

In [14]:
contact_results = {level: [] for level in ['superfamily']}# , 'family', 'fold']}
contact_criterion = nn.BCEWithLogitsLoss(reduction='none') 

for split_level in ['superfamily', 'family', 'fold']:
    print(f"\n{'='*20}\nStarting Contact Split Level: {split_level}\n{'='*20}")
    
    fold_precisions = []
    
    for fold in ['0']:# , '1', '2', '3', '4']:
        print(f"  > Fold {fold}...")

        # 1. Dataset
        train_ds = ESMStructuralSplitDataset(split_level=split_level, cv_partition=fold, split='train', root_path='./data', download=True)
        valid_ds = ESMStructuralSplitDataset(split_level=split_level, cv_partition=fold, split='valid', root_path='./data', download=True)
        
        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=contact_collator, num_workers=0)
        valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=contact_collator, num_workers=0)

        # 2. Model
        probe = LinearContactProbe(input_dim=EMBED_DIM, projection_dim=128).to(DEVICE)
        optimizer = optim.Adam(probe.parameters(), lr=LR)

        # 3. Train
        for epoch in range(NUM_EPOCHS):
            probe.train()
            for batch in tqdm(train_loader, desc=f"    Epoch {epoch+1} Train", leave=False):
                if batch is None: continue
                
                tokens = batch['input_ids'].to(DEVICE)
                contacts = batch['contacts'].to(DEVICE)
                
                with torch.no_grad():
                    results = model(tokens, repr_layers=[LAYER_IDX], return_contacts=False)
                token_embeddings = results["representations"][LAYER_IDX]
                
                optimizer.zero_grad()
                logits = probe(token_embeddings)
                
                loss_mask = (contacts != -1)
                loss = contact_criterion(logits, contacts)
                
                if loss_mask.sum() > 0:
                    masked_loss = loss[loss_mask].mean()
                    masked_loss.backward()
                    optimizer.step()

        # 4. Validate
        probe.eval()
        precisions = []
        with torch.no_grad():
            for batch in tqdm(valid_loader, desc=f"    Validating Fold {fold}", leave=False):
                if batch is None: continue
                tokens = batch['input_ids'].to(DEVICE)
                contacts = batch['contacts'].to(DEVICE)
                
                results = model(tokens, repr_layers=[LAYER_IDX], return_contacts=False)
                token_embeddings = results["representations"][LAYER_IDX]
                
                logits = probe(token_embeddings)
                
                for i in range(tokens.size(0)):
                    # Determine seq length (tokens != pad(1)) - 2 (cls, eos)
                    seq_len = (tokens[i] != 1).sum().item() - 2
                    p = compute_precision_at_l(logits[i], contacts[i], seq_len)
                    precisions.append(p)

        avg_p = np.mean(precisions)
        fold_precisions.append(avg_p)
        print(f"Fold {fold} Top-L Precision: {avg_p:.4f}")

    contact_results[split_level] = fold_precisions


Starting Contact Split Level: superfamily
  > Fold 0...
Files already downloaded and verified
Files already downloaded and verified


    Epoch 1 Train:   0%|          | 0/3008 [00:00<?, ?it/s]

                                                                     

KeyboardInterrupt: 

In [None]:
labels = list(contact_results.keys())
data = list(contact_results.values())

plt.figure(figsize=(10, 6))
bplot = plt.boxplot(data, tick_labels=labels, patch_artist=True, medianprops=dict(color="black"))
colors = ['lightblue', 'lightgreen', 'lightcoral']
for patch, color in zip(bplot['boxes'], colors):
    patch.set_facecolor(color)

for i, accs in enumerate(data):
    x = np.random.normal(i + 1, 0.04, size=len(accs))
    plt.plot(x, accs, 'r.', alpha=0.6)

plt.title('ESM Linear Contact Top-L Precision by Split Level')
plt.ylabel('Top-L Long-Range Precision')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.ylim(0, 1.0)
plt.show()

print(f"{'Split Level':<15} | {'Mean Prec':<10} | {'Std Dev':<10}")
print("-" * 40)
for level, accs in contact_results.items():
    print(f"{level:<15} | {np.mean(accs):.4f}     | {np.std(accs):.4f}")