# Linear Model for structure prediction

In [2]:
# Needed to import modules from helpers
import sys
import os
import pandas as pd
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import matplotlib.cm as cm
import random
import gzip
import urllib.request # TODO: Outsource data download to helper script

import esm
from io import StringIO
from Bio import SeqIO
from Bio import AlignIO

current_dir = os.getcwd()
# Gehe einen Ordner nach oben
project_root = os.path.abspath(os.path.join(current_dir, '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

from helpers import helper

from sklearn.metrics import roc_auc_score
from sklearn.metrics.pairwise import cosine_similarity

# Load ESM model to GPU

In [3]:
# kleineres Modell 'esm2_t6_8M_UR50D' zum testen 
# verwendet 36-layer Transformer trained on UniParc" (ca. 670 Mio. Parameter ) im Paper.
model, alphabet = esm.pretrained.esm1_t6_43M_UR50S()

if torch.cuda.is_available():
    model = model.cuda()
    print("Modell auf GPU geladen.")

# Get Data

In [4]:
# Download, Parse and Filter SCOP data
helper.download_no_requests()
df_scop = helper.parse_and_filter_scop()

if not df_scop.empty:
    print(f"Total sequences retained: {len(df_scop)}")
    print(df_scop.head())

File already exists. Skipping download.
Parsing and filtering sequences...
Total sequences retained: 12261
  domain_id                                           sequence class fold  \
0   d1dlwa_  slfeqlggqaavqavtaqfyaniqadatvatffngidmpnqtnkta...     a  a.1   
1   d2gkma_  gllsrlrkrepisiydkiggheaievvvedffvrvladdqlsaffs...     a  a.1   
2   d1ngka_  ksfydavggaktfdaivsrfyaqvaedevlrrvypeddlagaeerl...     a  a.1   
3   d2bkma_  eqwqtlyeaiggeetvaklveafyrrvaahpdlrpifpddltetah...     a  a.1   
4   d4i0va_  aslyeklggaaavdlavekfygkvladervnrffvntdmakqkqhq...     a  a.1   

  superfamily   family  
0       a.1.1  a.1.1.1  
1       a.1.1  a.1.1.1  
2       a.1.1  a.1.1.1  
3       a.1.1  a.1.1.1  
4       a.1.1  a.1.1.1  


In [5]:
from sklearn.model_selection import GroupKFold

# Wir erstellen ein Dictionary, das für jedes Level ('family', 'superfamily', 'fold')
# die entsprechenden 5-Fold Indizes speichert.
levels = ['family', 'superfamily', 'fold']
partitions = {}

for level in levels:
    # Initialisiere GroupKFold mit 5 Splits
    gkf = GroupKFold(n_splits=5)
    
    # Die "Gruppen" sind die Labels in der jeweiligen Spalte (z.B. 'a.1.1.1' für family)
    groups = df_scop[level].values
    
    # Erstelle die Liste der (train_index, test_index) Tupel
    # list() materialisiert den Generator, damit wir die Indizes später wiederverwenden können
    partitions[level] = list(gkf.split(df_scop, groups=groups))

print("Partitionierung abgeschlossen.")
print(f"Verfügbare Split-Level: {list(partitions.keys())}")
print(f"Anzahl Folds pro Level: {len(partitions['fold'])}")

Partitionierung abgeschlossen.
Verfügbare Split-Level: ['family', 'superfamily', 'fold']
Anzahl Folds pro Level: 5


In [None]:
# partitions['family'][4] 5. Fold für 'family' Level
# partitions['family'][4][0]  -> train indices
# partitions['family'][4][1]  -> test indices
# These here are the non intersecting train and test indices for the 5th fold of 'family' level
# print(partitions['family'][4][0])
# print(partitions['family'][4][1])

[]


In [None]:
df_scop.iloc[partitions['family'][4][0]] # Train set for 5th fold of 'family' level

Unnamed: 0,domain_id,sequence,class,fold,superfamily,family
5,d1asha_,anktrelcmkslehakvdtsnearqdgidlykhmfenypplrkyfk...,a,a.1,a.1.1,a.1.1.2
6,d2dc3a_,eelseaerkavqamwarlyancedvgvailvrffvnfpsakqyfsq...,a,a.1,a.1.1,a.1.1.2
7,d4hswa_,gfkqdiatirgdlrtyaqdiflaflnkypderryfknyvgksdqel...,a,a.1,a.1.1,a.1.1.2
8,d1ecaa_,lsadqistvqasfdkvkgdpvgilyavfkadpsimakftqfagkdl...,a,a.1,a.1.1,a.1.1.2
9,d1x9fd_,eclvteslkvklqwasafghahervafglelwrdiiddhpeikapf...,a,a.1,a.1.1,a.1.1.2
...,...,...,...,...,...,...
12255,d2ggva_,tdmwiertadiswesdaeitgsservdvrldddgnfqlmndpga,g,g.96,g.96.1,g.96.1.1
12256,d3s2ra_,ginpeirknedkvvdsvvvtelsknitpycrcwrsgtfplcdgscv...,g,g.97,g.97.1,g.97.1.0
12257,d4c3hd_,sattlntpvvihatqlpqhvstdevlqflesfidekeniidsttmn...,g,g.98,g.98.1,g.98.1.1
12258,d3jb9e_,rlrtsrtkrppdgfdeieptliefqdrmrqientmgkgtktemlap...,g,g.99,g.99.1,g.99.1.1


In [31]:
import pandas as pd
from Bio.PDB import PDBList, PDBParser, DSSP
import os

# 1. Hilfsfunktion zum Laden der Struktur und Extrahieren von DSSP
def get_dssp_labels(pdb_id, chain_id, pdb_dir="./pdb_files"):
    """
    Lädt PDB (falls nötig), berechnet DSSP und gibt den 8-Class-String zurück.
    """
    pdbl = PDBList()
    parser = PDBParser(QUIET=True)
    
    # PDB-Datei herunterladen (oder lokalen Pfad nutzen)
    # SCOP IDs wie 'd1asha_' basieren auf PDB ID '1ash'
    ent_filename = pdbl.retrieve_pdb_file(pdb_id, pdir=pdb_dir, file_format="pdb")
    
    try:
        structure = parser.get_structure(pdb_id, ent_filename)
        model = structure[0]
        
        # DSSP berechnen
        # Hinweis: Stelle sicher, dass 'mkdssp' oder 'dssp' im System-Pfad ist
        dssp = DSSP(model, ent_filename)
        
        # DSSP Keys sind Tupel (Chain, ResID). Wir filtern nach unserer Chain.
        # Die Eigenschaft für Sekundärstruktur ist Index 2 in der DSSP-Liste
        sec_structure = []
        
        for key in dssp.keys():
            if key[0] == chain_id:
                ss_code = dssp[key][2]
                # Leere Rückgaben oder '-' werden oft als Coil ('C') gewertet
                if ss_code == '-' or ss_code == ' ':
                    ss_code = 'C' 
                sec_structure.append(ss_code)
                
        return "".join(sec_structure)
        
    except Exception as e:
        print(f"Fehler bei {pdb_id}: {e}")
        return None

# 2. Anwendung auf deinen DataFrame
# Deine Domain ID 'd1asha_' ist im SCOP-Format.
# Konvention: d + PDB_ID (4 chars) + Chain (1 char) + _ (optional)
# Beispiel: d1asha_ -> PDB: 1ash, Chain: A

def parse_scop_id(domain_id):
    # Einfache Heuristik für Standard SCOP IDs
    # d1asha_ -> 1ash, A
    clean_id = domain_id[1:] # Entferne 'd'
    pdb_id = clean_id[:4]
    chain_id = clean_id[4]
    if chain_id == '_': 
        # Manchmal bedeutet _ "keine Chain" oder "Chain A", 
        # in SCOP oft Chain A oder die einzige vorhandene.
        # Hier muss man eventuell aufpassen.
        chain_id = 'A' 
    return pdb_id, chain_id.upper()

# Beispiel-Loop (Vorsicht: Das Herunterladen dauert!)
# Erstelle erst den Ordner für PDBs
os.makedirs("./pdb_files", exist_ok=True)

# Neue Spalte erstellen
df_scop['ss8_label'] = ""

for index, row in df_scop.iterrows():
    dom_id = row['domain_id']
    pdb_id, chain_id = parse_scop_id(dom_id)
    
    print(f"Bearbeite {dom_id} -> PDB: {pdb_id}, Chain: {chain_id}")
    
    ss_string = get_dssp_labels(pdb_id, chain_id)
    
    if ss_string:
        # WICHTIG: Die Länge des DSSP-Strings muss mit der Sequenz übereinstimmen!
        # PDB-Dateien haben oft fehlende Residues. Du musst eventuell alignen.
        # Für den Anfang speichern wir einfach das Ergebnis.
        df_scop.at[index, 'ss8_label'] = ss_string

print(df_scop.head())

Bearbeite d1dlwa_ -> PDB: 1dlw, Chain: A
Downloading PDB structure '1dlw'...
Fehler bei 1dlw: File type must be PDB, mmCIF or DSSP
Bearbeite d2gkma_ -> PDB: 2gkm, Chain: A
Downloading PDB structure '2gkm'...
Fehler bei 2gkm: File type must be PDB, mmCIF or DSSP
Bearbeite d1ngka_ -> PDB: 1ngk, Chain: A
Downloading PDB structure '1ngk'...
Fehler bei 1ngk: File type must be PDB, mmCIF or DSSP
Bearbeite d2bkma_ -> PDB: 2bkm, Chain: A
Downloading PDB structure '2bkm'...
Fehler bei 2bkm: File type must be PDB, mmCIF or DSSP
Bearbeite d4i0va_ -> PDB: 4i0v, Chain: A
Downloading PDB structure '4i0v'...
Desired structure not found or download failed. '4i0v': HTTP Error 403: Forbidden
Fehler bei 4i0v: 'NoneType' object has no attribute 'readlines'
Bearbeite d1asha_ -> PDB: 1ash, Chain: A
Downloading PDB structure '1ash'...
Fehler bei 1ash: File type must be PDB, mmCIF or DSSP
Bearbeite d2dc3a_ -> PDB: 2dc3, Chain: A
Downloading PDB structure '2dc3'...
Fehler bei 2dc3: File type must be PDB, mmCIF

KeyboardInterrupt: 

In [None]:
# Extract the domains_id's and the actual domain sequences from the DataFrame
labels = df_scop['domain_id'].tolist()
seqs = [s.upper() for s in df_scop['sequence'].tolist()]

In [None]:
# 1. Embeddings mit dem TRAINIERTEN Modell holen
print(f"Berechne TRAINIERTE Embeddings für {len(seqs)} Sequenzen...")

# Schritt 1: Hidden Representations holen
token_reps_trained, batch_strs_trained = helper.get_hidden_representations(model, alphabet, labels, seqs)

# Schritt 2: Mean Pooling durchführen
emb_trained = helper.get_protein_embedding(token_reps_trained, batch_strs_trained)

print("Fertig!")
print(f"Erstellte Embeddings: {len(emb_trained)}")

Berechne TRAINIERTE Embeddings für 100 Sequenzen...
Processing 100 sequences in batches of 1...
Fertig!
Erstellte Embeddings: 100


In [None]:
# 2. Get embeddings before pretraining (natürlich ist hier ein Problem, dass wir den seed nicht kennen alleine deshalb werden sich hier Sachen vom original Paper unterscheiden)
print("Calculating UNTRAINED Embeddings...")
untrained_model = helper.randomize_model(model)

if torch.cuda.is_available(): 
    untrained_model = untrained_model.cuda()

# Schritt 1: Calculate final hidden representations
token_reps_untrained, batch_strs_untrained = helper.get_hidden_representations(untrained_model, alphabet, labels, seqs)

# Schritt 2: Calculate Untrained Embeddings
emb_untrained = helper.get_protein_embedding(token_reps_untrained, batch_strs_untrained)

Calculating UNTRAINED Embeddings...
Processing 200 sequences in batches of 1...


### Motivation

One of the oldest assumptions of sequencing biology: The underlying structure of a protein is a hidden variable that influences the patterns observed in sequence data. And vice versa the patterns observed in the sequence data influence the dtructure of a protein.
In short: Structural information is encoded in the sequences.

- secondary structure decides local choice and order of sequences
- tertiary decides over long range choice and order of sequences

Underlying general Hypothesis: Since 3d struture is encoded in the sequences. It is a logical hypothesis that via unsupervised learning the model learns to decode the hidden information about the secondary and tertiary strucure of the protein implicitly. 

In the paper they start by using simple linear models on top of the learned respresentations to see whether or not even simple models can infer about structure using the learned representations. If they are able to do that that would be very impressive.

enabling a direct inspection of the structural content of representations.

By comparing representations of the Transformer before and after pretraining, we can identify the information that emerges as a result of the unsupervised learning

fivefold cross validation experiment to study generalization of structural information at the family, superfamily, and fold level.
-	For each of the three levels, we construct a dataset of 15,297 protein structures using the SCOPe database.

### Hypothesis 1:


### Hypothesis 2:
final hidden representations of a sequence encode information about the family it belongs to.

### Method:

- Get Dataset (Pfam)
- compare the distribution of cosine similarities of representations between pairs of residues that are aligned in the family’s MSA background distribution of cosine similarities between unaligned pairs of residues.
- Compare with distributions befor learning (We need the embeddings befor pretraining (randomize model))


In [None]:
# %%
# 1. Bereinigung: Filterung von Sequenzen, bei denen Länge(Seq) != Länge(DSSP)
# und Mapping der 8 Klassen auf Integers.

valid_indices = []
ss8_mapping = {
    'H': 0, 'B': 1, 'E': 2, 'G': 3, 'I': 4, 'T': 5, 'S': 6, 
    'C': 7, '-': 7, ' ': 7
}

# Liste für die bereinigten Daten
cleaned_data = []

# Wir gehen davon aus, dass token_reps_trained in der gleichen Reihenfolge vorliegt wie df_scop
# Hinweis: Das erste Token (Start-Token) und letzte Token (End-Token) vom ESM-Modell müssen beachtet werden.
# Die helper-Funktion gibt oft schon die gekürzte Repräsentation zurück. Wir prüfen die Längen.

print("Starte Datenbereinigung und Mapping...")

count_mismatch = 0
for idx, (index, row) in enumerate(df_scop.iterrows()):
    dssp_str = row['ss8_label']
    
    # Überspringe leere DSSP Ergebnisse
    if not dssp_str or len(dssp_str) == 0:
        continue
        
    # Hole das Embedding für dieses Protein
    # token_reps_trained ist eine Liste von Tensoren.
    rep = token_reps_trained[idx] # Shape: (Seq_Len, Hidden_Dim)
    
    # Check lengths
    # Manchmal enthalten Reps noch Start/End Tokens, das muss man prüfen.
    # Hier nehmen wir an, rep hat die Länge der Sequenz.
    if rep.shape[0] != len(dssp_str):
        # Versuch: ESM fügt oft Start/End Token hinzu. 
        # Wenn rep = len(seq) + 2, schneiden wir ab.
        if rep.shape[0] == len(dssp_str) + 2:
            rep = rep[1:-1]
        else:
            count_mismatch += 1
            continue

    # Wenn Längen jetzt stimmen, Label encoden
    try:
        label_indices = [ss8_mapping.get(char, 7) for char in dssp_str] # 7 ist fallback für Coil
        label_tensor = torch.tensor(label_indices, dtype=torch.long)
        
        # Speichern: (Embedding_Tensor, Label_Tensor, Original_Index)
        cleaned_data.append({
            'embedding': rep.cpu(), # Verschiebe in CPU RAM um GPU für Training freizuhalten
            'label': label_tensor,
            'df_index': index
        })
        valid_indices.append(index)
        
    except Exception as e:
        print(f"Error at index {index}: {e}")
        continue

print(f"Bereinigung fertig.")
print(f"Valid Samples: {len(cleaned_data)}")
print(f"Verworfene Samples (Length Mismatch): {count_mismatch}")

# Erstelle ein Mapping von df_index zu Listen-Index in cleaned_data für schnellen Zugriff
df_idx_to_list_idx = {item['df_index']: i for i, item in enumerate(cleaned_data)}

In [None]:
# %%
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Einfaches Lineares Modell (Logistic Regression in PyTorch)
class LinearProbe(nn.Module):
    def __init__(self, input_dim, num_classes=8):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        return self.linear(x)

# Dataset Wrapper für effizientes Laden
class ProteinResidueDataset(Dataset):
    def __init__(self, data_list):
        """
        data_list: Liste von Dicts {'embedding': tensor, 'label': tensor}
        Wir flachen hier NICHT alles sofort ab, um Speicher zu sparen, 
        sondern geben pro Item ein Protein zurück.
        Der DataLoader muss einen custom collate_fn nutzen um Batches zu bauen.
        """
        self.data = data_list
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        return self.data[idx]['embedding'], self.data[idx]['label']

def collate_residues(batch):
    # Stapelt alle Residues aus mehreren Proteinen in einen großen Tensor (Batch)
    embeddings = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    
    # Concatenate along dimension 0
    embeddings_cat = torch.cat(embeddings, dim=0)
    labels_cat = torch.cat(labels, dim=0)
    
    return embeddings_cat, labels_cat

def train_linear_probe(train_data, test_data, input_dim, device='cuda', epochs=5, batch_size=32):
    model = LinearProbe(input_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    train_dataset = ProteinResidueDataset(train_data)
    # Shuffle ist wichtig für Training
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_residues)
    
    # Training Loop
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_X, batch_y in train_loader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
    # Evaluation
    model.eval()
    correct = 0
    total = 0
    
    test_dataset = ProteinResidueDataset(test_data)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_residues)
    
    with torch.no_grad():
        for batch_X, batch_y in test_loader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            outputs = model(batch_X)
            _, predicted = torch.max(outputs.data, 1)
            total += batch_y.size(0)
            correct += (predicted == batch_y).sum().item()
            
    accuracy = 100 * correct / total
    return accuracy

In [None]:
# %%
# Durchführung des Experiments auf dem 'fold' Level (strengster Split)
target_level = 'fold'
folds = partitions[target_level]

accuracies_trained = []
# Optional: Hier könntest du auch accuracies_untrained tracken

input_dim = cleaned_data[0]['embedding'].shape[1] # z.B. 320 oder 768
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Starte 5-Fold Cross Validation auf Level: {target_level}")
print(f"Input Dimension: {input_dim}")

for fold_idx, (train_indices, test_indices) in enumerate(folds):
    print(f"\n--- Fold {fold_idx + 1} / 5 ---")
    
    # 1. Daten selektieren basierend auf den Indizes aus GroupKFold
    # Wir müssen prüfen, ob die Indizes in unseren 'cleaned_data' noch existieren
    train_subset = []
    test_subset = []
    
    for idx in train_indices:
        if idx in df_idx_to_list_idx:
            train_subset.append(cleaned_data[df_idx_to_list_idx[idx]])
            
    for idx in test_indices:
        if idx in df_idx_to_list_idx:
            test_subset.append(cleaned_data[df_idx_to_list_idx[idx]])
            
    print(f"Train samples: {len(train_subset)}, Test samples: {len(test_subset)}")
    
    if len(train_subset) == 0 or len(test_subset) == 0:
        print("Warnung: Leeres Set nach Filterung. Überspringe Fold.")
        continue

    # 2. Trainieren & Evaluieren
    acc = train_linear_probe(train_subset, test_subset, input_dim, device=device, epochs=3)
    accuracies_trained.append(acc)
    print(f"Fold {fold_idx + 1} Accuracy (Q8): {acc:.2f}%")

# %%
# Ergebnis visualisieren
mean_acc = np.mean(accuracies_trained)
std_acc = np.std(accuracies_trained)

print("\nResultate:")
print(f"Mean Q8 Accuracy: {mean_acc:.2f}% (+/- {std_acc:.2f})")

plt.figure(figsize=(8, 6))
plt.bar(['Trained ESM'], [mean_acc], yerr=[std_acc], capsize=10, color='skyblue', alpha=0.7)
plt.ylabel('Q8 Accuracy (%)')
plt.title(f'Secondary Structure Prediction (Linear Probe) - {target_level} split')
plt.ylim(0, 100)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()