In [2]:
! nvidia-smi

Fri Jul  4 02:41:50 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:4F:00.0 Off |                    0 |
| N/A   28C    P0             44W /  270W |       4MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
import argparse
import os
import time
from typing import Dict, Tuple, Union, Optional, Callable, List, Any
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import torch
import torch.distributed as dist
import transformers
import yaml
from datasets import (
    Dataset,
    load_dataset,
    DatasetDict,
    IterableDatasetDict,
    IterableDataset,
)
from datasets import Dataset as HFDataset, DatasetDict
from sklearn.metrics import f1_score, matthews_corrcoef
from sklearn.model_selection import KFold
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    PreTrainedTokenizer,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    DataCollatorWithPadding,
    PreTrainedModel,
    AutoConfig,
)




In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device 

device(type='cuda')

In [5]:
root = '/work/magroup/shared/DNA_LLM/DNALongBench/'

In [6]:
import torch
import dnalongbench
from dnalongbench.utils import load_data, BasenjiDataSet
import numpy as np
import tensorflow as tf

In [7]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [8]:
tf.config.set_visible_devices([], 'GPU')

In [9]:
# human_fasta_path = root + 'regulatory_sequence_activity_prediction/human/seqs/hg38.ml.fa'
# data_path = root + 'regulatory_sequence_activity_prediction/'
# organism = 'human'

In [10]:
# train_dataset = BasenjiDataSet(data_path, organism, 'train', 196608, human_fasta_path)

In [11]:
# def custom_collate(batch):
#     x, y = zip(*batch)
#     x = torch.tensor(np.stack(x), dtype=torch.float32)
#     y = torch.tensor(np.stack(y), dtype=torch.float32)
#     return x, y

# train_loader = torch.utils.data.DataLoader(
#     train_dataset, batch_size=4, num_workers=0, collate_fn=custom_collate
# )

In [12]:
# for batch in train_loader: 
#         x, y = batch
#         print('x:',x.size())
#         print('y:',y.size())
#         break


In [13]:
# train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, num_workers=0)

In [14]:
def load_data(root='/work/magroup/shared/DNA_LLM/DNALongBench/', task_name = 'regulatory_sequence_activity', organism = 'human', cell_type=None, batch_size=16, sequence_length=196608):
    if task_name == 'regulatory_sequence_activity':
        assert organism == "human" or organism == "mouse"
        human_fasta_path = root + 'regulatory_sequence_activity_prediction/human/seqs/hg38.ml.fa'
        mouse_fasta_path = root + 'regulatory_sequence_activity_prediction/mouse/seqs/mm10.ml.fa'
        data_path = root + 'regulatory_sequence_activity_prediction/'
        
        # SEQUENCE_LENGTH = 196608
        # BIN_SIZE = 128
        # TARGET_LENGTH = 896

        fasta_path = human_fasta_path if organism == "human" else mouse_fasta_path

        train_dataset = BasenjiDataSet(data_path, organism, 'train', sequence_length, fasta_path)
        valid_dataset = BasenjiDataSet(data_path, organism, 'valid', sequence_length, fasta_path)
        test_dataset = BasenjiDataSet(data_path, organism, 'test', sequence_length, fasta_path)

        def custom_collate(batch):
            x, y = zip(*batch)
            x = torch.tensor(np.stack(x), dtype=torch.float32)
            y = torch.tensor(np.stack(y), dtype=torch.float32)
            return x, y
        
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=0, collate_fn=custom_collate)
        valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, num_workers=0, collate_fn=custom_collate)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=0, collate_fn=custom_collate)
        return train_loader, valid_loader, test_loader 

In [15]:
train_loader, valid_loader, test_loader = load_data(root=root, task_name = 'regulatory_sequence_activity', organism = 'mouse', cell_type=None, batch_size=16, sequence_length=196608)

In [16]:
for batch in train_loader: 
        x, y = batch
        print('x:',x.size())
        print('y:',y.size())
        break


2025-07-04 02:41:56.760347: E tensorflow/core/util/util.cc:131] oneDNN supports DT_HALF only on platforms with AVX-512. Falling back to the default Eigen-based implementation if present.


x: torch.Size([16, 196608, 4])
y: torch.Size([16, 896, 1643])


In [17]:
!nvidia-smi

Fri Jul  4 02:41:58 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:4F:00.0 Off |                    0 |
| N/A   28C    P0             44W /  270W |       4MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [18]:
def collate_fn(batch, tokenizer, max_length=450000):
    """
    Custom collate function for DNA data that converts one-hot encoded sequences to raw sequences
    and tokenizes them.
    
    Args:
        batch: List of tuples where each tuple is (x, y)
               x is one-hot encoded DNA sequence of shape (seq_len, 4)
               y is gene expression data of shape (10, seq_len)
        tokenizer: The GENERator tokenizer
        max_length: Maximum sequence length for tokenization
    
    Returns:
        Dictionary with tokenized inputs and original gene expression data
    """
    # Separate x and y from the batch
    x_batch, y_batch = zip(*batch)
    x_batch = torch.tensor(np.stack(x_batch), dtype=torch.float32)
    y_batch = torch.tensor(np.stack(y_batch), dtype=torch.float32)
    
    # Convert one-hot encoded sequences to raw sequences
    raw_sequences = []
    nucleotides = ['A', 'C', 'G', 'T']
    for one_hot_seq in x_batch:
     
        # Ensure one_hot_seq is a PyTorch tensor
        if not isinstance(one_hot_seq, torch.Tensor):
            one_hot_seq = torch.tensor(one_hot_seq)
        
        # Get indices of 1s in one-hot encoding (argmax along axis 1)
        indices = torch.argmax(one_hot_seq, dim=1).cpu().numpy()
        
        # Convert indices to nucleotides
        raw_seq = ''.join([nucleotides[idx] for idx in indices])
        raw_sequences.append(raw_seq)
    
    # Tokenize the raw sequences
    tokenizer.padding_side = "right"
    inputs = tokenizer(
        raw_sequences,
        add_special_tokens=True,
        return_tensors="pt",
        padding=False,
        truncation=True,
        # max_length=max_length
    )

    
    # Convert y arrays to tensors and stack them
    y_tensors = []
    for y in y_batch:
        if not isinstance(y, torch.Tensor):
            y = torch.tensor(y, dtype=torch.float32)
        y_tensors.append(y)
    
    y_stacked = torch.stack(y_tensors)
    
    # Return tokenized inputs and original y
    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "y": y_stacked
    }

In [19]:

# Set logging level for transformers
transformers.logging.set_verbosity_info()

# Define optimization direction for each metric (whether higher or lower is better)
METRICS_DIRECTION: Dict[str, str] = {
    "accuracy": "max",
    "f1_score": "max",
    "mcc": "max",
    "f1_max": "max",
    "auprc_micro": "max",
    "mse": "min",
    "mae": "min",
    "r2": "max",
    "pearson": "max",
}


def is_main_process() -> bool:
    """
    Check if current process is the main process (rank 0) in distributed training.

    Returns:
        bool: True if this is the main process, False otherwise
    """
    if dist.is_initialized():
        return dist.get_rank() == 0
    return True


def dist_print(*args, **kwargs) -> None:
    """
    Print only from the main process (rank 0) in distributed training.
    Prevents duplicate outputs in multi-GPU settings.

    Args:
        *args: Arguments to pass to print function
        **kwargs: Keyword arguments to pass to print function
    """
    if is_main_process():
        print(*args, **kwargs)


In [20]:
def setup_tokenizer(
    model_name: str, padding_and_truncation_side: str
) -> PreTrainedTokenizer:
    """
    Load and configure tokenizer for sequence understanding.

    Args:
        model_name: Name or path of the HuggingFace model
        padding_and_truncation_side: Side for padding and truncation (left or right)

    Returns:
        PreTrainedTokenizer: Configured tokenizer for the model
    """
    dist_print(f"🔤 Loading tokenizer from: {model_name}")
    start_time = time.time()

    # Load tokenizer with trust_remote_code to support custom models
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    # Configure padding and truncation settings
    tokenizer.padding_side = padding_and_truncation_side
    tokenizer.truncation_side = padding_and_truncation_side

    # Set pad_token to eos_token if not defined
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    dist_print(
        f"⏱️ Tokenizer loading completed in {time.time() - start_time:.2f} seconds"
    )

    return tokenizer

In [21]:
tokenizer = setup_tokenizer("GenerTeam/GENERator-eukaryote-1.2b-base", 'right')

🔤 Loading tokenizer from: GenerTeam/GENERator-eukaryote-1.2b-base


loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at /home/wenduoc/.cache/huggingface/hub/models--GenerTeam--GENERator-eukaryote-1.2b-base/snapshots/3be4abf390afbb7f4d8ccb3370f599338523f1cd/special_tokens_map.json
loading file tokenizer_config.json from cache at /home/wenduoc/.cache/huggingface/hub/models--GenerTeam--GENERator-eukaryote-1.2b-base/snapshots/3be4abf390afbb7f4d8ccb3370f599338523f1cd/tokenizer_config.json
loading file tokenizer.json from cache at None
loading file chat_template.jinja from cache at None


⏱️ Tokenizer loading completed in 0.25 seconds


In [22]:
tokenizer

DNAKmerTokenizer(name_or_path='GenerTeam/GENERator-eukaryote-1.2b-base', vocab_size=4128, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<oov>', 'pad_token': '<pad>'}, clean_up_tokenization_spaces=True, added_tokens_decoder={
	0: AddedToken("<oov>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [23]:
train_loader2 = DataLoader(
        train_loader.dataset,
        batch_size=4,
        collate_fn=lambda b: collate_fn(b, tokenizer, max_length=196_608)
    )


In [24]:
for batch in train_loader2: 
        print(batch)
        break


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


{'input_ids': tensor([[   1,  633, 3120,  ...,  228, 2085,    2],
        [   1,  658, 1993,  ..., 1455, 1858,    2],
        [   1,   36,  151,  ..., 1972,  154,    2],
        [   1, 3006, 1625,  ...,  940, 2008,    2]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'y': tensor([[[7.7438e-03, 6.2622e-02, 5.7983e-02,  ..., 1.1318e+00,
          0.0000e+00, 6.0352e-01],
         [2.8976e-02, 4.9072e-02, 1.2213e-01,  ..., 3.0289e-03,
          1.9922e+00, 9.5642e-02],
         [5.7098e-02, 8.1726e-02, 7.4951e-02,  ..., 9.8486e-01,
          0.0000e+00, 0.0000e+00],
         ...,
         [1.9958e-02, 3.9764e-02, 1.8021e-02,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.2314e-02, 1.6327e-02, 1.8082e-02,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [5.0621e-03, 6.2622e-02, 1.8860e-02,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00]],

     

In [25]:
batch['input_ids'].shape, batch['y'].shape

(torch.Size([4, 32770]), torch.Size([4, 896, 1643]))

In [26]:
!nvidia-smi

Fri Jul  4 02:41:59 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:4F:00.0 Off |                    0 |
| N/A   29C    P0             65W /  270W |       4MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [27]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel

class RegulatorySignalPredictor(nn.Module):
    def __init__(
        self,
        base_model_name: str,
        max_subsequence_length: int = 4096,
        num_bins: int = 896,
        num_targets: int = 5313,
        gradient_checkpointing: bool = True,
    ):
        super().__init__()
        # 1) load your Llama‐style decoder
        self.base_model = AutoModel.from_pretrained(
            base_model_name, trust_remote_code=True
        )
        if gradient_checkpointing:
            self.base_model.gradient_checkpointing_enable()

        self.chunk_size = max_subsequence_length
        hidden_size = self.base_model.config.hidden_size

        # 2) adaptive pool along the token dimension → exactly `num_bins`
        self.pool = nn.AdaptiveAvgPool1d(num_bins)
        # 3) final head: H → num_targets
        self.head = nn.Linear(hidden_size, num_targets)

    def forward(self,
        input_ids: torch.Tensor,          # [B, 32770]
        attention_mask: torch.Tensor = None
    ):
        B, L = input_ids.shape
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)

        # --- chunk & encode ---
        hidden_chunks = []
        num_chunks = math.ceil(L / self.chunk_size)
        for i in range(num_chunks):
            start = i * self.chunk_size
            end   = min((i + 1) * self.chunk_size, L)

            sub_ids  = input_ids[:, start:end]
            sub_mask = attention_mask[:, start:end]

            # pad last chunk up to chunk_size so model accepts it
            if end - start < self.chunk_size:
                pad_len = self.chunk_size - (end - start)
                sub_ids  = F.pad(sub_ids,  (0, pad_len), value=0)
                sub_mask = F.pad(sub_mask, (0, pad_len), value=0)

            out = self.base_model(
                input_ids=sub_ids,
                attention_mask=sub_mask
            )
            hs = out.last_hidden_state  # → [B, chunk_size, H]

            # drop the padded tokens
            real_len = end - start
            hidden_chunks.append(hs[:, :real_len, :])  # [B, real_len, H]

        # --- reassemble full sequence hidden states ---
        x = torch.cat(hidden_chunks, dim=1)    # [B, L, H]

        # --- pool down to exactly `num_bins` positions ---
        x = x.transpose(1, 2)                 # [B, H, L]
        x = self.pool(x)                      # [B, H, num_bins]
        x = x.transpose(1, 2)                 # [B, num_bins, H]

        # --- project each bin to your 5,313 signals ---
        preds = self.head(x)                  # [B, 896, 5313]

        return {"logits": preds}


In [28]:
model = RegulatorySignalPredictor(
    base_model_name="GenerTeam/GENERator-eukaryote-1.2b-base",
    max_subsequence_length=4096,
    num_bins=896,
    num_targets = 1643
    # num_subsequences=8
)

loading configuration file config.json from cache at /home/wenduoc/.cache/huggingface/hub/models--GenerTeam--GENERator-eukaryote-1.2b-base/snapshots/3be4abf390afbb7f4d8ccb3370f599338523f1cd/config.json
Model config LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 5632,
  "max_position_embeddings": 16384,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 26,
  "num_key_value_heads": 4,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "float32",
  "transformers_version": "4.51.3",
  "use_cache": true,
  "vocab_size": 4128
}

loading weights file model.safetensors from cache at /home/wenduoc/.cache/huggingface/hub/models-

In [29]:
model=model.to(torch.bfloat16).to(device)

In [30]:
model

RegulatorySignalPredictor(
  (base_model): LlamaModel(
    (embed_tokens): Embedding(4128, 2048)
    (layers): ModuleList(
      (0-25): 26 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (

In [31]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"Total parameters:     {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Trainable %:          {100 * trainable_params / total_params:.2f}%")


In [32]:
count_parameters(model)

Total parameters:     1,156,974,187
Trainable parameters: 1,156,974,187
Trainable %:          100.00%


In [33]:
for param in model.base_model.parameters():
    param.requires_grad = False

# Unfreeze the last 8 LLaMA decoder layers
for layer in model.base_model.layers[-6:]:
    for param in layer.parameters():
        param.requires_grad = True


In [34]:
count_parameters(model)

Total parameters:     1,156,974,187
Trainable parameters: 267,632,235
Trainable %:          23.13%


# Train

In [35]:
from tqdm import tqdm
import torch
import torch.nn as nn
import numpy as np
from scipy.stats import pearsonr
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import time
import os
from typing import Dict, Any, Optional
from torch.utils.data import DataLoader
from transformers import PreTrainedModel, PreTrainedTokenizer


def evaluate_model_custom(
    model: PreTrainedModel,
    data_loader: DataLoader,
    device: str,
    criterion: nn.Module
) -> Dict[str, float]:
    """
    Evaluate regression model on a dataset.
    
    Args:
        model: Model to evaluate
        data_loader: DataLoader for evaluation
        device: Device to run evaluation on
        criterion: Loss function (MSE)
    
    Returns:
        Dictionary of evaluation metrics (MSE, MAE, R², PCC)
    """
    model.eval()
    total_loss = 0.0
    all_predictions = []
    all_labels = []
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['y'].to(device)

            # input_ids = input_ids.to(model.dtype)
            # attention_mask = attention_mask.to(model.dtype)
            
            outputs = model(input_ids=input_ids,attention_mask=attention_mask)
            logits = outputs['logits']
            
            # Process labels for regression
            if labels.dim() > 1:
                labels = labels.view(-1).float()
            
            # Calculate MSE loss
            loss = criterion(logits.view(-1), labels)
            
            # Get predictions (no activation function needed for regression)
            predictions = logits.float().cpu().numpy()
     
            total_loss += loss.item()
            
            # Store for metrics
            all_predictions.extend(predictions.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())
                
            num_batches += 1
    
    # Compute average loss
    avg_loss = total_loss / num_batches
    
    # Convert lists to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    
    # Mean Squared Error
    mse = mean_squared_error(all_labels, all_predictions)
    
    # Mean Absolute Error
    mae = mean_absolute_error(all_labels, all_predictions)
    
    # R-squared score
    r2 = r2_score(all_labels, all_predictions)
    
    # Pearson Correlation Coefficient
    pcc, p_value = pearsonr(all_labels, all_predictions)
    
    # Root Mean Squared Error
    rmse = np.sqrt(mse)
    
    return {
        'loss': avg_loss,
        'mse': mse,
        'mae': mae,
        'rmse': rmse,
        'r2': r2,
        'pcc': pcc,
        'pcc_p_value': p_value,
        'num_samples': len(all_predictions)
    }


def train_model_custom(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    train_loader: DataLoader,
    val_loader: DataLoader,
    test_loader: Optional[DataLoader] = None,
    num_epochs: int = 10,
    learning_rate: float = 1e-5,
    weight_decay: float = 0.01,
    warmup_steps: int = 0,
    max_grad_norm: float = 1.0,
    save_dir: str = None,
    save_steps: int = 20,
    early_stopping_patience: int = 5,
    device: str = "cuda",
    use_wandb: bool = False,
    gradient_accumulation_steps: int = 1,

) -> Dict[str, Any]:
    """
    Custom training function for DNA sequence models with gene expression prediction (regression).
    Modified to evaluate after each epoch and save model based on lowest validation loss (MSE).
    """
    model = model.to(device)
    model.train()

    if save_dir is None:
        save_dir = f"/work/magroup/wenduoc/DNALongBench/experiments/GENERator/results/RSAP/mouse"
    os.makedirs(save_dir, exist_ok=True)
    
    train_loader_custom = DataLoader(
        train_loader.dataset,
        batch_size=4,
        collate_fn=lambda b: collate_fn(b, tokenizer)
    )

    val_loader_custom = DataLoader(
        val_loader.dataset,
        batch_size=4,
        collate_fn=lambda b: collate_fn(b, tokenizer)
    )

    if test_loader is not None:
        test_loader_custom = DataLoader(
            test_loader.dataset,
            batch_size=4,
            collate_fn=lambda b: collate_fn(b, tokenizer)
        )

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay
    )

    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        start_factor=0.1,
        end_factor=1.0,
        total_iters=warmup_steps
    )

    # Use MSE loss for regression
    criterion = nn.MSELoss()

    best_val_loss = float('inf')
    best_val_pcc = -1.0  # Track best PCC as well
    patience_counter = 0
    global_step = 0
    training_history = {
        'train_loss': [],
        'val_loss': [],
        'val_pcc': [],
        'val_r2': [],
        'learning_rates': []
    }

    print(f"🚀 Starting regression training for {num_epochs} epochs...")
    print(f"🔧 Gradient accumulation steps: {gradient_accumulation_steps}")
    print(f"📊 Evaluation will occur after each epoch")
    print(f"💾 Model will be saved based on lowest validation MSE loss")

    start_time = time.time()

    for epoch in range(num_epochs):
        print(f"\n{'='*50}")
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print(f"{'='*50}")

        model.train()
        epoch_train_loss = 0.0
        train_steps = 0

        progress_bar = tqdm(train_loader_custom, desc=f"Epoch {epoch + 1}")

        for step, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['y'].to(device)

            # input_ids = input_ids.to(model.dtype)
            # attention_mask = attention_mask.to(model.dtype)

            with torch.cuda.amp.autocast():
                outputs = model(input_ids=input_ids,attention_mask=attention_mask)
                logits = outputs['logits']

                if labels.dim() > 1:
                    labels = labels.view(-1).float()

                # MSE loss for regression
                loss = criterion(logits.view(-1), labels)
                loss = loss / gradient_accumulation_steps

            loss.backward()

            if (step + 1) % gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

            epoch_train_loss += loss.item() * gradient_accumulation_steps
            train_steps += 1

            progress_bar.set_postfix({
                'mse_loss': f"{(loss.item() * gradient_accumulation_steps):.4f}",
                'lr': f"{scheduler.get_last_lr()[0]:.2e}"
            })

            if global_step > 0 and global_step % save_steps == 0:
                checkpoint_path = os.path.join(save_dir, f'checkpoint_step_{global_step}.pt')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'global_step': global_step
                }, checkpoint_path)
                print(f"💾 Intermediate checkpoint saved at step {global_step}")
         

        avg_train_loss = epoch_train_loss / train_steps
        training_history['train_loss'].append(avg_train_loss)
        training_history['learning_rates'].append(scheduler.get_last_lr()[0])

        print(f"\n🔄 Evaluating after epoch {epoch + 1}...")
        val_metrics = evaluate_model_custom(model, val_loader_custom, device, criterion)
        training_history['val_loss'].append(val_metrics['loss'])
        training_history['val_pcc'].append(val_metrics['pcc'])
        training_history['val_r2'].append(val_metrics['r2'])

        print(f"\n📈 Epoch {epoch + 1} Summary:")
        print(f"  Train MSE Loss: {avg_train_loss:.4f}")
        print(f"  Val MSE Loss: {val_metrics['loss']:.4f}")
        print(f"  Val PCC: {val_metrics['pcc']:.4f}")
        print(f"  Val R²: {val_metrics['r2']:.4f}")
        print(f"  Val RMSE: {val_metrics['rmse']:.4f}")
        print(f"  Learning Rate: {scheduler.get_last_lr()[0]:.2e}")

        # Save model based on lowest validation loss
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            best_val_pcc = val_metrics['pcc']
            patience_counter = 0

            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_pcc': best_val_pcc,
                'global_step': global_step
            }, os.path.join(save_dir, 'best_model.pt'))

            print(f"💾 New best model saved! Val MSE: {best_val_loss:.4f}, Val PCC: {best_val_pcc:.4f}")
        else:
            patience_counter += 1
            print(f"⏳ No improvement in MSE loss. Patience: {patience_counter}/{early_stopping_patience}")

        epoch_checkpoint_path = os.path.join(save_dir, f'model_epoch_{epoch + 1}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_metrics['loss'],
            'val_pcc': val_metrics['pcc'],
            'global_step': global_step
        }, epoch_checkpoint_path)
        print(f"💾 Epoch {epoch + 1} checkpoint saved")

        if patience_counter >= early_stopping_patience:
            print(f"🛑 Early stopping triggered after {patience_counter} epochs without MSE loss improvement")
            break

        model.train()

    total_time = time.time() - start_time
    print(f"\n✅ Training completed in {total_time/60:.2f} minutes")
    print(f"🏆 Best validation MSE loss achieved: {best_val_loss:.4f}")
    print(f"🏆 Best validation PCC achieved: {best_val_pcc:.4f}")

    final_metrics = {
        'training_history': training_history, 
        'best_val_loss': best_val_loss,
        'best_val_pcc': best_val_pcc
    }
    
    if test_loader is not None:
        print("\n🧪 Evaluating on test set...")
        test_metrics = evaluate_model_custom(model, test_loader_custom, device, criterion)
        final_metrics['test_metrics'] = test_metrics

        print("📊 Final Test Metrics:")
        print(f"  MSE: {test_metrics['mse']:.4f}")
        print(f"  RMSE: {test_metrics['rmse']:.4f}")
        print(f"  MAE: {test_metrics['mae']:.4f}")
        print(f"  R²: {test_metrics['r2']:.4f}")
        print(f"  PCC: {test_metrics['pcc']:.4f}")
        print(f"  PCC p-value: {test_metrics['pcc_p_value']:.6f}")

    return final_metrics

In [None]:


# Train the model
training_results = train_model_custom(
    model=model,
    tokenizer=tokenizer,
    train_loader=train_loader,
    val_loader=valid_loader,
    test_loader=test_loader,
    num_epochs=3,
    learning_rate=1e-5,
    # batch_size=1,  # Start small due to memory constraints

    device=device,
    use_wandb=False,  
    gradient_accumulation_steps=16,  # Effective batch size = 1 * 8 = 8
)



🚀 Starting regression training for 3 epochs...
🔧 Gradient accumulation steps: 16
📊 Evaluation will occur after each epoch
💾 Model will be saved based on lowest validation MSE loss

Epoch 1/3


  with torch.cuda.amp.autocast():
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Epoch 1: 320it [1:30:58, 17.99s/it, mse_loss=42.2334, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 321it [1:31:19, 18.73s/it, mse_loss=46.7410, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 322it [1:31:39, 19.28s/it, mse_loss=43.8595, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 323it [1:32:00, 19.68s/it, mse_loss=47.5384, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 324it [1:32:20, 19.95s/it, mse_loss=50.6042, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 325it [1:32:41, 20.15s/it, mse_loss=44.9177, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 326it [1:33:02, 20.26s/it, mse_loss=54.7780, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 327it [1:33:22, 20.36s/it, mse_loss=45.6844, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 328it [1:33:43, 20.41s/it, mse_loss=42.1722, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 329it [1:34:03, 20.44s/it, mse_loss=43.9855, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 330it [1:34:24, 20.43s/it, mse_loss=43.7697, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 331it [1:34:44, 20.42s/it, mse_loss=42.9588, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 332it [1:35:04, 20.40s/it, mse_loss=44.9308, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 333it [1:35:25, 20.41s/it, mse_loss=54.1894, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 334it [1:35:45, 20.41s/it, mse_loss=47.5563, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 335it [1:36:06, 20.46s/it, mse_loss=66.5913, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 20


Epoch 1: 640it [3:02:34, 17.98s/it, mse_loss=49.3795, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 641it [3:02:54, 18.75s/it, mse_loss=48.6392, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 642it [3:03:14, 19.24s/it, mse_loss=43.7136, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 643it [3:03:35, 19.63s/it, mse_loss=44.7634, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 644it [3:03:55, 19.84s/it, mse_loss=43.7545, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 645it [3:04:16, 20.04s/it, mse_loss=49.2070, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 646it [3:04:36, 20.17s/it, mse_loss=55.3301, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 647it [3:04:57, 20.21s/it, mse_loss=60.2585, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 648it [3:05:17, 20.28s/it, mse_loss=54.6707, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 649it [3:05:38, 20.36s/it, mse_loss=53.0219, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 650it [3:05:58, 20.38s/it, mse_loss=44.1935, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 651it [3:06:19, 20.40s/it, mse_loss=49.8721, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 652it [3:06:39, 20.40s/it, mse_loss=45.7334, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 653it [3:06:59, 20.41s/it, mse_loss=45.9145, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 654it [3:07:20, 20.43s/it, mse_loss=45.3318, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 655it [3:07:40, 20.37s/it, mse_loss=48.3278, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 40


Epoch 1: 960it [4:34:14, 18.05s/it, mse_loss=60.2607, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 961it [4:34:35, 18.84s/it, mse_loss=44.6147, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 962it [4:34:55, 19.39s/it, mse_loss=40.9417, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 963it [4:35:16, 19.74s/it, mse_loss=46.0247, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 964it [4:35:36, 19.99s/it, mse_loss=45.8553, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 965it [4:35:57, 20.17s/it, mse_loss=51.0840, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 966it [4:36:17, 20.27s/it, mse_loss=40.3653, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 967it [4:36:38, 20.34s/it, mse_loss=62.6900, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 968it [4:36:58, 20.36s/it, mse_loss=46.5550, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 969it [4:37:19, 20.42s/it, mse_loss=43.7741, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 970it [4:37:39, 20.43s/it, mse_loss=53.7671, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 971it [4:38:00, 20.45s/it, mse_loss=59.3719, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 972it [4:38:20, 20.48s/it, mse_loss=48.5702, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 973it [4:38:41, 20.50s/it, mse_loss=45.5905, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 974it [4:39:02, 20.51s/it, mse_loss=69.3925, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 975it [4:39:22, 20.52s/it, mse_loss=47.2092, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 60


Epoch 1: 1280it [6:06:10, 18.07s/it, mse_loss=55.5679, lr=1.00e-06] 

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1281it [6:06:31, 18.84s/it, mse_loss=60.0284, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1282it [6:06:51, 19.35s/it, mse_loss=50.2349, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1283it [6:07:12, 19.69s/it, mse_loss=66.6455, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1284it [6:07:32, 19.90s/it, mse_loss=43.5292, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1285it [6:07:52, 20.07s/it, mse_loss=63.4941, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1286it [6:08:13, 20.23s/it, mse_loss=48.4419, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1287it [6:08:34, 20.33s/it, mse_loss=52.8512, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1288it [6:08:54, 20.40s/it, mse_loss=41.4890, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1289it [6:09:15, 20.39s/it, mse_loss=50.9792, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1290it [6:09:35, 20.40s/it, mse_loss=57.5670, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1291it [6:09:56, 20.44s/it, mse_loss=47.1486, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1292it [6:10:16, 20.44s/it, mse_loss=56.2695, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1293it [6:10:37, 20.46s/it, mse_loss=46.8852, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1294it [6:10:57, 20.47s/it, mse_loss=50.8767, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1295it [6:11:17, 20.48s/it, mse_loss=44.5484, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 80


Epoch 1: 1600it [7:38:02, 18.05s/it, mse_loss=54.8029, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1601it [7:38:23, 18.79s/it, mse_loss=52.2228, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1602it [7:38:43, 19.28s/it, mse_loss=49.4183, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1603it [7:39:04, 19.63s/it, mse_loss=51.5972, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1604it [7:39:24, 19.91s/it, mse_loss=49.3073, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1605it [7:39:45, 20.12s/it, mse_loss=67.8089, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1606it [7:40:05, 20.20s/it, mse_loss=48.8472, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1607it [7:40:26, 20.28s/it, mse_loss=63.3004, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1608it [7:40:46, 20.31s/it, mse_loss=42.4995, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1609it [7:41:07, 20.36s/it, mse_loss=54.7005, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1610it [7:41:27, 20.37s/it, mse_loss=52.3720, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1611it [7:41:47, 20.37s/it, mse_loss=53.0417, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1612it [7:42:08, 20.43s/it, mse_loss=46.6277, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1613it [7:42:28, 20.41s/it, mse_loss=44.2451, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1614it [7:42:49, 20.45s/it, mse_loss=46.3794, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1615it [7:43:09, 20.46s/it, mse_loss=55.4617, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 100


Epoch 1: 1913it [9:07:37, 17.09s/it, mse_loss=43.0247, lr=1.00e-06]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 1: 4164it [19:54:43, 19.94s/it, mse_loss=46.3463, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 260


Epoch 1: 4165it [19:55:03, 20.16s/it, mse_loss=56.5503, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 260


Epoch 1: 4166it [19:55:24, 20.27s/it, mse_loss=45.0199, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 260


Epoch 1: 4167it [19:55:44, 20.38s/it, mse_loss=47.4829, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 260


Epoch 1: 4168it [19:56:05, 20.45s/it, mse_loss=45.5347, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 260


Epoch 1: 4169it [19:56:26, 20.47s/it, mse_loss=43.2639, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 260


Epoch 1: 4170it [19:56:46, 20.47s/it, mse_loss=41.8958, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 260


Epoch 1: 4171it [19:57:07, 20.51s/it, mse_loss=40.7135, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 260


Epoch 1: 4172it [19:57:27, 20.50s/it, mse_loss=49.3059, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 260


Epoch 1: 4173it [19:57:48, 20.47s/it, mse_loss=54.0556, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 260


Epoch 1: 4174it [19:58:08, 20.48s/it, mse_loss=43.3215, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 260


Epoch 1: 4175it [19:58:29, 20.51s/it, mse_loss=48.0938, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 260


Epoch 1: 4415it [21:06:49, 17.12s/it, mse_loss=46.6707, lr=1.00e-06]

In [35]:
checkpoint = torch.load(f"/work/magroup/wenduoc/DNALongBench/experiments/GENERator/results/RSAP/mouse/checkpoint_step_320.pt", map_location=device)

model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [38]:
import math
import torch
import numpy as np
from tqdm import tqdm
from typing import Dict
from scipy.stats import t as t_dist

def evaluate_model_custom(
    model: PreTrainedModel,
    data_loader: DataLoader,
    device: str,
    criterion: torch.nn.Module
) -> Dict[str, float]:
    """
    Streaming evaluation: regression metrics (MSE, MAE, RMSE, R², PCC) 
    without storing all predictions/labels.
    """
    model.eval()

    # Streaming accumulators
    n = 0
    sum_loss     = 0.0
    sum_sq_err   = 0.0
    sum_abs_err  = 0.0
    sum_pred     = 0.0
    sum_lbl      = 0.0
    sum_pred_sq  = 0.0
    sum_lbl_sq   = 0.0
    sum_pred_lbl = 0.0

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            # forward
            input_ids      = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels         = batch['y'].view(-1).float().to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds_t = outputs['logits'].view(-1)
            loss_b  = criterion(preds_t, labels).item()

            # to numpy
            preds = preds_t.float().cpu().numpy()
            lbls  = labels.cpu().numpy()

            # batch stats
            b = len(lbls)
            n += b
            sum_loss   += loss_b * b
            diff        = preds - lbls
            sum_sq_err += np.sum(diff**2)
            sum_abs_err+= np.sum(np.abs(diff))
            sum_pred   += np.sum(preds)
            sum_lbl    += np.sum(lbls)
            sum_pred_sq+= np.sum(preds**2)
            sum_lbl_sq += np.sum(lbls**2)
            sum_pred_lbl += np.sum(preds * lbls)

    # finalize metrics
    mse  = sum_sq_err   / n
    mae  = sum_abs_err  / n
    rmse = math.sqrt(mse)

    # R² = 1 - SS_res/SS_tot
    ss_tot = sum_lbl_sq - (sum_lbl**2) / n
    r2     = 1 - (sum_sq_err / ss_tot) if ss_tot > 0 else float('nan')

    # Pearson r
    num   = n * sum_pred_lbl - sum_pred * sum_lbl
    den   = math.sqrt((n * sum_pred_sq - sum_pred**2) * (n * sum_lbl_sq - sum_lbl**2))
    pcc   = num / den if den > 0 else 0.0

    # p-value (two-tailed) via t-distribution
    if n > 2:
        t_stat = pcc * math.sqrt((n - 2) / (1 - pcc**2))
        p_val  = 2 * (1 - t_dist.cdf(abs(t_stat), df=n-2))
    else:
        p_val = float('nan')

    return {
        'loss': sum_loss / n,
        'mse':  mse,
        'mae':  mae,
        'rmse': rmse,
        'r2':   r2,
        'pcc':  pcc,
        'pcc_p_value': p_val,
        'num_samples': n
    }


In [None]:
# Final evaluation on test set if provided
final_metrics = {}

test_loader_custom = DataLoader(
            test_loader.dataset,
            batch_size=2,
            collate_fn=lambda b: collate_fn(b, tokenizer)
        )

criterion = nn.MSELoss()

print("\n🧪 Evaluating on test set...")
test_metrics = evaluate_model_custom(model, test_loader_custom, device, criterion)
final_metrics['test_metrics'] = test_metrics

print("📊 Final Test Metrics:")
for key, value in test_metrics.items():
    print(f"  {key}: {value:.4f}")

print(final_metrics)


🧪 Evaluating on test set...


Evaluating: 0it [00:00, ?it/s]