In [16]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/ephod-dataset/pHopt_data.csv


# # SequenceHacking: LSTM-RLAT for Enzyme pH Prediction
#
# This notebook trains a model combining an LSTM feature extractor with the Residual Light Attention (RLAT) architecture (adapted from EpHod) to predict enzyme optimal pH from amino acid sequences.
#
# **Steps:**
# 1.  **Setup:** Import libraries.
# 2.  **Configuration:** Set hyperparameters and file paths. **Upload your dataset to Kaggle first!**
# 3.  **Module Definitions:** Define LSTM, RLAT utilities, Combined Model, Dataset, Loss, and Training functions.
# 4.  **Training:** Run the main training loop.
# 5.  **Evaluation (Optional):** Load the best model and evaluate on the test set.

In [17]:
# =============================================================================
# 1. Setup Cell: Imports
# =============================================================================
import numpy as np
import pandas as pd
import time
import json
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm # Use notebook tqdm

# Imports potentially needed by trainutils (adjust if needed)
from scipy.stats import spearmanr, pearsonr
from scipy.ndimage import convolve1d, gaussian_filter1d
from sklearn import metrics

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")

PyTorch Version: 2.5.1+cu124
CUDA Available: True
CUDA Device Name: Tesla P100-PCIE-16GB


In [18]:
# =============================================================================
# 2. Configuration Cell: Set Parameters
# =============================================================================

# --- Data Paths ---
# IMPORTANT: Upload your dataset (e.g., pHopt_data.csv) to Kaggle first.
# Then, find its path under /kaggle/input/your-dataset-name/
# Example: If you uploaded it as "ephod-data", the path might be:
# TARGET_DATA_PATH = '/kaggle/input/ephod-data/pHopt_data.csv'
TARGET_DATA_PATH = '/kaggle/input/ephod-dataset/pHopt_data.csv' # <<< CHANGE THIS

# --- Model Hyperparameters ---
LSTM_EMB_DIM = 128       # Embedding dimension for LSTM input
LSTM_HIDDEN_DIM = 256    # Hidden dimension per direction for BiLSTM (Output will be 2 * this)
LSTM_LAYERS = 2          # Number of BiLSTM layers
LSTM_DROPOUT = 0.2       # Dropout rate for LSTM
RLAT_KERNEL_SIZE = 7     # Kernel size for RLAT convolutions
RLAT_DROPOUT = 0.3       # Dropout rate for RLAT dense layers (tune this)
RLAT_RES_BLOCKS = 4      # Number of residual blocks in RLAT
RLAT_ACTIVATION = 'elu'  # Activation function for RLAT

# --- Training Parameters ---
LEARNING_RATE = 1e-4       # Initial learning rate (tune this)
L2_REG = 1e-5            # L2 regularization (weight decay)
BATCH_SIZE = 32            # Batch size for training (adjust based on GPU memory)
EPOCHS = 200               # Maximum number of training epochs
SAMPLE_WEIGHT_METHOD = 'LDS_inv_sqrt' # Method for sample weighting (e.g., None, bin_inv, LDS_inv_sqrt)
REDUCE_LR_PATIENCE = 10      # Patience for ReduceLROnPlateau scheduler
STOP_PATIENCE = 30         # Patience for early stopping

# --- Infrastructure ---
NUM_WORKERS = 2            # Number of workers for DataLoader (use 2 or 4 on Kaggle)
SAVEDIR = '/kaggle/working/' # Directory to save models and logs (Kaggle writable directory)
MODEL_NAME = 'LSTM_RLAT_v1' # Name for saved model files
RANDOM_SEED = 42           # Random seed for reproducibility

# Derived config
LSTM_OUTPUT_DIM = LSTM_HIDDEN_DIM * 2

# --- Create save directory ---
if not os.path.exists(SAVEDIR):
    os.makedirs(SAVEDIR)

In [19]:
# =============================================================================
# 3. LSTM Module Definition Cell (from lstm_module.py)
# =============================================================================

# Define Vocabulary (Example - ensure it matches protein space + special tokens)
DEFAULT_AAS = 'ACDEFGHIKLMNPQRSTVWY'
PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'
VOCAB = [PAD_TOKEN] + list(DEFAULT_AAS) + [UNK_TOKEN]
AA_TO_ID = {aa: i for i, aa in enumerate(VOCAB)}
VOCAB_SIZE = len(VOCAB)
PAD_ID = AA_TO_ID[PAD_TOKEN]

class LSTMFeatureExtractor(nn.Module):
    """LSTM model to generate per-residue features."""
    def __init__(self,
                 vocab_size=VOCAB_SIZE,
                 embedding_dim=LSTM_EMB_DIM, # Use config
                 lstm_hidden_dim=LSTM_HIDDEN_DIM, # Use config
                 num_lstm_layers=LSTM_LAYERS, # Use config
                 dropout=LSTM_DROPOUT, # Use config
                 random_seed=RANDOM_SEED): # Use config
        super().__init__()
        _ = torch.manual_seed(random_seed)

        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.lstm_hidden_dim = lstm_hidden_dim
        self.num_lstm_layers = num_lstm_layers
        self.dropout = dropout
        self.output_dim = lstm_hidden_dim * 2 # BiLSTM concatenates

        self.embedding = nn.Embedding(
            num_embeddings=self.vocab_size,
            embedding_dim=self.embedding_dim,
            padding_idx=PAD_ID
        )
        self.embed_dropout = nn.Dropout(dropout)
        self.lstm = nn.LSTM(
            input_size=self.embedding_dim,
            hidden_size=self.lstm_hidden_dim,
            num_layers=self.num_lstm_layers,
            bidirectional=True,
            batch_first=True,
            dropout=self.dropout if self.num_lstm_layers > 1 else 0
        )

    def forward(self, input_ids):
        # input_ids: [batch_size, seq_len]
        embedded = self.embedding(input_ids)
        # embedded: [batch_size, seq_len, embedding_dim]
        embedded = self.embed_dropout(embedded)
        lstm_out, _ = self.lstm(embedded)
        # lstm_out: [batch_size, seq_len, lstm_hidden_dim * 2]
        features_out = lstm_out.transpose(1, 2)
        # features_out: [batch_size, output_dim, seq_len]
        return features_out

# --- Helper functions ---
def tokenize_sequence(sequence, aa_to_id_map=AA_TO_ID):
    """Converts an amino acid sequence string to a list of token IDs."""
    processed_seq = sequence.upper() # Example processing
    # Handle common non-standard chars by replacing with UNK
    for char in ['B', 'J', 'O', 'U', 'Z', 'X', '*']:
         processed_seq = processed_seq.replace(char, UNK_TOKEN)
    return [aa_to_id_map.get(aa, AA_TO_ID[UNK_TOKEN]) for aa in processed_seq]

def pad_sequences(sequences_tokenized, max_len=1024, pad_value=PAD_ID):
    """Pads a list of tokenized sequences to the same length."""
    padded_sequences = []
    masks = []
    actual_lengths = []
    for seq in sequences_tokenized:
        seq_len = len(seq)
        truncated_seq = seq[:max_len] # Truncate if longer
        seq_len = len(truncated_seq) # Update length after potential truncation
        padding_len = max_len - seq_len
        padded_seq = truncated_seq + [pad_value] * padding_len
        mask = [1] * seq_len + [0] * padding_len

        padded_sequences.append(padded_seq)
        masks.append(mask)
        actual_lengths.append(seq_len)

    return torch.tensor(padded_sequences, dtype=torch.long), \
           torch.tensor(masks, dtype=torch.int32), \
           torch.tensor(actual_lengths, dtype=torch.long)

print("LSTM Module and Helpers Defined.")


LSTM Module and Helpers Defined.


In [20]:
# %% [code]
# =============================================================================
# 4. Train Utils & RLAT Model Definition Cell
# =============================================================================
# --- Contains functions from trainutils.py and necessary RLAT definitions ---
# --- from nn_models.py provided by the user ---

# Imports potentially needed (already in Cell 1, but good practice)
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import spearmanr, pearsonr
from scipy.ndimage import convolve1d, gaussian_filter1d
from sklearn import metrics

# --- RLAT Components (Copied from nn_models.py) ---

def torchActivation(activation='elu'):
    '''Return an activation function from torch.nn'''
    if activation == 'relu':
        return nn.ReLU()
    elif activation == 'leaky_relu':
        return nn.LeakyReLU()
    elif activation == 'elu':
        return nn.ELU()
    elif activation == 'selu':
        return nn.SELU()
    elif activation == 'gelu':
        return nn.GELU()
    else:
        raise ValueError(f"Unsupported activation: {activation}")

class ResidualDense(nn.Module):
    '''A single dense layer with residual connection'''
    def __init__(self, dim=2560, dropout=0.1, activation='elu', random_seed=0):
        super(ResidualDense, self).__init__()
        _ = torch.manual_seed(random_seed)
        self.dense = nn.Linear(dim, dim)
        self.batchnorm = nn.BatchNorm1d(dim)
        self.activation = torchActivation(activation)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x0 = x
        x = self.dense(x)
        x = self.batchnorm(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = x0 + x
        return x

class LightAttention(nn.Module):
    '''Convolution model with attention to learn pooled representations from embeddings'''
    def __init__(self, dim=1280, kernel_size=7, random_seed=0):
        super(LightAttention, self).__init__()
        _ = torch.manual_seed(random_seed)
        # Ensure kernel_size is odd for 'same' padding calculation
        if kernel_size % 2 == 0:
             print(f"Warning: Even kernel_size ({kernel_size}) used in LightAttention. Using floor for padding.")
        samepad = kernel_size // 2
        self.values_conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=samepad)
        self.weights_conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=samepad)
        self.softmax = nn.Softmax(dim=-1) # Apply softmax along the sequence length dimension

    def forward(self, x, mask=None):
        # x shape: [batch_size, dim, seq_len]
        # mask shape: [batch_size, seq_len], 0 for pad, 1 for real
        if mask is None:
            mask = torch.ones(x.shape[0], x.shape[2], dtype=torch.int32, device=x.device)

        # Ensure mask is on the same device as input x
        mask = mask.to(x.device)

        # Calculate values and weights
        values = self.values_conv(x)
        weights = self.weights_conv(x)

        # Apply mask: Set masked positions to a very small number (-1e6 or -inf)
        # before softmax for weights and before max pooling for values.
        # Mask needs to be broadcastable: [batch_size, 1, seq_len]
        mask_expanded = mask.unsqueeze(1) # Shape: [batch_size, 1, seq_len]

        # Mask values before max pooling
        values_masked_for_max = values.masked_fill(mask_expanded == 0, -float('inf')) # Use -inf for max

        # Mask weights before softmax
        weights_masked_for_softmax = weights.masked_fill(mask_expanded == 0, -float('inf')) # Use -inf for softmax

        # Calculate softmax attention weights
        attention_weights = self.softmax(weights_masked_for_softmax) # Shape: [batch_size, dim, seq_len]

        # Mask values again before weighted sum (optional but safe: prevents NaN if value was -inf)
        # Using the original values tensor here. Masking ensures padded positions don't contribute.
        values_masked_for_sum = values.masked_fill(mask_expanded == 0, 0.0) # Use 0 for sum

        # Calculate attention-weighted sum pooling
        # Element-wise multiply values and attention weights, then sum over sequence length
        # Sum over the last dimension (seq_len)
        x_sum = torch.sum(values_masked_for_sum * attention_weights, dim=-1) # Shape: [batch_size, dim]

        # Calculate max pooling over sequence length
        # Using values masked appropriately for max pooling
        x_max, _ = torch.max(values_masked_for_max, dim=-1) # Shape: [batch_size, dim]
        # Handle cases where all values were masked (result is -inf), replace with 0
        x_max = torch.where(torch.isinf(x_max), torch.zeros_like(x_max), x_max)


        # Concatenate sum-pooled and max-pooled representations
        x_pooled = torch.cat([x_sum, x_max], dim=1) # Shape: [batch_size, 2 * dim]

        return x_pooled, attention_weights # Return pooled representation and weights

class ResidualLightAttention(nn.Module):
    '''Model consisting of light attention followed by residual dense layers'''
    def __init__(self, dim=1280, kernel_size=9, dropout=0.5,
                 activation='relu', res_blocks=4, random_seed=0):
        super(ResidualLightAttention, self).__init__()
        torch.manual_seed(random_seed)
        self.light_attention = LightAttention(dim, kernel_size, random_seed)
        # The input to BatchNorm and Dense layers is 2*dim due to concatenation in LightAttention
        self.batchnorm = nn.BatchNorm1d(2 * dim)
        self.dropout = nn.Dropout(dropout)
        self.residual_dense = nn.ModuleList()
        for i in range(res_blocks):
            self.residual_dense.append(
                ResidualDense(2 * dim, dropout, activation, random_seed) # Pass 2*dim here
                )
        self.output = nn.Linear(2 * dim, 1) # Output layer takes 2*dim input

    def forward(self, x, mask=None):
        # x shape: [batch_size, dim, seq_len]
        # mask shape: [batch_size, seq_len]
        x_pooled, weights = self.light_attention(x, mask) # x_pooled shape: [batch_size, 2 * dim]
        x = self.batchnorm(x_pooled)
        x = self.dropout(x)
        hidden_embedding = x # Store embedding before output layer if needed
        for layer in self.residual_dense:
            hidden_embedding = layer(hidden_embedding) # Pass through residual blocks

        y = self.output(hidden_embedding) # Final prediction shape: [batch_size, 1]
        # No need to flatten if output is already [batch_size, 1]
        # y = y.flatten() # Use flatten() only if shape is different, e.g. [batch_size]

        # Return in the expected list format [prediction, final_hidden_embedding, attention_weights]
        return [y, hidden_embedding, weights]

print("RLAT Components Defined.")

# --- Utility Functions from trainutils.py ---
# (Keep the label_distribution_smoothing, get_sample_weights, performance functions here as before)
# ... (paste trainutils functions here again) ...

def label_distribution_smoothing(y, bins=None, ks=5, sigma=2, normalize=True):
    """
    Return a smoothed label distribution derived by convolving a symetric kernel
    to the empirical label distribution. If bins is None, split the data (y) into bins
    such that each bin corresponds to 1.0 pH unit. Otherwise if bins is an integer, split
    bins into as many bins as is specified.
    See the paper,
    Yang, Zha, Chen, et al, 2021. Delving into deep imbalanced regression.
    Code adapted from https://github.com/YyzHarry/imbalanced-regression
    """
    y = np.asarray(y)
    if y.size == 0: # Handle empty input
        return np.array([])

    min_y, max_y = np.min(y), np.max(y)

    if bins is None:
        bins = int(np.ceil(max_y - min_y)) + 1 # Ensure bins cover range
        bins = max(1, bins) # Ensure at least one bin
    else:
        bins = max(1, int(bins)) # Ensure positive integer

    # Ensure range covers all values, handle constant case
    if min_y == max_y:
        range_min = min_y - 0.5
        range_max = max_y + 0.5
    else:
        range_min, range_max = min_y, max_y

    bin_freqs, bin_borders = np.histogram(y, range=(range_min, range_max), bins=bins)

    # Find bin index for each y value using searchsorted
    bin_indices = np.searchsorted(bin_borders[1:-1], y, side='right')

    # Compute kernel window
    if ks <= 0:
         kernel_window = np.array([1.0])
    else:
        half_ks = (ks - 1) // 2
        base_kernel = np.array([0.] * half_ks + [1.] + [0.] * half_ks)
        kernel_window = gaussian_filter1d(base_kernel, sigma=sigma)
        # Normalize kernel sum to 1
        if np.sum(kernel_window) > 0:
            kernel_window /= np.sum(kernel_window)
        else: # Handle case where sigma is very large, kernel becomes flat
             kernel_window = np.ones_like(base_kernel) / len(base_kernel)


    # Derive Kernel estimation using convolution
    bin_kde = convolve1d(np.array(bin_freqs, dtype=float), weights=kernel_window, mode='constant', cval=0.0)

    # Map KDE values back to original samples
    epsilon = 1e-12 # Avoid division by zero
    y_kde = np.array([bin_kde[idx] for idx in bin_indices]) + epsilon

    # Normalize KDE so minimum density corresponds to weight 1 before inversion
    min_kde_val = np.min(y_kde)
    if normalize and min_kde_val > 0:
        y_kde = y_kde / min_kde_val
    elif normalize:
        print("Warning: Minimum KDE value is zero or negative, skipping normalization.")

    return y_kde

def get_sample_weights(ydata, method='bin_inv', bin_borders=[5,9]):
    """
    Return an array of sample weights computed with different methods.
    """
    assert method in ['None', 'bin_inv', 'bin_inv_sqrt', 'LDS_inv', 'LDS_inv_sqrt',
                      'LDS_extreme', None]

    ydata = np.asarray(ydata)
    if ydata.size == 0:
        return np.array([])
    weights = np.ones(len(ydata))

    if method == 'None' or method is None:
        pass

    elif method in ['bin_inv', 'bin_inv_sqrt']:
        y_binned = np.digitize(ydata, bin_borders)
        bin_class, bin_freqs = np.unique(y_binned, return_counts=True)
        inv_freq_dict = dict(zip(bin_class, 1 / (bin_freqs + 1e-9)))
        weights = np.array([inv_freq_dict.get(value, 1.0) for value in y_binned]) # Use .get for safety

    elif method in ['LDS_inv', 'LDS_inv_sqrt', 'LDS_extreme']:
        effdist = label_distribution_smoothing(ydata, bins=100, ks=5, sigma=2)
        weights = 1.0 / effdist
        if method == 'LDS_extreme':
            relevance = np.logical_or(ydata <= bin_borders[0], ydata >= bin_borders[-1]).astype(float)
            relevance = relevance * (1 - 0.5) + 0.5
            weights = weights * relevance

    if method in ['bin_inv_sqrt', 'LDS_inv_sqrt']:
        # Ensure non-negative before sqrt
        weights = np.sqrt(np.maximum(weights, 0))

    # Normalize so weights have a mean of 1, handle potential NaN/inf/zero mean
    weights[~np.isfinite(weights)] = 0.0 # Set non-finite weights to 0
    mean_weight = np.mean(weights)
    if mean_weight > 1e-9:
         weights = weights / mean_weight
    else:
        print("Warning: Mean of weights is close to zero, using equal weights.")
        weights = np.ones(len(ydata)) # Fallback

    return weights


def performance(ytrue, ypred, weights=None, bins=[5,9]):
    '''Return a dictionary of performance metrics evaluated on predictions'''
    perf = {}
    ytrue, ypred = np.asarray(ytrue), np.asarray(ypred)

    if ytrue.size == 0 or ypred.size == 0: # Handle empty inputs
        return {'rmse': float('inf'), 'r2': 0.0, 'rho': 0.0, 'r': 0.0, 'mcc': 0.0, 'f1score': 0.0, 'auc': 0.0}


    if weights is None:
        weights = np.ones_like(ytrue)
    weights = np.asarray(weights)
    weights[~np.isfinite(weights)] = 0.0
    weights = np.maximum(weights, 0.0)

    # Filter out samples where true value is NaN if any
    valid_idx = ~np.isnan(ytrue)
    ytrue = ytrue[valid_idx]
    ypred = ypred[valid_idx]
    weights = weights[valid_idx]

    if ytrue.size == 0: # Handle case where all true values were NaN
         return {'rmse': float('inf'), 'r2': 0.0, 'rho': 0.0, 'r': 0.0, 'mcc': 0.0, 'f1score': 0.0, 'auc': 0.0}

    # Replace NaN predictions with mean of true values for metric calculation
    ypred_mean = np.nanmean(ytrue) # Mean of valid true values
    ypred = np.nan_to_num(ypred, nan=ypred_mean)


    # Normalize weights if sum is positive
    sum_weights = np.sum(weights)
    if sum_weights <= 1e-9:
        print("Warning: Sum of weights is zero in performance calculation. Using equal weights.")
        weights = np.ones_like(weights) / len(weights)


    # Correlation
    try:
        perf['rho'] = float(spearmanr(ytrue, ypred)[0]) if len(np.unique(ytrue)) > 1 and len(np.unique(ypred)) > 1 else 0.0
        perf['r'] = float(pearsonr(ytrue, ypred)[0]) if len(np.unique(ytrue)) > 1 and len(np.unique(ypred)) > 1 else 0.0
    except ValueError:
        perf['rho'] = 0.0
        perf['r'] = 0.0

    # Sample-weighted Regression Metrics
    try:
        perf['rmse'] = float(np.sqrt(metrics.mean_squared_error(ytrue, ypred, sample_weight=weights)))
        perf['r2'] = float(metrics.r2_score(ytrue, ypred, sample_weight=weights))
    except ValueError:
        perf['rmse'] = float('inf')
        perf['r2'] = 0.0

    # Classification performance of binned data
    try:
        ytrue_binned = np.digitize(ytrue, bins)
        ypred_binned = np.digitize(ypred, bins)
        present_classes = sorted(np.unique(ytrue_binned))

        if len(present_classes) > 1 : # Need multiple classes for MCC, F1, AUC
             perf['mcc'] = float(metrics.matthews_corrcoef(ytrue_binned, ypred_binned, sample_weight=weights))
             f1score = float(metrics.f1_score(ytrue_binned, ypred_binned, sample_weight=weights, average='weighted', zero_division=0))
             # AUC requires scores, not direct multi-class calculation from binary predictions this way
             # Using one-vs-rest approach
             auc_scores = []
             for cls in present_classes:
                 ytrue_cls = (ytrue_binned == cls).astype(int)
                 ypred_cls = (ypred_binned == cls).astype(int)
                 if len(np.unique(ytrue_cls)) > 1: # Check if class is present
                     auc_scores.append(metrics.roc_auc_score(ytrue_cls, ypred_cls, sample_weight=weights))
                 else:
                      auc_scores.append(0.5) # Assign neutral score if only one class present
             perf['auc'] = np.mean(auc_scores) if auc_scores else 0.0
             perf['f1score'] = f1score

        else: # Only one class present
            perf['mcc'] = 0.0
            perf['f1score'] = 0.0 # Or 1.0 if predictions are perfect? Usually 0 for single class.
            perf['auc'] = 0.5 # Undefined, often set to 0.5


    except ValueError as e:
        print(f"Error calculating classification metrics: {e}")
        perf['mcc'] = 0.0
        perf['f1score'] = 0.0
        perf['auc'] = 0.0

    return perf

print("Training Utilities Defined.")

RLAT Components Defined.
Training Utilities Defined.


In [21]:
# =============================================================================
# 5. Combined Model Definition Cell (from combined_model.py)
# =============================================================================
# Import RLAT from the ephod package IF INSTALLED, otherwise paste definition in Cell 4
# Option 1: If ephod package is installed or accessible in Kaggle env
# try:
#     from ephod.training.nn_models import ResidualLightAttention
#     print("Imported ResidualLightAttention from ephod package.")
# except ImportError:
#     print("Could not import from ephod package. Make sure RLAT is defined in Cell 4.")
#     # Fallback to placeholder if needed - requires pasting the RLAT code in cell 4
#     if 'ResidualLightAttention' not in globals():
#          ResidualLightAttention = # ... Define placeholder or raise error ...

# Option 2: Assume RLAT code was pasted into Cell 4 (RECOMMENDED FOR NOTEBOOK)
# Make sure the class 'ResidualLightAttention' is defined above in Cell 4.
if 'ResidualLightAttention' not in globals():
    raise NameError("ResidualLightAttention class definition not found. Paste it into Cell 4.")
else:
    print("Found ResidualLightAttention definition.")


class SequenceTopHModel(nn.Module):
    """
    Combined model: LSTM feature extractor followed by RLAT for pH prediction.
    Designed to be trained end-to-end.
    """
    def __init__(self,
                 # LSTM args
                 vocab_size=VOCAB_SIZE, # Use global vocab size
                 lstm_embedding_dim=LSTM_EMB_DIM,
                 lstm_hidden_dim=LSTM_HIDDEN_DIM,
                 num_lstm_layers=LSTM_LAYERS,
                 lstm_dropout=LSTM_DROPOUT,
                 # RLAT args (must match LSTM output)
                 rlat_kernel_size=RLAT_KERNEL_SIZE,
                 rlat_dropout=RLAT_DROPOUT,
                 rlat_activation=RLAT_ACTIVATION,
                 rlat_res_blocks=RLAT_RES_BLOCKS,
                 random_seed=RANDOM_SEED):
        super().__init__()
        _ = torch.manual_seed(random_seed)

        # 1. LSTM Feature Extractor
        self.lstm_extractor = LSTMFeatureExtractor(
            vocab_size=vocab_size,
            embedding_dim=lstm_embedding_dim,
            lstm_hidden_dim=lstm_hidden_dim,
            num_lstm_layers=num_lstm_layers,
            dropout=lstm_dropout,
            random_seed=random_seed
        )
        lstm_output_dim = self.lstm_extractor.output_dim # Get actual output dim (e.g., 512)

        # 2. RLAT Head
        # CRITICAL: Initialize RLAT with the output dimension of the LSTM
        self.rlat_head = ResidualLightAttention(
            dim=lstm_output_dim, # Match LSTM output!
            kernel_size=rlat_kernel_size,
            dropout=rlat_dropout,
            activation=rlat_activation,
            res_blocks=rlat_res_blocks,
            random_seed=random_seed # Pass seed if RLAT accepts it
        )

    def forward(self, input_ids, attention_mask):
        """
        Args:
            input_ids (torch.Tensor): Batch of token IDs [batch_size, seq_len]
            attention_mask (torch.Tensor): Batch of masks [batch_size, seq_len] (1=real, 0=pad)

        Returns:
            torch.Tensor: Predicted pH values [batch_size]
        """
        # 1. Get LSTM features
        # Input: [batch_size, seq_len]
        # Output: [batch_size, lstm_output_dim, seq_len]
        lstm_features = self.lstm_extractor(input_ids)

        # 2. Predict pH using RLAT head
        # Input: [batch_size, lstm_output_dim, seq_len], [batch_size, seq_len]
        # Output list: [y_pred, hidden_embedding, attention_weights] (Ensure RLAT returns list)
        # The mask needs to be boolean or float depending on RLAT implementation
        # Assuming mask needs to be [batch_size, seq_len]
        rlat_output = self.rlat_head(lstm_features, attention_mask.bool()) # Pass mask
        y_pred = rlat_output[0].squeeze(-1) # Get the final pH prediction and remove trailing dim if any

        return y_pred # Only return prediction for training loss

    def get_num_params(self):
        """Helper to count parameters."""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

print("Combined LSTM-RLAT Model Defined.")

Found ResidualLightAttention definition.
Combined LSTM-RLAT Model Defined.


In [22]:

# =============================================================================
# 6. Dataset & Collate Function Cell
# =============================================================================
class SequencepHDataset(Dataset):
    def __init__(self, dataframe, sample_weight_method=SAMPLE_WEIGHT_METHOD): # Use config
        if not isinstance(dataframe, pd.DataFrame):
             raise ValueError("Input 'dataframe' must be a pandas DataFrame.")
        if 'Sequence' not in dataframe.columns or 'pHopt' not in dataframe.columns:
             raise ValueError("DataFrame must contain 'Sequence' and 'pHopt' columns.")

        self.sequences = dataframe['Sequence'].values
        self.labels = dataframe['pHopt'].values.astype(np.float32)
        # Store accessions if present, otherwise generate dummy ones
        self.accessions = dataframe['Accession'].values if 'Accession' in dataframe.columns else [f"Seq_{i}" for i in range(len(dataframe))]


        # Pre-calculate sample weights
        if sample_weight_method != 'None' and sample_weight_method is not None:
            print(f"Calculating sample weights using method: {sample_weight_method}")
            # Ensure labels are passed correctly to weighting function
            self.weights = get_sample_weights(self.labels, method=sample_weight_method)
            self.weights = self.weights.astype(np.float32)
        else:
            print("No sample weighting applied.")
            self.weights = np.ones_like(self.labels, dtype=np.float32)

        print(f"Dataset created with {len(self.sequences)} samples. Weight method: {sample_weight_method}")
        if len(self.sequences) == 0:
            print("Warning: Dataset is empty!")

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        if idx >= len(self.sequences):
             raise IndexError("Index out of bounds")
        seq = self.sequences[idx]
        label = self.labels[idx]
        weight = self.weights[idx]

        if not isinstance(seq, str):
             print(f"Warning: Sequence at index {idx} is not a string: {seq}. Attempting conversion.")
             seq = str(seq)

        # Tokenize the sequence here
        token_ids = tokenize_sequence(seq)
        return token_ids, label, weight

def collate_fn(batch):
    """Collates data samples into batches."""
    # Filter out potential None values if dataset loading had issues
    batch = [b for b in batch if b is not None]
    if not batch:
        return None, None, None, None # Return None if batch is empty

    token_ids_list, labels, weights = zip(*batch)

    # Pad sequences in the batch - Use the function defined in Cell 3
    padded_ids, masks, _ = pad_sequences(token_ids_list, max_len=1024, pad_value=PAD_ID)

    # Convert labels and weights to tensors
    labels_tensor = torch.tensor(labels, dtype=torch.float32)
    weights_tensor = torch.tensor(weights, dtype=torch.float32)

    # Ensure weights are positive and finite
    weights_tensor = torch.clamp(weights_tensor, min=0.0)
    weights_tensor[torch.isnan(weights_tensor) | torch.isinf(weights_tensor)] = 0.0


    return padded_ids, masks, labels_tensor, weights_tensor

print("Dataset and Collate Function Defined.")

Dataset and Collate Function Defined.


In [23]:
# =============================================================================
# 7. Loss Function Cell
# =============================================================================
def weighted_rmse_loss(y_pred, y_true, weight):
    """Calculates the weighted root mean squared error."""
    if y_pred.shape != y_true.shape:
        y_pred = y_pred.squeeze() # Try to fix shape mismatch
    if y_pred.shape != y_true.shape:
         raise ValueError(f"Shape mismatch: y_pred {y_pred.shape}, y_true {y_true.shape}")
    if y_pred.shape != weight.shape:
        # Try broadcasting weight if it's [batch_size] and prediction is [batch_size, 1] or vice versa
         if weight.shape[0] == y_pred.shape[0] and len(weight.shape)==1:
              weight = weight.view(-1, 1) # Reshape weight to match prediction if needed
              if weight.shape[0] != y_pred.shape[0]: # Check again
                  raise ValueError(f"Shape mismatch after reshape: y_pred {y_pred.shape}, weight {weight.shape}")
         elif y_pred.shape[0] == weight.shape[0] and len(y_pred.shape)==1:
             y_pred = y_pred.view(-1,1) # Reshape prediction to match weight
             if y_pred.shape[0] != weight.shape[0]: # Check again
                 raise ValueError(f"Shape mismatch after reshape: y_pred {y_pred.shape}, weight {weight.shape}")
         else:
              raise ValueError(f"Shape mismatch: y_pred {y_pred.shape}, weight {weight.shape}")


    loss = torch.mean(((y_pred - y_true) ** 2) * weight)
    # Add epsilon to prevent sqrt(0) and potential NaN gradients
    return torch.sqrt(loss + 1e-9)

print("Loss Function Defined.")

Loss Function Defined.


In [24]:
# =============================================================================
# 8. Train/Validation Loop Cell
# =============================================================================
def train_one_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0.0
    num_batches = len(dataloader)
    if num_batches == 0:
        print("Warning: Training dataloader is empty.")
        return 0.0

    for batch_idx, batch_data in enumerate(tqdm(dataloader, desc="Training", leave=False)):
        # Check if collate_fn returned None (empty batch)
        if batch_data[0] is None:
            print(f"Skipping empty batch {batch_idx+1}")
            continue

        ids, masks, labels, weights = batch_data
        ids, masks, labels, weights = ids.to(device), masks.to(device), labels.to(device), weights.to(device)

        optimizer.zero_grad()
        outputs = model(ids, masks) # Get predictions [batch_size]

        # Ensure labels have the same shape as outputs for loss calculation
        if outputs.shape != labels.shape:
            labels = labels.view(outputs.shape) # Reshape labels if necessary

        loss = loss_fn(outputs, labels, weights)

        # Handle potential NaN loss
        if torch.isnan(loss):
            print(f"Warning: NaN loss detected at batch {batch_idx}. Skipping batch.")
            # Optionally: Investigate inputs/outputs/weights for this batch
            # print("NaN Debug Info:")
            # print("IDs:", ids)
            # print("Masks:", masks)
            # print("Labels:", labels)
            # print("Weights:", weights)
            # print("Outputs:", outputs)
            continue # Skip backward pass and optimizer step for this batch

        loss.backward()
        # Optional: Gradient clipping
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / num_batches if num_batches > 0 else 0.0


def validate_epoch(model, dataloader, loss_fn, device):
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_labels = []
    all_weights = [] # Use for calculating weighted metrics if needed
    num_batches = len(dataloader)

    if num_batches == 0:
        print("Warning: Validation dataloader is empty.")
        return 0.0, {}

    with torch.no_grad():
        for batch_data in tqdm(dataloader, desc="Validation", leave=False):
             if batch_data[0] is None: # Skip empty batches
                 continue
             ids, masks, labels, weights = batch_data
             ids, masks, labels, weights = ids.to(device), masks.to(device), labels.to(device), weights.to(device)

             outputs = model(ids, masks) # [batch_size]

             if outputs.shape != labels.shape:
                 labels = labels.view(outputs.shape)

             loss = loss_fn(outputs, labels, weights) # Use the same weighted loss for consistency, or unweighted RMSE for simple val loss
             if not torch.isnan(loss): # Only add valid loss values
                 total_loss += loss.item()

             all_preds.extend(outputs.cpu().numpy())
             all_labels.extend(labels.cpu().numpy())
             all_weights.extend(weights.cpu().numpy()) # Collect weights if using trainutils.performance

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0

    # Calculate performance metrics using trainutils
    # Use standard 'bin_inv' weighting for evaluation metrics for comparability, as in original paper
    # Or pass None to use unweighted metrics
    try:
        # Ensure labels and preds are numpy arrays
        all_labels_np = np.array(all_labels)
        all_preds_np = np.array(all_preds)
        # Calculate evaluation weights based on the true validation labels
        eval_weights = get_sample_weights(all_labels_np, method='bin_inv')
        perf_metrics = performance(all_labels_np, all_preds_np, weights=eval_weights)
    except Exception as e:
        print(f"Error during performance calculation: {e}")
        # Return default empty metrics
        perf_metrics = {'rmse': float('inf'), 'r2': 0.0, 'rho': 0.0, 'r': 0.0, 'mcc': 0.0, 'f1score': 0.0, 'auc': 0.0}


    return avg_loss, perf_metrics


print("Train/Validation Functions Defined.")

Train/Validation Functions Defined.


In [25]:
# # %% [code]
# # =============================================================================
# # 9. Main Training Execution Cell (Verified Saving Logic)
# # =============================================================================

# print("--- Starting LSTM-RLAT Training ---")

# # --- Device Configuration ---
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")

# # --- Set Random Seed ---
# torch.manual_seed(RANDOM_SEED)
# np.random.seed(RANDOM_SEED)
# if device == torch.device("cuda"):
#     torch.cuda.manual_seed_all(RANDOM_SEED)
#     # Optional: For full reproducibility, may need these, but can slow down training
#     # torch.backends.cudnn.deterministic = True
#     # torch.backends.cudnn.benchmark = False

# # --- Load Data ---
# try:
#     print(f"Loading data from: {TARGET_DATA_PATH}")
#     if not os.path.exists(TARGET_DATA_PATH):
#          raise FileNotFoundError(f"Data file not found at {TARGET_DATA_PATH}. Please check path and ensure dataset is uploaded.")

#     full_df = pd.read_csv(TARGET_DATA_PATH, index_col=None) # Adjust index_col if needed
#     if 'Split' not in full_df.columns:
#          raise ValueError("CSV file must contain a 'Split' column with values like 'Training', 'Validation', 'Testing'.")

#     # Filter for train/validation splits
#     train_df = full_df[full_df['Split'] == 'Training'].reset_index(drop=True)
#     val_df = full_df[full_df['Split'] == 'Validation'].reset_index(drop=True)
#     test_df = full_df[full_df['Split'] == 'Testing'].reset_index(drop=True) # Load test data for final eval size check

#     print(f"Loaded {len(full_df)} total entries.")
#     print(f"Training set size: {len(train_df)}")
#     print(f"Validation set size: {len(val_df)}")
#     print(f"Test set size: {len(test_df)}")

#     if len(train_df) == 0 or len(val_df) == 0:
#          raise ValueError("Training or Validation split is empty. Check 'Split' column in CSV.")

# except FileNotFoundError as e:
#     print(f"ERROR: {e}")
#     print("Please ensure your data is uploaded to Kaggle and the TARGET_DATA_PATH variable is set correctly.")
#     # Stop execution if data isn't loaded
#     # In a notebook, you might just print the error and let the user fix it.
#     # For automatic runs, use sys.exit(1)
#     sys.exit(1) # Or comment out to allow fixing path and re-running cell
# except Exception as e:
#     print(f"An unexpected error occurred loading data: {e}")
#     sys.exit(1) # Or comment out

# # --- Create Datasets and Dataloaders ---
# try:
#     train_dataset = SequencepHDataset(train_df, sample_weight_method=SAMPLE_WEIGHT_METHOD)
#     # No sample weighting needed for calculating validation loss itself,
#     # but eval metrics will use weighted calculation inside validate_epoch
#     val_dataset = SequencepHDataset(val_df, sample_weight_method='None')

#     if len(train_dataset) == 0 or len(val_dataset) == 0:
#         raise ValueError("Dataset creation resulted in empty dataset(s).")

#     train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True) # drop_last=True can help with batchnorm issues on last batch
#     val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=True)

#     print(f"Train DataLoader: {len(train_loader)} batches")
#     print(f"Validation DataLoader: {len(val_loader)} batches")

# except Exception as e:
#     print(f"Error creating Datasets or DataLoaders: {e}")
#     sys.exit(1) # Or comment out


# # --- Build Model ---
# model = SequenceTopHModel(
#     vocab_size=VOCAB_SIZE,
#     lstm_embedding_dim=LSTM_EMB_DIM,
#     lstm_hidden_dim=LSTM_HIDDEN_DIM,
#     num_lstm_layers=LSTM_LAYERS,
#     lstm_dropout=LSTM_DROPOUT,
#     rlat_kernel_size=RLAT_KERNEL_SIZE,
#     rlat_dropout=RLAT_DROPOUT,
#     rlat_res_blocks=RLAT_RES_BLOCKS,
#     rlat_activation=RLAT_ACTIVATION,
#     random_seed=RANDOM_SEED
# ).to(device)
# print(f"Model created with {model.get_num_params():,} parameters.")
# # Optional: Print model summary
# # print(model)


# # --- Optimizer and Loss ---
# optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=L2_REG)
# loss_fn = weighted_rmse_loss # Use the function defined earlier
# # Scheduler monitors validation loss
# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=REDUCE_LR_PATIENCE, verbose=True, min_lr=1e-7) # Added min_lr


# # --- Training Loop Variables ---
# best_val_loss = float('inf') # Initialize best loss to infinity
# epochs_no_improve = 0
# training_log = [] # Store epoch results

# # --- Training Loop ---
# print("\n--- Starting Training Loop ---")
# start_time_total = time.time()

# # Initialize epoch variable outside loop for final message
# epoch = 0

# for epoch in range(1, EPOCHS + 1):
#     start_time_epoch = time.time()
#     print(f"\n--- Epoch {epoch}/{EPOCHS} ---")

#     # Ensure model is in training mode
#     model.train()
#     train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device)

#     # Ensure model is in evaluation mode for validation
#     model.eval()
#     # Get validation loss AND metrics dictionary
#     current_val_loss, val_metrics = validate_epoch(model, val_loader, loss_fn, device)

#     # Handle potential NaN in validation loss before using it
#     if np.isnan(current_val_loss):
#         print("Warning: Validation loss is NaN. Skipping scheduler step and model saving for this epoch.")
#         # Optionally stop training if validation becomes unstable
#         # epochs_no_improve += 1 # Consider if NaN counts as non-improvement
#     else:
#         # Scheduler steps based on the valid validation loss
#         scheduler.step(current_val_loss)

#         # --- Save model and handle early stopping BASED ON VALIDATION LOSS ---
#         if current_val_loss < best_val_loss:
#             print(f"  Validation loss decreased ({best_val_loss:.4f} --> {current_val_loss:.4f}). Saving model...")
#             best_val_loss = current_val_loss # Update best loss
#             epochs_no_improve = 0
#             # Save the best model
#             save_path = os.path.join(SAVEDIR, f"{MODEL_NAME}_best.pt")
#             # Ensure model is on CPU before saving to avoid potential GPU memory issues during saving
#             # model.to('cpu') # Move model to CPU temporarily
#             torch.save({
#                 'epoch': epoch,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'best_val_loss': best_val_loss, # Save the best loss value
#                 'val_metrics_at_best': val_metrics, # Optionally save metrics dict at best epoch
#                 'config': { # Save config used for this run
#                      'lstm_emb_dim': LSTM_EMB_DIM, 'lstm_hidden_dim': LSTM_HIDDEN_DIM,
#                      'lstm_layers': LSTM_LAYERS, 'lstm_dropout': LSTM_DROPOUT,
#                      'rlat_kernel_size': RLAT_KERNEL_SIZE, 'rlat_dropout': RLAT_DROPOUT,
#                      'rlat_res_blocks': RLAT_RES_BLOCKS, 'rlat_activation': RLAT_ACTIVATION,
#                      'learning_rate': LEARNING_RATE, 'l2_reg': L2_REG, 'batch_size': BATCH_SIZE,
#                      'sample_weight_method': SAMPLE_WEIGHT_METHOD, 'random_seed': RANDOM_SEED
#                 }
#             }, save_path)
#             # model.to(device) # Move model back to original device
#         else:
#             epochs_no_improve += 1
#             print(f"  Validation loss did not improve for {epochs_no_improve} epoch(s). Best: {best_val_loss:.4f}")

#     # Print Epoch Summary (regardless of NaN status, report metrics if available)
#     epoch_duration = time.time() - start_time_epoch
#     print(f"Epoch {epoch} Summary:")
#     print(f"  Train Loss: {train_loss:.4f}")
#     print(f"  Val Loss:   {current_val_loss:.4f}") # Report current loss, even if NaN
#     print(f"  Val RMSE:   {val_metrics.get('rmse', float('nan')):.4f}")
#     print(f"  Val R2:     {val_metrics.get('r2', float('nan')):.4f}")
#     print(f"  Val Rho:    {val_metrics.get('rho', float('nan')):.4f}")
#     print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
#     print(f"  Duration:   {epoch_duration:.2f}s")

#     # Log results
#     log_entry = {
#         'epoch': epoch,
#         'train_loss': train_loss,
#         'val_loss': current_val_loss, # Log current val loss
#         'val_rmse': val_metrics.get('rmse', float('nan')),
#         'val_r2': val_metrics.get('r2', float('nan')),
#         'val_rho': val_metrics.get('rho', float('nan')),
#         'val_r': val_metrics.get('r', float('nan')),
#         'val_mcc': val_metrics.get('mcc', float('nan')),
#         'val_f1score': val_metrics.get('f1score', float('nan')),
#         'val_auc': val_metrics.get('auc', float('nan')),
#         'learning_rate': optimizer.param_groups[0]['lr'],
#         'duration': epoch_duration
#     }
#     training_log.append(log_entry)


#     # Check for early stopping
#     if epochs_no_improve >= STOP_PATIENCE:
#         print(f"\nEarly stopping triggered after {STOP_PATIENCE} epochs without improvement on validation loss.")
#         break
#     # Optional: Stop if validation loss becomes NaN
#     # if np.isnan(current_val_loss):
#     #    print("\nStopping training due to NaN validation loss.")
#     #    break


# # --- End of Training ---
# total_duration = time.time() - start_time_total
# print(f"\n--- Training Finished ---")
# # Ensure epoch reports the actual last epoch number if early stopping occurred
# print(f"Training finished after {epoch} epochs.")
# print(f"Total training time: {total_duration/60:.2f} minutes.")
# print(f"Best validation loss achieved: {best_val_loss:.4f}")

# # Save the final training log
# log_df = pd.DataFrame(training_log)
# log_path = os.path.join(SAVEDIR, f"{MODEL_NAME}_training_log.csv")
# log_df.to_csv(log_path, index=False)
# print(f"Training log saved to {log_path}")

# print("\n --- LSTM-RLAT Training Script Completed ---")

In [26]:
# %% [code]
# =============================================================================
# 9. Optuna Hyperparameter Optimization Setup (Corrected & Tuned Ranges)
# =============================================================================
!pip install optuna -q # Install optuna quietly
import optuna
import gc # Garbage collector
import pandas as pd
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
import os
import sys
from tqdm.notebook import tqdm # Ensure notebook tqdm is used

# --- Ensure Previous Cells (1-8) containing necessary definitions are executed ---
# Includes: LSTM Module, RLAT Components, Combined Model, Dataset, Loss, Train/Val Funcs

print("--- Setting up Optuna Hyperparameter Optimization (Tuned Ranges) ---")

# --- Device Configuration ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Load Data ONCE outside the objective function ---
# (Assuming train_df, val_df defined from previous cell execution)
if 'train_df' not in globals() or 'val_df' not in globals():
    print("ERROR: train_df or val_df not defined. Please execute data loading cell first.")
    # Add data loading code here again if necessary, or stop
    import sys
    sys.exit(1)
else:
    print(f"Using pre-loaded Training ({len(train_df)}) and Validation ({len(val_df)}) data.")


# --- Define the Objective Function for Optuna ---

def objective(trial):
    """Optuna objective function for hyperparameter tuning."""
    global train_df, val_df # Access dataframes defined outside

    # --- 1. Suggest Hyperparameters (with adjusted ranges) ---
    cfg = {
        # LSTM Args
        'lstm_emb_dim': trial.suggest_categorical('lstm_emb_dim', [128, 256]), # Kept embedding dim reasonable
        'lstm_hidden_dim': trial.suggest_categorical('lstm_hidden_dim', [128, 256]), # Reduced max hidden dim
        'lstm_layers': trial.suggest_int('lstm_layers', 1, 2), # Reduced max layers
        'lstm_dropout': trial.suggest_float('lstm_dropout', 0.1, 0.4, step=0.1), # Slightly reduced max dropout
        # RLAT Args
        'rlat_kernel_size': trial.suggest_categorical('rlat_kernel_size', [5, 7, 9]), # Removed smallest kernel
        'rlat_dropout': trial.suggest_float('rlat_dropout', 0.1, 0.5, step=0.1), # Slightly reduced max dropout
        'rlat_res_blocks': trial.suggest_int('rlat_res_blocks', 2, 4), # Reduced max res blocks
        'rlat_activation': trial.suggest_categorical('rlat_activation', ['relu', 'elu', 'gelu']),
        # Training Args
        'learning_rate': trial.suggest_float('learning_rate', 5e-5, 8e-4, log=True), # Adjusted LR range slightly
        'l2_reg': trial.suggest_float('l2_reg', 1e-6, 5e-5, log=True), # Adjusted reg range slightly
        'batch_size': trial.suggest_categorical('batch_size', [32, 64]), # Increased min batch size
        'sample_weight_method': trial.suggest_categorical('sample_weight_method', ['None', 'bin_inv_sqrt', 'LDS_inv_sqrt']),
        'optimizer_name': trial.suggest_categorical('optimizer_name', ['Adam', 'AdamW']),
    }
    print(f"\n--- Starting Trial {trial.number} ---")
    print(f"Parameters: {cfg}")

    # --- 2. Setup Trial (Dataset, Model, Optimizer) ---
    trial_seed = RANDOM_SEED # Use same base seed for better comparison across trials initially
    torch.manual_seed(trial_seed)
    np.random.seed(trial_seed)
    if device == torch.device("cuda"):
        torch.cuda.manual_seed_all(trial_seed)

    try:
        # Create datasets using the suggested weight method for training data
        current_train_dataset = SequencepHDataset(train_df, sample_weight_method=cfg['sample_weight_method'])
        current_val_dataset = SequencepHDataset(val_df, sample_weight_method='None')
        # Create DataLoaders with suggested batch size
        train_loader = DataLoader(current_train_dataset, batch_size=cfg['batch_size'], shuffle=True, collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
        val_loader = DataLoader(current_val_dataset, batch_size=cfg['batch_size'], shuffle=False, collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=True)
        print(f"Trial {trial.number}: Train batches={len(train_loader)}, Val batches={len(val_loader)}")
    except Exception as e:
         print(f"Trial {trial.number} failed during data loading: {e}")
         return float('inf') # Indicate failure

    # Build model using suggested hyperparameters
    try:
        model = SequenceTopHModel(
            vocab_size=VOCAB_SIZE, lstm_embedding_dim=cfg['lstm_emb_dim'], lstm_hidden_dim=cfg['lstm_hidden_dim'],
            num_lstm_layers=cfg['lstm_layers'], lstm_dropout=cfg['lstm_dropout'], rlat_kernel_size=cfg['rlat_kernel_size'],
            rlat_dropout=cfg['rlat_dropout'], rlat_res_blocks=cfg['rlat_res_blocks'], rlat_activation=cfg['rlat_activation'],
            random_seed=trial_seed
        ).to(device)
        print(f"Trial {trial.number}: Model params={model.get_num_params():,}")
    except Exception as e:
        print(f"Trial {trial.number} failed during model creation: {e}")
        # Clean up data loaders if model fails
        del train_loader, val_loader, current_train_dataset, current_val_dataset
        gc.collect(); torch.cuda.empty_cache()
        return float('inf') # Indicate failure


    # Optimizer using suggested name and LR/Reg
    if cfg['optimizer_name'] == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=cfg['learning_rate'], weight_decay=cfg['l2_reg'])
    else: # AdamW
         optimizer = optim.AdamW(model.parameters(), lr=cfg['learning_rate'], weight_decay=cfg['l2_reg'])

    loss_fn = weighted_rmse_loss
    # Use fixed scheduler params or tune them as well (keeping fixed for now)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=REDUCE_LR_PATIENCE, verbose=False, min_lr=1e-7)

    # --- 3. Training Loop for this Trial (with reduced HPO epochs) ---
    best_trial_val_loss = float('inf') # Track best loss *for this trial*
    epochs_no_improve = 0
    trial_last_epoch = 0

    # Reduced number of epochs per HPO trial
    HPO_EPOCHS = 60 # Adjust this value (e.g., 50-75) based on time constraints

    print(f"Trial {trial.number}: Starting training for max {HPO_EPOCHS} epochs...")
    for epoch in range(1, HPO_EPOCHS + 1):
        trial_last_epoch = epoch
        epoch_start_time = time.time()

        # Train one epoch
        model.train()
        train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device) # Assumes defined above

        # Validate one epoch
        model.eval()
        current_val_loss, val_metrics = validate_epoch(model, val_loader, loss_fn, device) # Assumes defined above

        epoch_duration = time.time() - epoch_start_time
        # Print progress minimally during HPO
        if epoch % 10 == 0 or epoch == 1: # Print every 10 epochs and the first epoch
             print(f"  Epoch {epoch}/{HPO_EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {current_val_loss:.4f} | Val RMSE: {val_metrics.get('rmse', float('nan')):.4f} | Time: {epoch_duration:.1f}s")

        # Handle potential NaN validation loss
        if np.isnan(current_val_loss) or np.isinf(current_val_loss):
             print(f"Trial {trial.number} - Epoch {epoch}: Invalid validation loss ({current_val_loss}). Pruning.")
             del model, optimizer, train_loader, val_loader, current_train_dataset, current_val_dataset
             gc.collect(); torch.cuda.empty_cache()
             raise optuna.TrialPruned()

        # Update best loss *for this trial*
        if current_val_loss < best_trial_val_loss:
            best_trial_val_loss = current_val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # Step the scheduler
        scheduler.step(current_val_loss)

        # --- Optuna Pruning Check ---
        trial.report(current_val_loss, epoch)
        if trial.should_prune():
            print(f"Trial {trial.number} pruned at epoch {epoch} with val_loss {current_val_loss:.4f}.")
            del model, optimizer, train_loader, val_loader, current_train_dataset, current_val_dataset
            gc.collect(); torch.cuda.empty_cache()
            raise optuna.TrialPruned()

        # --- Early Stopping for this Trial ---
        if epochs_no_improve >= STOP_PATIENCE: # Use global STOP_PATIENCE
             print(f"Trial {trial.number} early stopped at epoch {epoch} due to lack of improvement (best loss: {best_trial_val_loss:.4f}).")
             break # Stop training this specific trial

    # --- 4. Return Metric to Optimize ---
    print(f"Trial {trial.number} completed after {trial_last_epoch} epochs. Best val_loss: {best_trial_val_loss:.4f}")
    del model, optimizer, train_loader, val_loader, current_train_dataset, current_val_dataset
    gc.collect(); torch.cuda.empty_cache()

    # Return the best validation LOSS achieved during this trial
    return best_trial_val_loss


# --- Run the Optuna Study ---
N_TRIALS = 30 # Reduced number of trials initially to test stability, increase later (e.g., 50)
STUDY_NAME = f"{MODEL_NAME}_HPO_v2" # New study name

# --- Start a NEW study ---
# If you interrupted a previous one, it's often cleaner to start fresh with adjusted ranges
study = optuna.create_study(
    direction='minimize', # Minimize validation loss
    study_name=STUDY_NAME,
    sampler=optuna.samplers.TPESampler(seed=RANDOM_SEED),
    # Use a more aggressive pruner
    pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=15, interval_steps=3) # Check pruning sooner
)

print(f"\nStarting Optuna optimization with {N_TRIALS} trials...")
try:
    study.optimize(
        objective,
        n_trials=N_TRIALS,
        timeout=None, # e.g., 6 * 3600 for 6 hours
        gc_after_trial=True # Helps manage memory
    )
except KeyboardInterrupt:
     print("Optimization stopped manually.")
except Exception as e:
     print(f"An error occurred during Optuna optimization: {e}")
     import traceback
     traceback.print_exc()


# --- Optimization Finished ---
print("\n--- Optuna Optimization Finished ---")
print(f"Number of finished trials: {len(study.trials)}")

try:
    best_trial = study.best_trial
    print(f"Best trial number: {best_trial.number}")
    print(f"Best validation objective value (loss): {best_trial.value:.4f}")
    print("Best hyperparameters:")
    for key, value in best_trial.params.items():
        print(f"  {key}: {value}")
    BEST_PARAMS = best_trial.params
    print("\nBest parameters stored in BEST_PARAMS dictionary.")

except ValueError: # Handles case where no trials completed successfully
     print("No successful trials completed. Cannot determine best parameters.")
     BEST_PARAMS = None

# Save study results
study_results_df = study.trials_dataframe()
study_results_path = os.path.join(SAVEDIR, f"{STUDY_NAME}_results.csv")
study_results_df.to_csv(study_results_path, index=False)
print(f"\nStudy results saved to {study_results_path}")


# --- Proceed to Cell 10 (Retraining with BEST_PARAMS if available) ---
if BEST_PARAMS is None:
    print("\nSkipping final retraining as no best parameters were found.")
else:
    print("\nProceeding to final retraining using BEST_PARAMS (in Cell 10).")
# (Ensure Cell 10 exists and uses the BEST_PARAMS dictionary)

--- Setting up Optuna Hyperparameter Optimization (Tuned Ranges) ---
Using device: cuda
ERROR: train_df or val_df not defined. Please execute data loading cell first.


SystemExit: 1

In [None]:
print(study.trials_dataframe())
# Look specifically at the row for trial number 2 if it exists
# Or look at the last successful trial's info
print("\nBest Trial So Far:")
print(study.best_trial)

In [None]:
# %% [code]
# =============================================================================
# 10. Retrain Model with Best Hyperparameters & Evaluate on Test Set
# =============================================================================
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
import os
import sys
import gc # Import garbage collector

# --- Ensure Previous Cells containing necessary definitions are executed ---
# Requires: BEST_PARAMS dict, train_df, val_df, test_df pandas DataFrames,
#           SequenceTopHModel, SequencepHDataset, collate_fn, weighted_rmse_loss,
#           train_one_epoch, validate_epoch, VOCAB_SIZE, PAD_ID, SAVEDIR,
#           MODEL_NAME, RANDOM_SEED, NUM_WORKERS, EPOCHS, REDUCE_LR_PATIENCE,
#           STOP_PATIENCE

print("\n--- Retraining model with best hyperparameters ---")

# --- Check if BEST_PARAMS exists ---
if 'BEST_PARAMS' not in globals() or BEST_PARAMS is None:
    print("ERROR: BEST_PARAMS dictionary not found or is None. Cannot proceed.")
    print("Please run the Optuna study successfully in Cell 9 first.")
    # Stop execution if params aren't available
    # sys.exit(1) # Uncomment for non-interactive runs
    # In a notebook, better to raise an error or just print the message and stop here.
    raise NameError("BEST_PARAMS not defined. Run Optuna study cell.")
else:
    print("Using best parameters found by Optuna:")
    # Print cleanly
    for key, value in BEST_PARAMS.items():
        print(f"  {key}: {value}")

    # --- Setup using BEST_PARAMS ---
    final_cfg = BEST_PARAMS.copy() # Use the best params found
    final_seed = RANDOM_SEED # Use the global seed for consistency
    final_model_name = f"{MODEL_NAME}_best_hpo" # Distinct name for the final model

    # Set seed for final training run
    torch.manual_seed(final_seed)
    np.random.seed(final_seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if device == torch.device("cuda"):
        torch.cuda.manual_seed_all(final_seed)

    # --- Create Datasets and Dataloaders for final run ---
    try:
        # Ensure dataframes exist
        if 'train_df' not in globals() or 'val_df' not in globals() or 'test_df' not in globals():
             raise NameError("train_df, val_df, or test_df not found. Ensure data loading in Cell 9 ran correctly.")

        print("Creating final datasets...")
        final_train_dataset = SequencepHDataset(train_df, sample_weight_method=final_cfg['sample_weight_method'])
        final_val_dataset = SequencepHDataset(val_df, sample_weight_method='None')
        final_test_dataset = SequencepHDataset(test_df, sample_weight_method='None')

        # Use the batch size found by Optuna
        final_batch_size = final_cfg['batch_size']
        print(f"Using final batch size: {final_batch_size}")

        final_train_loader = DataLoader(final_train_dataset, batch_size=final_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
        final_val_loader = DataLoader(final_val_dataset, batch_size=final_batch_size, shuffle=False, collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=True)
        # Use a potentially larger batch size for test evaluation if memory allows
        final_test_batch_size = max(64, final_batch_size) # Example: Use 64 or the training batch size, whichever is larger
        final_test_loader = DataLoader(final_test_dataset, batch_size=final_test_batch_size, shuffle=False, collate_fn=collate_fn, num_workers=NUM_WORKERS)
        print("Final DataLoaders created.")

    except Exception as e:
        print(f"Error creating final datasets/loaders: {e}")
        import traceback
        traceback.print_exc()
        # Stop if datasets fail
        # sys.exit(1) # Uncomment for non-interactive runs
        raise RuntimeError("Failed to create final datasets/loaders.") from e

    # --- Build final model using BEST_PARAMS ---
    try:
        print("Building final model...")
        final_model = SequenceTopHModel(
            vocab_size=VOCAB_SIZE,
            lstm_embedding_dim=final_cfg['lstm_emb_dim'],
            lstm_hidden_dim=final_cfg['lstm_hidden_dim'],
            num_lstm_layers=final_cfg['lstm_layers'],
            lstm_dropout=final_cfg['lstm_dropout'],
            rlat_kernel_size=final_cfg['rlat_kernel_size'],
            rlat_dropout=final_cfg['rlat_dropout'],
            rlat_res_blocks=final_cfg['rlat_res_blocks'],
            rlat_activation=final_cfg['rlat_activation'],
            random_seed=final_seed
        ).to(device)
        print(f"Final model created with {final_model.get_num_params():,} parameters.")
    except Exception as e:
        print(f"Error building final model: {e}")
        import traceback
        traceback.print_exc()
        raise RuntimeError("Failed to build final model.") from e


    # --- Final Optimizer and Loss ---
    try:
        print("Setting up final optimizer and scheduler...")
        if final_cfg['optimizer_name'] == 'Adam':
            final_optimizer = optim.Adam(final_model.parameters(), lr=final_cfg['learning_rate'], weight_decay=final_cfg['l2_reg'])
        elif final_cfg['optimizer_name'] == 'AdamW':
            final_optimizer = optim.AdamW(final_model.parameters(), lr=final_cfg['learning_rate'], weight_decay=final_cfg['l2_reg'])
        else: # Default fallback
            print(f"Warning: Unknown optimizer '{final_cfg['optimizer_name']}'. Defaulting to Adam.")
            final_optimizer = optim.Adam(final_model.parameters(), lr=final_cfg['learning_rate'], weight_decay=final_cfg['l2_reg'])

        final_loss_fn = weighted_rmse_loss
        # Scheduler monitors validation RMSE
        final_scheduler = ReduceLROnPlateau(final_optimizer, mode='min', factor=0.5, patience=REDUCE_LR_PATIENCE, verbose=True, min_lr=1e-7)
        print("Optimizer and scheduler ready.")
    except Exception as e:
        print(f"Error setting up optimizer/scheduler: {e}")
        raise RuntimeError("Failed to setup optimizer/scheduler.") from e

    # --- Final Training Loop ---
    print("\n--- Starting Final Training Loop ---")
    best_final_val_rmse = float('inf') # Track best RMSE for saving
    final_epochs_no_improve = 0
    final_training_log = []
    final_model_save_path = os.path.join(SAVEDIR, f"{final_model_name}.pt") # Define save path

    start_time_total = time.time()
    last_epoch = 0 # Keep track of the last epoch number

    try:
        for epoch in range(1, EPOCHS + 1):
            last_epoch = epoch
            start_time_epoch = time.time()
            print(f"\n--- Final Training Epoch {epoch}/{EPOCHS} ---")

            # Training Step
            final_model.train()
            train_loss = train_one_epoch(final_model, final_train_loader, final_optimizer, final_loss_fn, device)

            # Validation Step
            final_model.eval()
            val_loss, val_metrics = validate_epoch(final_model, final_val_loader, final_loss_fn, device)
            # Use validation RMSE as the primary metric to check for improvement
            current_val_rmse = val_metrics.get('rmse', float('inf'))

            # Handle potential NaN in validation metric before using it
            if np.isnan(current_val_rmse) or np.isinf(current_val_rmse):
                 print(f"Warning: Validation RMSE is {current_val_rmse}. Skipping scheduler step and model saving.")
                 # Decide how to handle this: continue, count as non-improvement, or break
                 final_epochs_no_improve += 1 # Example: Count as non-improvement
            else:
                # Step the scheduler based on the valid validation RMSE
                final_scheduler.step(current_val_rmse)

                # --- Save model and handle early stopping BASED ON VALIDATION RMSE ---
                if current_val_rmse < best_final_val_rmse:
                    print(f"  Validation RMSE improved ({best_final_val_rmse:.4f} --> {current_val_rmse:.4f}). Saving model to {final_model_save_path}")
                    best_final_val_rmse = current_val_rmse # Update best RMSE
                    final_epochs_no_improve = 0
                    # Save the best model's state dictionary
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': final_model.state_dict(),
                        'optimizer_state_dict': final_optimizer.state_dict(),
                        'best_val_rmse': best_final_val_rmse,
                        'final_val_metrics': val_metrics, # Save metrics dict at best point
                        'hyperparameters': BEST_PARAMS # Save the HPO params used
                    }, final_model_save_path)
                else:
                    final_epochs_no_improve += 1
                    print(f"  Validation RMSE did not improve for {final_epochs_no_improve} epoch(s). Best: {best_final_val_rmse:.4f}")

            # Print Epoch Summary
            epoch_duration = time.time() - start_time_epoch
            print(f"Epoch {epoch} Summary:")
            print(f"  Train Loss: {train_loss:.4f}")
            print(f"  Val Loss:   {val_loss:.4f}") # Report raw validation loss
            print(f"  Val RMSE:   {current_val_rmse:.4f}") # Report validation RMSE used for decisions
            print(f"  Val R2:     {val_metrics.get('r2', float('nan')):.4f}")
            print(f"  Learning Rate: {final_optimizer.param_groups[0]['lr']:.2e}")
            print(f"  Duration:   {epoch_duration:.2f}s")

            # Log results for this epoch
            log_entry = {
                'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss,
                'val_rmse': current_val_rmse, 'val_r2': val_metrics.get('r2', float('nan')),
                'val_rho': val_metrics.get('rho', float('nan')), # Log other metrics too
                'learning_rate': final_optimizer.param_groups[0]['lr'], 'duration': epoch_duration
            }
            final_training_log.append(log_entry)

            # Check for early stopping
            if final_epochs_no_improve >= STOP_PATIENCE:
                print(f"\nFinal training early stopped after {STOP_PATIENCE} epochs without improvement on validation RMSE.")
                break
            # Optional: Check for NaN stop
            # if np.isnan(current_val_rmse) or np.isinf(current_val_rmse):
            #     print(f"Stopping final training due to invalid validation RMSE: {current_val_rmse}")
            #     break

    except Exception as e:
        print(f"\nAn error occurred during the final training loop at epoch {last_epoch}: {e}")
        import traceback
        traceback.print_exc()
    finally:
        # --- End of Final Training ---
        total_duration = time.time() - start_time_total
        print(f"\n--- Final Training Finished ---")
        print(f"Training ran for {last_epoch} epochs.")
        print(f"Total final training time: {total_duration/60:.2f} minutes.")
        if np.isinf(best_final_val_rmse):
             print("Best validation RMSE could not be determined (remained infinity).")
        else:
             print(f"Best final validation RMSE achieved: {best_final_val_rmse:.4f}")

        # Save final log regardless of how loop ended
        if final_training_log:
             log_df = pd.DataFrame(final_training_log)
             log_path = os.path.join(SAVEDIR, f"{final_model_name}_training_log.csv")
             log_df.to_csv(log_path, index=False)
             print(f"Final training log saved to {log_path}")
        else:
             print("No epochs completed, final training log not saved.")

        # Clean up GPU memory
        del final_model, final_optimizer, final_train_loader, final_val_loader, final_train_dataset, final_val_dataset
        gc.collect()
        if device == torch.device("cuda"): torch.cuda.empty_cache()


    # --- Final Evaluation on Test Set ---
    print("\n--- Evaluating final model on Test Set ---")
    # Use the path where the best model *should* have been saved
    if os.path.exists(final_model_save_path) and not np.isinf(best_final_val_rmse):
        print(f"Loading best model from: {final_model_save_path}")
        try:
            # Load the checkpoint
            checkpoint = torch.load(final_model_save_path, map_location=device)
            # Verify hyperparameters match BEST_PARAMS
            if checkpoint['hyperparameters'] != BEST_PARAMS:
                 print("Warning: Hyperparameters in saved model do not match BEST_PARAMS from Optuna study!")

            # Rebuild model architecture with saved hyperparameters
            # Use the hyperparameters stored *in the checkpoint* for robustness
            saved_hyperparams = checkpoint['hyperparameters']
            eval_model = SequenceTopHModel(
                 vocab_size=VOCAB_SIZE, lstm_embedding_dim=saved_hyperparams['lstm_emb_dim'],
                 lstm_hidden_dim=saved_hyperparams['lstm_hidden_dim'], num_lstm_layers=saved_hyperparams['lstm_layers'],
                 lstm_dropout=saved_hyperparams['lstm_dropout'], rlat_kernel_size=saved_hyperparams['rlat_kernel_size'],
                 rlat_dropout=saved_hyperparams['rlat_dropout'], rlat_res_blocks=saved_hyperparams['rlat_res_blocks'],
                 rlat_activation=saved_hyperparams['rlat_activation'], random_seed=final_seed # Use consistent seed
             ).to(device)

            # Load the learned weights
            eval_model.load_state_dict(checkpoint['model_state_dict'])
            eval_model.eval() # Set to evaluation mode

            print("Model loaded successfully. Running test evaluation...")
            # Run validation function on the test loader
            test_loss, test_metrics = validate_epoch(eval_model, final_test_loader, final_loss_fn, device)

            print("\n--- Final Model Test Set Performance ---")
            print(f"  Test Loss (Weighted RMSE): {test_loss:.4f}") # Note: This is weighted RMSE loss on test set
            print(f"  Test RMSE (Eval Weighted): {test_metrics.get('rmse', float('nan')):.4f}") # From performance()
            print(f"  Test R2   (Eval Weighted): {test_metrics.get('r2', float('nan')):.4f}")
            print(f"  Test Rho  (Spearman):      {test_metrics.get('rho', float('nan')):.4f}")
            print(f"  Test R    (Pearson):       {test_metrics.get('r', float('nan')):.4f}")
            print(f"  Test MCC  (Binned):        {test_metrics.get('mcc', float('nan')):.4f}")
            print(f"  Test F1   (Binned, W-Avg): {test_metrics.get('f1score', float('nan')):.4f}")
            print(f"  Test AUC  (Binned, OvR):   {test_metrics.get('auc', float('nan')):.4f}")

            # Clean up eval model
            del eval_model, checkpoint
            gc.collect()
            if device == torch.device("cuda"): torch.cuda.empty_cache()

        except Exception as e:
            print(f"Error during final test evaluation: {e}")
            import traceback
            traceback.print_exc()
    elif np.isinf(best_final_val_rmse):
         print(f"Skipping test evaluation because final model training did not achieve a valid best validation RMSE.")
    else:
        print(f"Final best model file not found at {final_model_save_path}. Cannot perform test evaluation.")

# Final message if BEST_PARAMS was initially None
if 'BEST_PARAMS' not in globals() or BEST_PARAMS is None:
     print("\nRetraining and evaluation skipped as BEST_PARAMS were not available.")

# ## Next Steps: Inference
#
# After training, you can use the saved best model (`_best.pt`) to predict pH for new sequences. You would typically:
# 1.  Load the model architecture (SequenceTopHModel).
# 2.  Load the saved state dictionary (`torch.load(...)`).
# 3.  Set the model to evaluation mode (`model.eval()`).
# 4.  Adapt the logic from `run.py` (reading FASTA, tokenizing, padding, running model inference) in a new cell or notebook.