In [19]:
%pip install -r requirements.txt

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [20]:
import numpy as np
import pandas as pd
from Bio import SeqIO
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import h5py
import tqdm
import math
import gc # Garbage collector
import os # For creating directories
import json # For saving results
import time # For timestamping runs
import random # For seeding

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F # For Swish if needed
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torch.utils.tensorboard import SummaryWriter # For TensorBoard

from sklearn.model_selection import KFold, ParameterSampler # For CV and Random Search
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, average_precision_score

# Building Our Models

## Loading in data from HDF5 files:

In [22]:
class Embedding_retriever(Dataset): 
    def __init__(self, h5_path, protein_keys = None):
        self.h5_path = h5_path
        with h5py.File(self.h5_path, 'r') as f:
            all_keys = list(f['embeddings_folder'].keys())
            all_keys.sort()
            self.protein_keys = protein_keys if protein_keys is not None else all_keys

    def __len__(self): 
        return len(self.protein_keys) 

    def __getitem__(self, index): 
        protein_key = self.protein_keys[index]
        with h5py.File(self.h5_path, 'r') as f:
            protein_group = f['embeddings_folder'][protein_key]
            embeddings = protein_group['embeddings'][:]
            labels = protein_group['labels'][:]
            labels = labels.astype(np.float32) 

        return {
            'name': protein_key,
            'embeddings': torch.tensor(embeddings, dtype=torch.float32),
            'labels': torch.tensor(labels, dtype=torch.float32),
            'length': len(labels) 
        }

In [10]:
## testing that we can retrieve specified protein from HDF5
protein_data = Embedding_retriever('/work/jgg/exp1137-improving-b-cell-antigen-predictions-using-plms/esm2_protein_embeddings.h5', ['3b9k_B'])
protein_data[0]

{'name': '3b9k_B',
 'embeddings': tensor([[-0.0792, -0.0822,  0.0584,  ..., -0.0358,  0.2468,  0.0965],
         [ 0.2719,  0.1316, -0.1275,  ..., -0.0662,  0.0838,  0.0270],
         [ 0.0752, -0.1247, -0.3128,  ..., -0.3670, -0.0709, -0.1302],
         ...,
         [-0.0930,  0.1506,  0.3534,  ..., -0.0738, -0.3277, -0.1105],
         [ 0.0638,  0.1243,  0.2399,  ..., -0.2450, -0.2691,  0.0887],
         [ 0.0787, -0.0101,  0.3330,  ..., -0.0309, -0.1629,  0.1119]]),
 'labels': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.,
         1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1.,
         1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 

In [26]:
protein_data_esmc = Embedding_retriever('/work/jgg/exp1137-improving-b-cell-antigen-predictions-using-plms/esmc_protein_embeddings.h5', ['5ggv_Y'])
protein_data_esmc[0]

BlockingIOError: [Errno 11] Unable to synchronously open file (unable to lock file, errno = 11, error message = 'Resource temporarily unavailable')

In [18]:
print(protein_data_esmc[0]['embeddings'].shape)


BlockingIOError: [Errno 11] Unable to synchronously open file (unable to lock file, errno = 11, error message = 'Resource temporarily unavailable')

In [18]:
test_data = Embedding_retriever('/work/jgg/exp1137-improving-b-cell-antigen-predictions-using-plms/esm2_test_protein_embeddings.h5', ['7lj4_B'])
test_data[0]

{'name': '7lj4_B',
 'embeddings': tensor([[ 0.0967,  0.0548, -0.0380,  ...,  0.3126,  0.1504, -0.0625],
         [-0.0258, -0.1512,  0.0948,  ...,  0.2948, -0.0251, -0.0958],
         [-0.0322, -0.0087, -0.1973,  ...,  0.2294, -0.1310, -0.1817],
         ...,
         [-0.2181, -0.1279, -0.0083,  ...,  0.0957, -0.0185,  0.0791],
         [-0.2594, -0.0837,  0.0623,  ...,  0.0057, -0.0654, -0.1096],
         [-0.0031,  0.1199,  0.1547,  ...,  0.0080, -0.1935, -0.0970]]),
 'labels': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0.,
         0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0.,
         0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 

### Padding Collate Function

In [None]:
def collate_fn(batch):
    embeddings = [item['embeddings'] for item in batch]
    labels = [item['labels'] for item in batch]
    lengths = [item['length'] for item in batch]
    names = [item['name'] for item in batch]

    padded_embeddings = nn.utils.rnn.pad_sequence(embeddings, batch_first=True, padding_value= 0.0)
    padded_labels = nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-1)

    max_len = padded_embeddings.size(1)
    lengths_tensor = torch.tensor(lengths, dtype=torch.long)
    padding_mask  =torch.arange(max_len)[None, :] >= lengths_tensor[:, None]

    return {
        'names': names,
        'embeddings': padded_embeddings,
        'labels': padded_labels,
        'padding_mask': padding_mask,
        'lengths': lengths_tensor
    }

# Building Our Models

### DEMO MLP MODEL

In [19]:
cls = torch.nn.Sequential(
    torch.nn.Linear(embedding_size_esm2, 128),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(128, 1),
    torch.nn.Sigmoid()
)

optimizer = torch.optim.Adam(cls.parameters(), lr=1e-3)
loss_fn = torch.nn.BCELoss(reduction='none')  # Remove the space

first_10_items = list(preprocessed_data.items())[:10]

for epoch in range(5):
    total_loss = 0  # Initialize the loss counter
    
    # Extract sequences and labels first
    sequences = []
    labels = []
    for name, obs in first_10_items:
        sequences.append(obs['sequence'])
        labels.append(obs['labels'])
    
    # Now train on each sequence
    for i in range(len(sequences)):
        sequence = sequences[i]
        label = labels[i]
        
        inputs = tokenizer_650m(sequence, return_tensors="pt")  # Remove **
        with torch.no_grad():
            esm_output = model_650m(**inputs)
        embeddings = esm_output.last_hidden_state

        seq_len = len(sequence)

        predictions = cls(embeddings[0, 1:seq_len+1])
        predictions = predictions.squeeze(-1)

        target = torch.tensor(label, dtype=torch.float32)

        loss = loss_fn(predictions, target)
        mean_loss = loss.mean()

        optimizer.zero_grad()
        mean_loss.backward()
        optimizer.step()

        total_loss += mean_loss.item()

    print(f"epoch {epoch+1}, loss: {total_loss/len(sequences):.4f}")

NameError: name 'embedding_size_esm2' is not defined

In [None]:
def predict_epitopes(sequence, model, tokenizer, esm_model):
    # Tokenize sequence
    inputs = tokenizer(sequence, return_tensors="pt")
    
    # Get ESM embeddings
    with torch.no_grad():
        esm_output = esm_model(**inputs)
    embeddings = esm_output.last_hidden_state
    
    # Get predictions
    predictions = model(embeddings[0, 1:len(sequence)+1]).squeeze(-1)
    
    # Convert to probabilities
    probs = predictions.detach().numpy()
    
    # Convert to binary predictions
    binary_preds = (probs > 0.2).astype(int)
    
    return probs, binary_preds

# Example usage
test_seq = preprocessed_data['7rk1_A']['sequence']
probs, binary_preds = predict_epitopes(test_seq, cls, tokenizer_650m, model_650m)

# Visualize predictions compared to actual labels
actual_labels = preprocessed_data['7rk1_A']['labels']
print(f"Sequence length: {len(test_seq)}")
print(f"outpur {binary_preds}" )
print(f"Predicted epitope positions: {[i for i, p in enumerate(binary_preds) if p == 1]}")
print(f"Actual epitope positions: {[i for i, l in enumerate(actual_labels) if l == 1]}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as mcolors

def plot_epitope_probabilities(probs, actual_labels=None, sequence=None, threshold=0.5):
    positions = np.arange(1, len(probs) + 1)
    
    # Create a colormap for probabilities
    cmap = plt.cm.get_cmap('RdYlGn_r')
    colors = [cmap(p) for p in probs]
    
    plt.figure(figsize=(14, 6))
    bars = plt.bar(positions, probs, color=colors, width=0.8, alpha=0.7)
    
    # Add threshold line
    plt.axhline(y=threshold, color='black', linestyle='--', alpha=0.7, label=f'Threshold ({threshold})')
    
    # Add actual epitope labels if provided
    if actual_labels is not None:
        for i, label in enumerate(actual_labels):
            if label == 1:
                plt.axvspan(i+0.5, i+1.5, color='blue', alpha=0.2)
    
    # Set x-axis ticks to show position numbers
    if len(positions) > 50:
        # For longer sequences, show every 10th position
        plt.xticks(positions[::10], positions[::10])
    else:
        # For shorter sequences, show all positions
        plt.xticks(positions, positions)
    
    # Add colorbar
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(0, 1))
    sm.set_array([])
    cbar = plt.colorbar(sm)
    cbar.set_label('Probability of being an epitope')
    
    plt.xlabel('Amino Acid Position')
    plt.ylabel('Probability')
    plt.title('Epitope Prediction Probabilities')
    plt.tight_layout()
    
    return plt.gcf()

# Example usage
test_seq = preprocessed_data['7rk1_A']['sequence']
actual_labels = preprocessed_data['7rk1_A']['labels']
probs, _ = predict_epitopes(test_seq, cls, tokenizer_650m, model_650m)

# Plot the first 100 positions for better visibility
plot_window = 100
plot_epitope_probabilities(
    probs[:plot_window], 
    actual_labels[:plot_window], 
    test_seq[:plot_window],
    threshold=0.3
)
plt.show()

## Transformer Model

In [None]:
MAX_LEN = 5000

In [23]:
# --- Positional Encoding --- 
class PositionalEncoder(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len = MAX_LEN):
        super().__init__()
            #super(PositionalEncoder, self).__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model) #generating empty tensor to later populate with positional encoding values.
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # Create a column vector of token positions, which will later be used to calculate sinusoidal encodings 
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # the denominator of PE formula.

        #Slicing of both even and uneven terms (removed if/else function)
        pe[:, 0::2] = torch.sin(position * div_term) # generators tensor using sine function
        pe[:, 1::2] = torch.cos(position * div_term) # generators tensor using sine function

        self.register_buffer('pe', pe) # fixed paramter, not trainable.

    def forward(self, x):
        # apply positional encoding to input.
        x = x + self.pe[:x.size(1), :] # slicing ensures correct dimension and adds positional encoding to input.
        return self.dropout(x) # regularization to input
    
        # OBS!!!: dimension output of PE is now [batch_size, seq_len, embed_dim]

# --- SiLu (Swish) Feed Forward Network ---
class SiLuFFN(nn.Module):
    def __init__(self, embed_dim, ffn_hidden_dim, dropout=0.1):
        super().__init__()

        self.w1 = nn.Linear(embed_dim, ffn_hidden_dim, bias=True) # defines 1st layer
        self.w2 = nn.Linear(ffn_hidden_dim, embed_dim, bias=True) # defines 2nd layer
        self.dropout = nn.Dropout(dropout) # a good ol' dropout layer
        self.activation = F.silu # stores SWISH activation function

    def forward(self ,x):
        hidden = self.w1(x) #passing input through first (w1) layer    
        activated = self.activation(hidden) # applying SWISH    
        dropped = self.dropout(activated) # applying dropout
        output = self.w2(dropped) # pass through second (w2) layer

        return output # returns final processed tensor


# --- Transformer Encoder Layer with SiLU FFN ---
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, nhead, ffn_hidden_dim, dropout=0.1):
        super().__init__()

        #Sub layers
        self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout, batch_first=True) # 1. MHA
        self.ffn = SiLuFFN(embed_dim, ffn_hidden_dim, dropout) # FFN part of transformer block.T_destination

        # Layer norm & dropout
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim) # will have different weights
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout) # will have different weights

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # MHA applied in encoder
        attention_output, _ = self.self_attn(src, src, src, # these are for the Q, K and V tensors
                                             attn_mask=src_mask,
                                             key_padding_mask=src_key_padding_mask,
                                             need_weights=False)
        
        # Add & Norm (Residual Connection 1) 
        src = src + self.dropout1(attention_output) # We dropout some of our output from the attention layer
        src = self.norm1(src) # and we apply layer normalization
        
        # FF block
        ffn_output = self.ffn(src)

        # Add & Norm (Residual Connection 2) 
        src = src + self.dropout2(ffn_output) # We dropout some of our output from the FFN layer
        src = self.norm2(src) # and we apply layer normalization

        return src # this output goes to decoder block
    


# --- Main Epitope Transformer Model using SiLU ---
class EpitopeTransformer(nn.Module):
    def __init__(self, embed_dim, nhead, num_encoder_layers, ffn_hidden_dim,
                 dropout=0.1, max_len=MAX_LEN):
        
        super().__init__()
        self.embed_dim = embed_dim

        self.pos_encoder = PositionalEncoder(embed_dim, dropout, max_len) # first PE layer

        # Encoder Stack
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, nhead, ffn_hidden_dim, dropout)
            for _ in range(num_encoder_layers)
        ])

        self.norm_final = nn.LayerNorm(embed_dim) # normalized before added to next encoder layer. 
        self.output_layer = nn.Linear(embed_dim, 1) # transformed to 1 dimension
        self.init_weights() # optional, but good practice

    def init_weights(self):
        initrange = 0.1
        self.output_layer.bias.data.zero_()
        self.output_layer.bias.data.uniform_(-initrange, initrange) #we initialize between -0.1 to 0.1. Biased to be 0

    def forward(self, src, src_key_padding_mask=None): #padding mask=None as we're only in encoder layer
        src = self.pos_encoder(src) # we encode input

        output = src # we initialize output value and loop through each encoder layer passing the output from layer to layer.
        for layer in self.layers:
            output = layer(output, src_key_padding_mask=src_key_padding_mask)
        output = self.norm_final(output) # apply final norm
        output_logits = self.output_layer(output) # Apply the final output layer (Prediction Head)
        # Output shape: [batch_size, seq_len, 1]
        
        return output_logits


### Transformer Functions

In [1]:
# --- Training Function ---
def train_epoch(model, dataloader, optimizer, criterion, device, epoch, writer=None):
    model.train()
    total_loss = 0
    num_batches = len(dataloader)
    progress_bar = tqdm.tqdm(dataloader, desc=f'Epoch {epoch+1} Training', leave=False, ncols=100)
    for i, batch in enumerate(progress_bar):
        embeddings = batch['embeddings'].to(device)
        labels = batch['labels'].to(device)
        padding_mask = batch['padding_mask'].to(device)
        optimizer.zero_grad() # resets gradiants from prev batch
        outputs = model(embeddings, src_key_padding_mask=padding_mask)
        outputs = outputs.squeeze(-1)
        active_mask = (labels != -1)
        if active_mask.sum() == 0: continue
        active_logits = outputs[active_mask]
        active_labels = labels[active_mask]
        loss = criterion(active_logits, active_labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    if writer: writer.add_scalar('Loss/train_epoch', avg_loss, epoch) # Log average epoch loss
    return avg_loss


# --- Evaluation Function ---      
def evaluate(model, dataloader, criterion, device, epoch, writer=None):
    model.eval()
    total_loss = 0
    all_preds_prob = []
    all_labels_list = []
    num_batches = len(dataloader)
    with torch.no_grad():
        progress_bar = tqdm.tqdm(dataloader, desc=f"Epoch {epoch+1} Evaluating", leave=False, ncols=100)
        for batch in progress_bar:
            embeddings = batch['embeddings'].to(device)
            labels = batch['labels'].to(device)
            padding_mask = batch['padding_mask'].to(device)
            outputs = model(embeddings,src_key_padding_mask=padding_mask).squeeze(-1)
            active_mask = (labels != -1)
            if active_mask.sum() == 0: continue
            active_logits = outputs[active_mask]
            active_labels = labels[active_mask]
            loss = criterion(active_logits, active_labels)
            total_loss += loss.item()
            probs = torch.sigmoid(active_logits).cpu().numpy() # SIGMOID??? CPU???
            all_preds_prob.extend(probs)
            all_labels_list.extend(active_labels.cpu().numpy())
        
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    precision, recall, f1, auc_roc, auc_pr, auc10 = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
    if len(all_labels_list) > 0:
        all_labels_np = np.array(all_labels_list)
        all_preds_prob_np = np.array(all_preds_prob)
        if len(np.unique(all_labels_np)) < 2:
            print("Warning: Only one class present in validation fold/batch.")
            all_preds_binary = (all_preds_prob_np >= 0.5).astype(int)
            precision, recall, f1, _ = precision_recall_fscore_support(all_labels_np, all_preds_binary, average='binary', zero_division=0)
            auc_roc = 0.0
            auc10 = 0.0
            try:
                auc_pr = average_precision_score(all_labels_np, all_preds_prob_np)
            except ValueError:
                auc_pr = 0.0
        else:
            all_preds_binary = (all_preds_prob_np >= 0.5).astype(int)
            precision, recall, f1, _ = precision_recall_fscore_support(all_labels_np, all_preds_binary, average='binary', zero_division=0)
            auc_pr = average_precision_score(all_labels_np, all_preds_prob_np) # den her var lidt mystisk
            try:
                auc_roc = roc_auc_score(all_labels_np, all_preds_prob_np)
                auc10 = roc_auc_score(all_labels_np, all_preds_prob_np, max_fpr=0.1)
            except ValueError:
                print("Warning: ValueError during AUC calculation despite unqie check.")
                auc_roc = 0.0
                auc10 = 0.0

    print(f"Eval Loss: {avg_loss:.4f}, F1: {f1:.4f}, AUC-PR: {auc_pr:.4f}, AUC-ROC: {auc_roc:.4f}, AUC10: {auc10:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")

    if writer:
        writer.add_scalar('Loss/val', avg_loss, epoch)
        writer.add_scalar('Metrics/F1', f1, epoch)
        writer.add_scalar('Metrics/AUC-PR', auc_pr, epoch)
        writer.add_scalar('Metrics/AUC-ROC', auc_roc, epoch)
        writer.add_scalar('Metrics/AUC10', auc10, epoch)
        writer.add_scalar('Metrics/Precision', precision, epoch)
        writer.add_scalar('Metrics/Recall', recall, epoch)
    return avg_loss, precision, recall, f1, auc_roc, auc_pr, auc10

### ESM2

In [2]:
# --- Configuration ---

# EMBEDDING TYPE
EMBEDDING_TYPE = 'esm2' 


# MAPPING
EMBEDDING_CONFIG = {
    'esm2': {
        'h5_path': 'esm2_protein_embeddings.h5',
        'embed_dim': 1280
    },
    'esmc': {
        'h5_path': 'esmc_protein_embeddings.h5',
        'embed_dim': 960
    }
}


# PATHS AND DATA
config = EMBEDDING_CONFIG[EMBEDDING_TYPE]
H5_FILE_PATH = config['h5_path']
EMBED_DIM = config['embed_dim']

BASE_RUNS_DIR = 'transformer_runs' #overall base dir
RUN_TYPE_DIR = os.path.join(BASE_RUNS_DIR, EMBEDDING_TYPE) # subfolder in dir
MODEL_SAVE_DIR = os.path.join(RUN_TYPE_DIR, 'saved models') #subfolder with saved models
TENSORBOARD_BASE_DIR = os.path.join(RUN_TYPE_DIR, 'tensorboard')
RESULTS_FILE = os.path.join(RUN_TYPE_DIR, f'{EMBEDDING_TYPE}_hyperparam_search_results.json')


# Fixed Training Settings
N_SPLITS = 5 # number of K-Fold splits
N_EPOCHS = 10 # Max epochs per fold
BATCH_SIZE = 8 # !!!maybe adjust for GPU memory!!!!
RANDOM_SEED = 17 #bc 17 is a cool number 8-)


# Model Architecture Base (Embed dim now comes from config)
MAX_LEN = 5000 # max seq len for positional encoding # can be changed


# --- Random Search Space ---
param_dist = {
    'learning_rate': [1e-5, 5e-5, 1e-4, 5e-4], # !!!! come back to this
    'dropout': [0.1, 0.15, 0.2, 0.25],
    'num_encoder_layers': [4, 5, 6], # they had 8 in Attention is all you need
    'nhead': [4, 8],
    'ffn_hidden_dim_factor': [2, 3 ,4]
}
N_SEARCH_ITERATIONS = 10 # number of random combinations to try


# --- Device Setup ---
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'--- Running Experiment for_ {EMBEDDING_TYPE} ---')
print(f'Using device: {DEVICE}')
print(f'Embeddings Path: {H5_FILE_PATH}')
print(f'Embeddings Dimensin: {EMBED_DIM}')
print(f'Output Directory: {RUN_TYPE_DIR}')


# --- SEEDING ---
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seed(RANDOM_SEED)

# --- Creating Directories ---
os.makedirs(RUN_TYPE_DIR, exist_ok=True)
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(RUN_TYPE_DIR, exist_ok=True)

NameError: name 'os' is not defined

#### Main execution

In [None]:
# 1. Loading all protein keys (from hfd5 file)
try:
    with h5py.File(H5_FILE_PATH, 'r') as f:
        if 'embeddings_folder' not in f: raise KeyError("Not found")
        all_protein_keys = list(f['embeddings_folder'].keys())
        all_protein_keys.sort()
        all_protein_keys = np.array(all_protein_keys)
    print(f"Loaded {len(all_protein_keys)} keys from {H5_FILE_PATH}")

except FileNotFoundError:
    print("File not found")
    exit()

except KeyError as e:
    print(f"Error loading from HDF {e}")
except Exception as e:
    print(f"Unexpected error loading keys; {e}")
    exit()
    
# Hyperparameter Search Initialization
sampler = ParameterSampler(param_dist, n_iter=N_SEARCH_ITERATIONS, random_state=RANDOM_SEED)
all_trial_results = []
trial_num = 0

print(f"\n--- Starting Hyperparameter Search ({N_SEARCH_ITERATIONS} trials) for {EMBEDDING_TYPE} ---")

for params in sampler:
    trial_num += 1
    print(f"\n----- Trial {trial_num}/{N_SEARCH_ITERATIONS} ({EMBEDDING_TYPE}) -----")
    print(f"Parameters: {params}")

    # Extract params for this trial
    LEARNING_RATE = params['learning_rate']
    DROPOUT = params['dropout']
    NUM_ENCODER_LAYERS = params['num_encoder_layers']
    N_HEAD = params['nhead']
    FFN_HIDDEN_DIM_FACTOR = params['ffn_hidden_dim_factor']
    FFN_HIDDEN_DIM = int(EMBED_DIM * FFN_HIDDEN_DIM_FACTOR)

    # Check nhead validity
    if EMBED_DIM % N_HEAD != 0:
        print(f"Skipping trial: embed_dim ({EMBED_DIM}) not divisible by nhead ({N_HEAD}).")
        continue

    # 3. K-Fold Cross-Validation Loop for this set of parameters
    kf = KFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_SEED)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(all_protein_keys)):
        print(f"\n--- Fold {fold + 1}/{N_SPLITS} ---")
        set_seed(RANDOM_SEED + fold)

        # Create distinct TensorBoard log directory for this fold/trial
        # Group by trial first, then fold
        fold_log_dir = os.path.join(TENSORBOARD_BASE_DIR, f'trial_{trial_num}', f'fold_{fold+1}')
        writer = SummaryWriter(log_dir=fold_log_dir)

        train_keys = all_protein_keys[train_idx]
        val_keys = all_protein_keys[val_idx]

        # Create datasets and dataloaders
        train_dataset = Embedding_retriever(H5_FILE_PATH, protein_keys=list(train_keys))
        val_dataset = Embedding_retriever(H5_FILE_PATH, protein_keys=list(val_keys))
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=True if DEVICE == torch.device("cuda") else False)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True if DEVICE == torch.device("cuda") else False)

        # Calculate pos_weight
        print("Calculating pos_weight...")
        num_pos, num_neg = 0, 0
        # Safer iteration to calculate pos_weight
        temp_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)
        for batch in tqdm.tqdm(temp_loader, desc="Calculating pos_weight", leave=False, ncols=80):
            labels = batch['labels'] # Labels are already tensors
            active_mask = (labels != -1)
            active_labels = labels[active_mask].numpy() # Calculate on active labels
            num_pos += np.sum(active_labels == 1)
            num_neg += np.sum(active_labels == 0)
        del temp_loader # Free memory

        if num_pos == 0 or num_neg == 0:
            print(f"Warning: Fold {fold+1} - num_pos={num_pos}, num_neg={num_neg}. Using pos_weight=1.0")
            pos_weight = 1.0
        else:
            pos_weight = num_neg / num_pos
        print(f"Fold {fold+1} - pos_weight: {pos_weight:.2f}")
        pos_weight_tensor = torch.tensor([pos_weight], dtype=torch.float32).to(DEVICE)
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor) # Good for imbalanced data sets

        # Instantiate NEW model and optimizer
        model = EpitopeTransformer(
            embed_dim=EMBED_DIM,
            nhead=N_HEAD,
            num_encoder_layers=NUM_ENCODER_LAYERS,
            ffn_hidden_dim=FFN_HIDDEN_DIM,
            dropout=DROPOUT
        ).to(DEVICE)
        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

        best_val_f1 = -1
        best_fold_metrics = {}

        # 4. Epoch Loop for the Fold
        print(f"Starting training for {N_EPOCHS} epochs...")
        for epoch in range(N_EPOCHS):
            train_loss = train_epoch(model, train_loader, optimizer, criterion, DEVICE, epoch, writer)
            print(f"Epoch {epoch+1}/{N_EPOCHS} - Train Loss: {train_loss:.4f}")
            val_loss, precision, recall, f1, auc_roc, auc_pr, auc10 = evaluate(model, val_loader, criterion, DEVICE, epoch, writer)

            if f1 > best_val_f1:
                print(f"Epoch {epoch+1} - Val F1 improved ({best_val_f1:.4f} -> {f1:.4f}). Saving model...")
                best_val_f1 = f1
                best_fold_metrics = {
                    'loss': val_loss, 'precision': precision, 'recall': recall,
                    'f1': f1, 'auc_roc': auc_roc, 'auc_pr': auc_pr, 'epoch': epoch+1
                }
                # Save model to the type-specific directory
                model_save_path = os.path.join(MODEL_SAVE_DIR, f'trial_{trial_num}_fold_{fold+1}_best.pth')
                torch.save(model.state_dict(), model_save_path)

        print(f"\nBest Validation Metrics for Fold {fold + 1} (Epoch {best_fold_metrics.get('epoch', 'N/A')}):")
        print(best_fold_metrics)
        if best_fold_metrics: # Only append if a best model was found
             fold_results.append(best_fold_metrics)
        else:
             print(f"Fold {fold+1} did not yield improving metrics.")



        writer.close()
        del model, optimizer, train_loader, val_loader, train_dataset, val_dataset, criterion, pos_weight_tensor
        gc.collect()
        if DEVICE == torch.device("cuda"): torch.cuda.empty_cache()
    # --- End of Fold Loop ---


    # --- Summarize Fold Results for the Current Trial ---
    if fold_results:
        # Calculate average metrics only on folds that produced results
        avg_metrics = {key: np.mean([fold[key] for fold in fold_results if key != 'epoch'])
                       for key in fold_results[0] if key != 'epoch'}
        std_metrics = {key: np.std([fold[key] for fold in fold_results if key != 'epoch'])
                       for key in fold_results[0] if key != 'epoch'}
        print(f"\n----- Trial {trial_num} ({EMBEDDING_TYPE}) Cross-Validation Summary -----")
        print("Average Metrics across folds:")
        for key, value in avg_metrics.items():
            print(f"  Avg {key}: {value:.4f} (+/- {std_metrics[key]:.4f})")
        trial_summary = {'trial_num': trial_num, 'params': params, 'avg_metrics': avg_metrics, 'std_metrics': std_metrics, 'individual_fold_metrics': fold_results}
    else:
        print(f"----- Trial {trial_num} ({EMBEDDING_TYPE}) had no valid fold results -----")
        trial_summary = {'trial_num': trial_num, 'params': params, 'avg_metrics': {}, 'std_metrics': {}, 'individual_fold_metrics': []}

    all_trial_results.append(trial_summary)


    # Save results incrementally
    try:
        with open(RESULTS_FILE, 'w') as f:
            json.dump(all_trial_results, f, indent=4, default=lambda x: float(x) if isinstance(x, (np.float32, np.float64)) else x) # Handle numpy floats
        print(f"Trial {trial_num} results saved to {RESULTS_FILE}")
    except IOError as e:
        print(f"Error saving results to {RESULTS_FILE}: {e}")
    except TypeError as e:
        print(f"Error serializing results to JSON: {e}. Check data types.")

# --- End of Trial Loop ---


# --- Final Hyperparameter Search Summary ---
print(f"\n--- Hyperparameter Search Complete for {EMBEDDING_TYPE} ---")
if all_trial_results:
    # Filter out trials with no avg_metrics before finding max
    valid_trials = [t for t in all_trial_results if t.get('avg_metrics')]
    if valid_trials:
        best_trial = max(valid_trials, key=lambda x: x.get('avg_metrics', {}).get('f1', -1))
        print("\nBest Trial Found:")
        print(f"  Trial Number: {best_trial['trial_num']}")
        print(f"  Parameters: {best_trial['params']}")
        print(f"  Avg F1 Score: {best_trial.get('avg_metrics', {}).get('f1', 'N/A'):.4f}")
        print(f"  Avg AUC-PR: {best_trial.get('avg_metrics', {}).get('auc_pr', 'N/A'):.4f}")
    else:
        print("No trials yielded valid average metrics.")

    print(f"\nFull results saved in: {RESULTS_FILE}")
    print(f"TensorBoard logs in: {TENSORBOARD_BASE_DIR}")
    print(f"Best model checkpoints saved in: {MODEL_SAVE_DIR}")
else:
    print("No successful trials were completed.")

print("\nScript Finished.")