In [1]:
! nvidia-smi

Mon Jun 30 00:34:23 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:4F:00.0 Off |                    0 |
| N/A   51C    P0             77W /  270W |       1MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
import argparse
import os
import time
from typing import Dict, Tuple, Union, Optional, Callable, List, Any
from torch.utils.data import Dataset, DataLoader, Subset
import torch.nn as nn
import numpy as np
import torch
import torch.distributed as dist
import transformers
import yaml
from datasets import (
    Dataset,
    load_dataset,
    DatasetDict,
    IterableDatasetDict,
    IterableDataset,
)
from datasets import Dataset as HFDataset, DatasetDict
from sklearn.metrics import f1_score, matthews_corrcoef
from sklearn.model_selection import KFold
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    PreTrainedTokenizer,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    DataCollatorWithPadding,
    PreTrainedModel,
    AutoConfig,
    AutoModel  
)




  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  warn(
2025-07-08 21:56:08.546377: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-08 21:56:09.478158: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-07-08 21:56:09.746711: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-07-08 21:56:09.849700: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory f

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device 

In [None]:
root = '/work/magroup/shared/DNA_LLM/DNALongBench/'

In [None]:
import torch
import dnalongbench
from dnalongbench.utils import load_data

In [None]:
train_loader, valid_loader, test_loader = load_data(root = root, task_name = 'eqtl_prediction', organism = None, cell_type = 'Adipose_Subcutaneous', batch_size = 1)

In [None]:
for batch in train_loader: 
        print('x_ref:', batch['x_ref'].size())
        print('x_alt', batch['x_alt'].size())
        print('y:',batch['y'].size())
        break

In [None]:
def collate_fn(batch, tokenizer, max_length=450_000):
    """
    Ultra-fast version with further optimizations for very long sequences.
    """
    nucleotides = np.array(['A', 'C', 'G', 'T'], dtype='U1')  # Single character strings
    
    def one_hot_to_sequence_ultra(one_hot_array):
        """Ultra-fast conversion using numpy operations"""
        # Check for N positions more efficiently
        row_sums = np.sum(one_hot_array, axis=1)
        n_mask = np.abs(row_sums - 1.0) > 1e-6  # N positions sum to ~1.0, others sum to 1.0
        
        # Get argmax indices
        max_indices = np.argmax(one_hot_array, axis=1)
        
        # Create sequence array
        sequence_array = nucleotides[max_indices]
        
        # Set N positions
        if np.any(n_mask):
            sequence_array[n_mask] = 'N'
        
        # Fast join using numpy
        return sequence_array.tobytes().decode('ascii')
    
    # Process batch with minimal Python loops
    sequences_data = []
    y_values = []
    
    for item in batch:
        x_ref_seq = one_hot_to_sequence_ultra(item['x_ref'])
        x_alt_seq = one_hot_to_sequence_ultra(item['x_alt'])
        sequences_data.append((x_ref_seq, x_alt_seq))
        y_values.append(item['y'])
    
    # Separate sequences for tokenization
    x_ref_sequences, x_alt_sequences = zip(*sequences_data)
    
    # Tokenize in parallel if possible
    x_ref_tokenized = tokenizer(
        list(x_ref_sequences),
        max_length=max_length,
        truncation=True,
        padding=True,
        return_tensors='pt'
    )
    
    x_alt_tokenized = tokenizer(
        list(x_alt_sequences),
        max_length=max_length,
        truncation=True,
        padding=True,
        return_tensors='pt'
    )
    
    # Convert y values to tensor
    y_batch = torch.tensor([y.item() if hasattr(y, 'item') else y for y in y_values])
    
    return {
        'x_ref_input_ids': x_ref_tokenized["input_ids"],
        'x_ref_attention_mask': x_ref_tokenized["attention_mask"],
        'x_alt_input_ids': x_alt_tokenized["input_ids"],
        'x_alt_attention_mask': x_alt_tokenized["attention_mask"],
        'y': y_batch
    }

In [None]:

# Set logging level for transformers
transformers.logging.set_verbosity_info()

# Define optimization direction for each metric (whether higher or lower is better)
METRICS_DIRECTION: Dict[str, str] = {
    "accuracy": "max",
    "f1_score": "max",
    "mcc": "max",
    "f1_max": "max",
    "auprc_micro": "max",
    "mse": "min",
    "mae": "min",
    "r2": "max",
    "pearson": "max",
}


def is_main_process() -> bool:
    """
    Check if current process is the main process (rank 0) in distributed training.

    Returns:
        bool: True if this is the main process, False otherwise
    """
    if dist.is_initialized():
        return dist.get_rank() == 0
    return True


def dist_print(*args, **kwargs) -> None:
    """
    Print only from the main process (rank 0) in distributed training.
    Prevents duplicate outputs in multi-GPU settings.

    Args:
        *args: Arguments to pass to print function
        **kwargs: Keyword arguments to pass to print function
    """
    if is_main_process():
        print(*args, **kwargs)


In [None]:
def setup_tokenizer(
    model_name: str, padding_and_truncation_side: str
) -> PreTrainedTokenizer:
    """
    Load and configure tokenizer for sequence understanding.

    Args:
        model_name: Name or path of the HuggingFace model
        padding_and_truncation_side: Side for padding and truncation (left or right)

    Returns:
        PreTrainedTokenizer: Configured tokenizer for the model
    """
    dist_print(f"🔤 Loading tokenizer from: {model_name}")
    start_time = time.time()

    # Load tokenizer with trust_remote_code to support custom models
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    # Configure padding and truncation settings
    tokenizer.padding_side = padding_and_truncation_side
    tokenizer.truncation_side = padding_and_truncation_side

    # Set pad_token to eos_token if not defined
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    dist_print(
        f"⏱️ Tokenizer loading completed in {time.time() - start_time:.2f} seconds"
    )

    return tokenizer

In [None]:
tokenizer = setup_tokenizer("GenerTeam/GENERator-eukaryote-1.2b-base", 'right')

In [None]:
tokenizer

In [None]:
train_loader2 = DataLoader(
        train_loader.dataset,
        batch_size=1,
        collate_fn=lambda b: collate_fn(b, tokenizer, max_length=450_000)
    )


In [None]:
for batch in train_loader2: 
        print(batch)
        break


In [15]:
# import torch
# import torch.nn as nn
# from transformers import AutoModel

# class LongSequenceClassificationModel(nn.Module):
#     def __init__(
#         self,
#         base_model_name: str,
#         num_labels: int = 2,
#         max_subsequence_length: int = 9375,
#         num_subsequences: int = 8,
#         gradient_checkpointing: bool = True
#     ):
#         super().__init__()
#         self.base_model = AutoModel.from_pretrained(
#             base_model_name,
#             trust_remote_code=True
#         )
#         if gradient_checkpointing:
#             self.base_model.gradient_checkpointing_enable()

#         self.max_subsequence_length = max_subsequence_length
#         self.num_subsequences = num_subsequences

#         # head projects concatenated [CLS] embeddings → logits
#         hidden_size = self.base_model.config.hidden_size * self.num_subsequences
#         self.classification_head = nn.Linear(hidden_size, num_labels, bias=False)

#     def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor):
#         """
#         Args:
#             input_ids:      [batch_size, seq_len]
#             attention_mask: [batch_size, seq_len]
#         Returns:
#             {
#               "logits": torch.FloatTensor [batch_size, num_labels],
#               "hidden_states": torch.FloatTensor [batch_size, num_subseqs * hidden_size]
#             }
#         """
#         batch_size = input_ids.size(0)
#         seq_states = []

#         # slice into chunks, run each through base_model, grab its [CLS] token
#         for i in range(self.num_subsequences):
#             start = i * self.max_subsequence_length
#             end   = (i + 1) * self.max_subsequence_length

#             chunk_ids   = input_ids[:, start:end]
#             chunk_mask  = attention_mask[:, start:end]

#             out = self.base_model(
#                 input_ids=chunk_ids,
#                 attention_mask=chunk_mask
#             )
#             # out.last_hidden_state: [B, chunk_len, hidden_size]
#             # take the final token’s embedding as “CLS”
#             cls_emb = out.last_hidden_state[:, -1, :]  # [B, hidden_size]
#             seq_states.append(cls_emb)

#         # concatenate all CLS embeddings: [B, num_subseqs * hidden_size]
#         combined_hidden = torch.cat(seq_states, dim=-1)

#         logits = self.classification_head(combined_hidden)  # [B, num_labels]

#         return {
#             "logits": logits,
#             "hidden_states": combined_hidden
#         }


In [16]:
import torch
import torch.nn as nn
from transformers import AutoModel

class EqtlSiameseModel(nn.Module):
    def __init__(
        self,
        base_model_name: str,
        num_labels: int = 2,
        max_subsequence_length: int = 9375,
        num_subsequences: int = 8,
        gradient_checkpointing: bool = True
    ):
        super().__init__()
        # shared encoder
        self.encoder = AutoModel.from_pretrained(
            base_model_name, trust_remote_code=True
        )
        if gradient_checkpointing:
            self.encoder.gradient_checkpointing_enable()

        self.max_sub_len = max_subsequence_length
        self.num_subseqs = num_subsequences
        hidden_size = self.encoder.config.hidden_size * self.num_subseqs

        # [allele; ref; |allele–ref|] → logits
        self.classification_head = nn.Linear(3 * hidden_size, num_labels, bias=False)

    def _encode(self, input_ids: torch.LongTensor):
        """
        Break into chunks, encode each, grab final token embedding,
        concat along seq‐chunks.
        """
        seq_states = []
        for i in range(self.num_subseqs):
            start = i * self.max_sub_len
            end   = (i + 1) * self.max_sub_len

            chunk_ids = input_ids[:, start:end]
            # create a full‐ones mask so every token is attended
            chunk_mask = torch.ones_like(chunk_ids)

            out = self.encoder(input_ids=chunk_ids, attention_mask=chunk_mask)
            # final token as CLS proxy
            cls_emb = out.last_hidden_state[:, -1, :]  # [B, hidden]
            seq_states.append(cls_emb)

        return torch.cat(seq_states, dim=-1)  # [B, num_subseqs*hidden]

    def forward(
        self,
        x_alt: torch.LongTensor,   # your “allele” seqs
        x_ref: torch.LongTensor,   # your “reference” seqs
    ):
        emb_alt = self._encode(x_alt)
        emb_ref = self._encode(x_ref)

        delta = torch.abs(emb_alt - emb_ref)
        features = torch.cat([emb_alt, emb_ref, delta], dim=-1)  # [B, 3*H]
        logits = self.classification_head(features)
        return {"logits": logits}


In [17]:
# model = EqtlSiameseModel(
#     base_model_name="GenerTeam/GENERator-eukaryote-1.2b-base",
#     num_labels=2,
#     max_subsequence_length=9375,
#     num_subsequences=8
# )

In [18]:
class LongSequenceClassificationModel(nn.Module):
    def __init__(self, base_model_name, num_labels=2, max_subsequence_length=9375, num_subsequences=8, gradient_checkpointing=True):
        super().__init__()
        self.base_model = AutoModel.from_pretrained(base_model_name, trust_remote_code=True)
        self.classification_head = nn.Linear(num_subsequences * self.base_model.config.hidden_size, num_labels, bias=False)
        if gradient_checkpointing:
            self.base_model.gradient_checkpointing_enable()
        self.max_subsequence_length = max_subsequence_length
        self.num_subsequences = num_subsequences

    # def forward(self, input_ids, attention_mask, labels=None):
    def forward(self, input_ids, attention_mask):
        batch_size = input_ids.size(0)
        hidden_states = []

        for i in range(self.num_subsequences):
            start_idx = i * self.max_subsequence_length
            end_idx = (i + 1) * self.max_subsequence_length
            sub_input_ids = input_ids[:, start_idx:end_idx]
            sub_attention_mask = attention_mask[:, start_idx:end_idx]

            outputs = self.base_model(input_ids=sub_input_ids, attention_mask=sub_attention_mask)
            last_hidden_state = outputs.last_hidden_state
            cls_embedding = last_hidden_state[:, -1, :]
            hidden_states.append(cls_embedding)

        combined_hidden_states = torch.cat(hidden_states, dim=-1)
        logits = self.classification_head(combined_hidden_states)

        # loss = None
        # if labels is not None:
        #     loss_fn = nn.CrossEntropyLoss()
        #     loss = loss_fn(logits, labels)

        # return {"logits": logits, "loss": loss}
        return {"logits": logits}


In [19]:
model = LongSequenceClassificationModel(
    base_model_name="GenerTeam/GENERator-eukaryote-1.2b-base",
    num_labels=2,
    max_subsequence_length=9375,
    num_subsequences=8
)

loading configuration file config.json from cache at /home/wenduoc/.cache/huggingface/hub/models--GenerTeam--GENERator-eukaryote-1.2b-base/snapshots/3be4abf390afbb7f4d8ccb3370f599338523f1cd/config.json
Model config LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 5632,
  "max_position_embeddings": 16384,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 26,
  "num_key_value_heads": 4,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "float32",
  "transformers_version": "4.51.3",
  "use_cache": true,
  "vocab_size": 4128
}

loading weights file model.safetensors from cache at /home/wenduoc/.cache/huggingface/hub/models-

In [2]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"Total parameters:     {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Trainable %:          {100 * trainable_params / total_params:.2f}%")


In [21]:
model=model.to(torch.bfloat16).to(device)

In [22]:
model

LongSequenceClassificationModel(
  (base_model): LlamaModel(
    (embed_tokens): Embedding(4128, 2048)
    (layers): ModuleList(
      (0-25): 26 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)

In [1]:
count_parameters(model)

NameError: name 'count_parameters' is not defined

In [23]:
count_parameters(model)

Total parameters:     1,153,640,448
Trainable parameters: 1,153,640,448
Trainable %:          100.00%


In [24]:
for param in model.base_model.parameters():
    param.requires_grad = False

# Unfreeze the last 8 LLaMA decoder layers
for layer in model.base_model.layers[-6:]:
    for param in layer.parameters():
        param.requires_grad = True


In [25]:
count_parameters(model)

Total parameters:     1,153,640,448
Trainable parameters: 264,298,496
Trainable %:          22.91%


In [26]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import PreTrainedModel, PreTrainedTokenizer
from typing import Dict, Any, Optional, Callable
import time
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
import os

In [27]:
import os
import glob
import time
from tqdm import tqdm
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    roc_auc_score,
    average_precision_score
)

def train_model_custom(
    model:       torch.nn.Module,
    tokenizer,  # (unused here but kept for collate_fn)
    train_loader,
    val_loader,
    test_loader=None,
    num_epochs: int = 10,
    learning_rate: float = 1e-4,
    weight_decay:  float = 0.01,
    warmup_steps:  int = 0,
    max_grad_norm: float = 1.0,
    save_dir:     str = "/work/magroup/wenduoc/DNALongBench/experiments/GENERator/results/EQTL/altseq",
    eval_steps:   int = 40,
    device:       str = "cuda",
    gradient_accumulation_steps: int = 1,
) -> dict:
    
    os.makedirs(save_dir, exist_ok=True)
    model = model.to(device)

    # wrap datasets with your collate_fn (returns x_alt, x_ref, y)
    train_loader = DataLoader(train_loader.dataset, batch_size=1,
                              collate_fn=lambda b: collate_fn(b, tokenizer))
    val_loader   = DataLoader(val_loader.dataset,   batch_size=1,
                              collate_fn=lambda b: collate_fn(b, tokenizer))
    if test_loader is not None:
        test_loader = DataLoader(test_loader.dataset, batch_size=1,
                                  collate_fn=lambda b: collate_fn(b, tokenizer))

    # ——— set up optimizer, scheduler, loss fn
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps
    )
    criterion = nn.CrossEntropyLoss()

    # ——— resume if possible
    ckpts = sorted(glob.glob(os.path.join(save_dir, "checkpoint-step-*.pt")))
    if ckpts:
        latest = ckpts[-1]
        print(f"⏳ Resuming from {latest}")
        chk = torch.load(latest, map_location=device)
        model.load_state_dict(chk["model_state_dict"])
        optimizer.load_state_dict(chk["optimizer_state_dict"])
        scheduler.load_state_dict(chk["scheduler_state_dict"])
        start_epoch = chk["epoch"]
        global_step = chk["step"]
    else:
        start_epoch = 0
        global_step = 0

    best_auroc = 0.0
    history = {
        'train_loss': [], 'val_loss_steps': [], 'val_auroc_steps': [],
        'epoch_val_loss': [], 'epoch_val_auroc': [], 'learning_rates': []
    }

    print(f"🚀 Training for {num_epochs} epochs (resume at epoch {start_epoch+1}), step‐eval every {eval_steps} steps.")

    start_time = time.time()
    for epoch in range(start_epoch, num_epochs):
        model.train()
        epoch_loss = 0.0
        num_batches = 0

        print(f"\n===== Epoch {epoch+1}/{num_epochs} =====")
        for step, batch in enumerate(tqdm(train_loader, desc="train"), start=1):
        
            input_ids      = batch['x_alt_input_ids'].to(device)
            attention_mask = batch['x_alt_attention_mask'].to(device)
            labels         = batch['y'].long().to(device)

            with torch.cuda.amp.autocast():
                logits = model(input_ids=input_ids, attention_mask=attention_mask)['logits']
                loss   = criterion(logits, labels) / gradient_accumulation_steps

            loss.backward()
            if step % gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

                # ——— Step‐level eval & checkpoint
                if eval_steps and global_step % eval_steps == 0:
                    print(f"\n🔄 Step {global_step} eval…")
                    vm = evaluate_model_custom(model, val_loader, device)
                    loss_s, auroc_s = vm['loss'], vm['auroc']
                    history['val_loss_steps'].append(loss_s)
                    history['val_auroc_steps'].append(auroc_s)
                    print(f"  AUROC {auroc_s:.4f} | Loss {loss_s:.4f}")

                    # save regular checkpoint
                    ckpt_path = os.path.join(save_dir, f"checkpoint-step-{global_step}.pt")
                    torch.save({
                        "epoch": epoch,
                        "step": global_step,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict()
                    }, ckpt_path)
                    print(f"💾 Saved checkpoint: {ckpt_path}")

                    # update best
                    if auroc_s > best_auroc:
                        best_auroc = auroc_s
                        best_path = os.path.join(save_dir, "best_model.pt")
                        torch.save(model.state_dict(), best_path)
                        print(f"🏆 New best at step {global_step}: {best_path}")

            epoch_loss += loss.item() * gradient_accumulation_steps
            num_batches += 1

        # ——— record train stats
        history['train_loss'].append(epoch_loss / num_batches)
        history['learning_rates'].append(scheduler.get_last_lr()[0])

        # ——— Epoch‐level eval & checkpoint
        print(f"\n🔄 Epoch {epoch+1} eval…")
        vm = evaluate_model_custom(model, val_loader, device)
        loss_e, auroc_e = vm['loss'], vm['auroc']
        history['epoch_val_loss'].append(loss_e)
        history['epoch_val_auroc'].append(auroc_e)
        print(f"  AUROC {auroc_e:.4f} | Loss {loss_e:.4f}")

        if auroc_e > best_auroc:
            best_auroc = auroc_e
            best_path = os.path.join(save_dir, "best_model.pt")
            torch.save(model.state_dict(), best_path)
            print(f"🏆 New best at epoch {epoch+1}: {best_path}")

    elapsed = (time.time() - start_time) / 60
    print(f"\n✅ Done in {elapsed:.2f} min – best AUROC {best_auroc:.4f}")

    results = {'training_history': history, 'best_val_auroc': best_auroc}

    if test_loader is not None:
        print("\n🧪 Final test eval…")
        tm = evaluate_model_custom(model, test_loader, device)
        results['test_metrics'] = tm
        print(f"  Test AUROC {tm['auroc']:.4f} | Loss {tm['loss']:.4f}")

    return results


def evaluate_model_custom(model, data_loader, device: str) -> dict:
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    total_loss, all_labels, all_preds, all_probs = 0.0, [], [], []
    batches = 0

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids      = batch['x_alt_input_ids'].to(device)
            attention_mask = batch['x_alt_attention_mask'].to(device)
            labels         = batch['y'].view(-1).long().to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits  = outputs['logits']
            loss    = loss_fn(logits, labels)

            total_loss += loss.item()
            probs = torch.softmax(logits.float(), dim=-1)[:, 1].cpu().numpy()
            preds = (probs > 0.5).astype(int)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds)
            all_probs.extend(probs)
            batches += 1

    avg_loss = total_loss / batches
    all_labels = np.array(all_labels)
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average='binary', zero_division=0
    )
    auroc = roc_auc_score(all_labels, all_probs)
    auprc = average_precision_score(all_labels, all_probs)

    return {
        'loss': avg_loss,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auroc': auroc,
        'auprc': auprc,
        'num_samples': len(all_labels)
    }


In [28]:
# def train_model_custom(
#     model: nn.Module,
#     tokenizer: PreTrainedTokenizer,       # still used by collate_fn
#     train_loader: DataLoader,
#     val_loader: DataLoader,
#     test_loader: Optional[DataLoader] = None,
#     num_epochs: int = 10,
#     learning_rate: float = 1e-4,
#     weight_decay: float = 0.01,
#     warmup_steps: int = 0,
#     max_grad_norm: float = 1.0,
#     save_dir: str = "/work/magroup/wenduoc/DNALongBench/experiments/GENERator/results/EQTL/altseq",
#     eval_steps: int = 10,
#     device: str = "cuda",
#     gradient_accumulation_steps: int = 1,
# ) -> Dict[str, Any]:
#     model = model.to(device)
#     model.train()

#     # wrap datasets with your collate_fn (returns x_alt, x_ref, y)
#     train_loader = DataLoader(train_loader.dataset, batch_size=1,
#                               collate_fn=lambda b: collate_fn(b, tokenizer))
#     val_loader   = DataLoader(val_loader.dataset,   batch_size=1,
#                               collate_fn=lambda b: collate_fn(b, tokenizer))
#     if test_loader is not None:
#         test_loader = DataLoader(test_loader.dataset, batch_size=1,
#                                   collate_fn=lambda b: collate_fn(b, tokenizer))

#     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
#     scheduler = torch.optim.lr_scheduler.LinearLR(
#         optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps
#     )
#     criterion = nn.CrossEntropyLoss()

#     best_auroc   = 0.0
#     global_step  = 0
#     best_ckpt    = os.path.join(save_dir, "best_model.pt")
#     history = {
#         'train_loss': [],
#         'val_loss_steps': [],
#         'val_auroc_steps': [],
#         'epoch_val_loss': [],
#         'epoch_val_auroc': [],
#         'learning_rates': [],
#     }

#     print(f"🚀 Training for {num_epochs} epochs, step‐eval every {eval_steps} steps, epoch‐eval each epoch.")
#     start_time = time.time()

#     for epoch in range(num_epochs):
#         print(f"\n===== Epoch {epoch+1}/{num_epochs} =====")
#         model.train()
#         epoch_loss   = 0.0
#         num_batches  = 0

#         for step, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
#             alt_ids = batch['x_alt_input_ids'].to(device)
#             ref_ids = batch['x_ref_input_ids'].to(device)
#             labels  = batch['y'].long().to(device)

#             with torch.cuda.amp.autocast():
#                 outputs = model(alt_ids, ref_ids)
#                 logits  = outputs['logits']
#                 loss    = criterion(logits, labels) / gradient_accumulation_steps

#             loss.backward()
#             if (step + 1) % gradient_accumulation_steps == 0:
#                 torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
#                 optimizer.step()
#                 scheduler.step()
#                 optimizer.zero_grad()
#                 global_step += 1

#                 # — Step‐level eval & checkpoint —
#                 if eval_steps and global_step % eval_steps == 0:
#                     print(f"\n🔄 Step {global_step} eval…")
#                     vm = evaluate_eqtl_model(model, val_loader, device, criterion)
#                     auroc_s, loss_s = vm['auc'], vm['loss']
#                     history['val_auroc_steps'].append(auroc_s)
#                     history['val_loss_steps'].append(loss_s)
#                     print(f"  AUROC {auroc_s:.4f} | Loss {loss_s:.4f}")

#                     if auroc_s > best_auroc:
#                         best_auroc = auroc_s
#                         torch.save(model.state_dict(), best_ckpt)
#                         print(f"🏆 New best at step {global_step}: {best_ckpt}")

#             epoch_loss  += loss.item() * gradient_accumulation_steps
#             num_batches += 1

#         # record train stats
#         history['train_loss'].append(epoch_loss / num_batches)
#         history['learning_rates'].append(scheduler.get_last_lr()[0])

#         # — Epoch‐level eval & checkpoint —
#         print(f"\n🔄 Epoch {epoch+1} eval…")
#         vm = evaluate_eqtl_model(model, val_loader, device, criterion)
#         auroc_e, loss_e = vm['auc'], vm['loss']
#         history['epoch_val_auroc'].append(auroc_e)
#         history['epoch_val_loss'].append(loss_e)
#         print(f"  AUROC {auroc_e:.4f} | Loss {loss_e:.4f}")

#         if auroc_e > best_auroc:
#             best_auroc = auroc_e
#             torch.save(model.state_dict(), best_ckpt)
#             print(f"🏆 New best at epoch {epoch+1}: {best_ckpt}")

#     elapsed = (time.time() - start_time) / 60
#     print(f"\n✅ Done in {elapsed:.2f}min – best AUROC {best_auroc:.4f}")

#     results = {
#         'training_history': history,
#         'best_val_auroc': best_auroc,
#     }

#     if test_loader is not None:
#         print("\n🧪 Final test eval…")
#         tm = evaluate_eqtl_model(model, test_loader, device)
#         results['test_metrics'] = tm
#         print(f"  Test AUROC {tm['auc']:.4f} | Loss {tm['loss']:.4f}")

#     return results






In [29]:
# from tqdm import tqdm
# import torch
# import numpy as np
# from sklearn.metrics import (
#     accuracy_score,
#     precision_recall_fscore_support,
#     roc_auc_score,
#     average_precision_score
# )

# def evaluate_eqtl_model(
#     model: nn.Module,
#     data_loader: DataLoader,
#     device: str,
#     criterion: nn.Module
# ) -> Dict[str, float]:
#     """
#     Evaluate a Siamese eQTL classification model on a dataset.
#     Assumes model(batch['x_alt'], batch['x_ref']) → {'logits': Tensor[B,2]}.
    
#     Args:
#         model:        Siamese eQTL model
#         data_loader:  DataLoader yielding {'x_alt', 'x_ref', 'y'}
#         device:       torch device
#         criterion:    loss function (e.g. CrossEntropyLoss())
    
#     Returns:
#         dict with loss, accuracy, precision, recall, f1, auc, auprc, num_samples
#     """
#     model.eval()
#     total_loss = 0.0
#     all_preds = []
#     all_labels = []
#     all_probs = []
#     batches = 0

#     with torch.no_grad():
#         for batch in tqdm(data_loader, desc="Evaluating"):
#             x_alt = batch['x_alt_input_ids'].to(device)
#             x_ref = batch['x_ref_input_ids'].to(device)
#             labels = batch['y'].long().to(device)

#             # forward pass
#             outputs = model(x_alt, x_ref)
#             logits = outputs['logits']      # [B, 2]
#             loss = criterion(logits, labels)
#             total_loss += loss.item()

#             # probabilities for class=1
#             probs = torch.softmax(logits.float(), dim=1)[:, 1].cpu().numpy()
#             preds = (probs > 0.5).astype(int)

#             all_labels.extend(labels.cpu().numpy())
#             all_preds.extend(preds)
#             all_probs.extend(probs)
#             batches += 1

#     # average loss
#     avg_loss = total_loss / batches

#     # convert to numpy arrays
#     all_labels = np.array(all_labels)
#     all_preds  = np.array(all_preds)
#     all_probs  = np.array(all_probs)

#     # compute metrics
#     accuracy = accuracy_score(all_labels, all_preds)
#     precision, recall, f1, _ = precision_recall_fscore_support(
#         all_labels, all_preds, average='binary', zero_division=0
#     )
#     auc   = roc_auc_score(all_labels, all_probs)
#     auprc = average_precision_score(all_labels, all_probs)

#     return {
#         'loss': avg_loss,
#         'accuracy': accuracy,
#         'precision': precision,
#         'recall': recall,
#         'f1': f1,
#         'auc': auc,
#         'auprc': auprc,
#         'num_samples': len(all_labels)
#     }


In [None]:


# Train the model
training_results = train_model_custom(
    model=model,
    tokenizer=tokenizer,
    train_loader=train_loader,
    val_loader=valid_loader,
    test_loader=test_loader,
    num_epochs=3,
    learning_rate=1e-5,
    # batch_size=1,  # Start small due to memory constraints
    device=device,
    gradient_accumulation_steps=16,  # Effective batch size = 1 * 8 = 8
)



⏳ Resuming from /work/magroup/wenduoc/DNALongBench/experiments/GENERator/results/EQTL/altseq/checkpoint-step-40.pt
🚀 Training for 5 epochs (resume at epoch 1), step‐eval every 40 steps.

===== Epoch 1/5 =====


  with torch.cuda.amp.autocast():
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
train: 393it [2:38:01, 24.05s/it]

In [None]:
# Final evaluation on test set if provided
final_metrics = {}

test_loader_custom = DataLoader(
            test_loader.dataset,
            batch_size=1,
            collate_fn=lambda b: collate_fn(b, tokenizer)
        )

criterion = nn.CrossEntropyLoss()

print("\n🧪 Evaluating on test set...")
test_metrics = evaluate_eqtl_model(model, test_loader_custom, device, criterion)
final_metrics['test_metrics'] = test_metrics

print("📊 Final Test Metrics:")
for key, value in test_metrics.items():
    print(f"  {key}: {value:.4f}")

print(final_metrics)