In [1]:
! nvidia-smi

Wed Jun 11 16:49:07 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:52:00.0 Off |                    0 |
| N/A   29C    P0             45W /  270W |       1MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import argparse
import os
import time
from typing import Dict, Tuple, Union, Optional, Callable, List, Any
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import torch
import torch.distributed as dist
import transformers
import yaml
from datasets import (
    Dataset,
    load_dataset,
    DatasetDict,
    IterableDatasetDict,
    IterableDataset,
)
from datasets import Dataset as HFDataset, DatasetDict
from sklearn.metrics import f1_score, matthews_corrcoef
from sklearn.model_selection import KFold
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    PreTrainedTokenizer,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    DataCollatorWithPadding,
    PreTrainedModel,
    AutoConfig,
)


import dnalongbench
from dnalongbench.utils import load_data

  warn(
2025-06-11 16:49:12.963285: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-11 16:49:12.975864: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-11 16:49:12.991997: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-11 16:49:12.996764: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-11 16:49:13.008016: I tensorflow/core/platform/cpu_feat

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device 

device(type='cuda')

# Data

In [4]:
root = '/work/magroup/shared/DNA_LLM/DNALongBench/'

In [None]:
train_loader, valid_loader, test_loader = load_data(root = root, task_name = 'enhancer_target_gene_prediction', organism = None, cell_type = None, batch_size = 1)

> load config done
> init fasta extractor done
> Start parsing EPI records to build the dataset train


100%|██████████| 2602/2602 [00:18<00:00, 142.43it/s]


# Finish parsing EPI records
# Total records:  2602
# Skipped records due to different chromosomes:  0
# Skipped records due to distance cutoff:  0
# Skipped records due to unknown strand:  0
# Select records 2066 with subset train 
> load config done
> init fasta extractor done
> Start parsing EPI records to build the dataset valid


 14%|█▍        | 365/2602 [00:00<00:02, 828.56it/s] 

In [None]:
for batch in train_loader: 
        x, y = batch
        print('x:',x.size())
        print('y:',y.size())
        break


In [None]:
def collate_fn(batch, tokenizer, max_length=450000):
    """
    Custom collate function for DNA data that converts one-hot encoded sequences to raw sequences
    and tokenizes them.
    
    Args:
        batch: List of tuples where each tuple is (x, y)
               x is one-hot encoded DNA sequence of shape (seq_len, 4)
               y is gene expression data of shape (10, seq_len)
        tokenizer: The GENERator tokenizer
        max_length: Maximum sequence length for tokenization
    
    Returns:
        Dictionary with tokenized inputs and original gene expression data
    """
    # Separate x and y from the batch
    x_batch, y_batch = zip(*batch)
    
    # Convert one-hot encoded sequences to raw sequences
    raw_sequences = []
    nucleotides = ['A', 'C', 'G', 'T']
    for one_hot_seq in x_batch:
        # Ensure one_hot_seq is a PyTorch tensor
        if not isinstance(one_hot_seq, torch.Tensor):
            one_hot_seq = torch.tensor(one_hot_seq)
        
        # Get indices of 1s in one-hot encoding (argmax along axis 1)
        indices = torch.argmax(one_hot_seq, dim=1).cpu().numpy()
        
        # Convert indices to nucleotides
        raw_seq = ''.join([nucleotides[idx] for idx in indices])
        raw_sequences.append(raw_seq)
    
    # Tokenize the raw sequences
    tokenizer.padding_side = "right"
    inputs = tokenizer(
        raw_sequences,
        add_special_tokens=True,
        return_tensors="pt",
        padding=False,
        truncation=True,
        # max_length=max_length
    )

    
    # Convert y arrays to tensors and stack them
    y_tensors = []
    for y in y_batch:
        if not isinstance(y, torch.Tensor):
            y = torch.tensor(y, dtype=torch.float32)
        y_tensors.append(y)
    
    y_stacked = torch.stack(y_tensors)
    
    # Return tokenized inputs and original y
    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "y": y_stacked
    }

In [None]:

# Set logging level for transformers
transformers.logging.set_verbosity_info()

# Define optimization direction for each metric (whether higher or lower is better)
METRICS_DIRECTION: Dict[str, str] = {
    "accuracy": "max",
    "f1_score": "max",
    "mcc": "max",
    "f1_max": "max",
    "auprc_micro": "max",
    "mse": "min",
    "mae": "min",
    "r2": "max",
    "pearson": "max",
}


def is_main_process() -> bool:
    """
    Check if current process is the main process (rank 0) in distributed training.

    Returns:
        bool: True if this is the main process, False otherwise
    """
    if dist.is_initialized():
        return dist.get_rank() == 0
    return True


def dist_print(*args, **kwargs) -> None:
    """
    Print only from the main process (rank 0) in distributed training.
    Prevents duplicate outputs in multi-GPU settings.

    Args:
        *args: Arguments to pass to print function
        **kwargs: Keyword arguments to pass to print function
    """
    if is_main_process():
        print(*args, **kwargs)


In [None]:
def setup_tokenizer(
    model_name: str, padding_and_truncation_side: str
) -> PreTrainedTokenizer:
    """
    Load and configure tokenizer for sequence understanding.

    Args:
        model_name: Name or path of the HuggingFace model
        padding_and_truncation_side: Side for padding and truncation (left or right)

    Returns:
        PreTrainedTokenizer: Configured tokenizer for the model
    """
    dist_print(f"🔤 Loading tokenizer from: {model_name}")
    start_time = time.time()

    # Load tokenizer with trust_remote_code to support custom models
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    # Configure padding and truncation settings
    tokenizer.padding_side = padding_and_truncation_side
    tokenizer.truncation_side = padding_and_truncation_side

    # Set pad_token to eos_token if not defined
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    dist_print(
        f"⏱️ Tokenizer loading completed in {time.time() - start_time:.2f} seconds"
    )

    return tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained("GenerTeam/GENERator-eukaryote-1.2b-base", trust_remote_code=True) # "GenerTeam/GENERator-eukaryote-3b-base"

In [None]:
tokenizer

In [None]:
train_loader2 = DataLoader(
        train_loader.dataset,
        batch_size=1,
        collate_fn=lambda b: collate_fn(b, tokenizer, max_length=450_000)
    )


In [None]:
for batch in train_loader2: 
        print(batch)
        break


In [None]:
batch['input_ids'].shape, batch['attention_mask'].shape

# Model

In [None]:
def setup_model(
    model_name: str,
    problem_type: str,
    num_labels: int,
    max_length: Optional[int] = 16384,
    length_extension_mode: Optional[str] = None,
) -> PreTrainedModel:
    """
    Load and configure model for sequence understanding.

    Args:
        model_name: Name or path of the HuggingFace model
        problem_type: Type of problem
        num_labels: Number of labels for the task
        length_extension_mode: Mode for handling sequences longer than 16384 * 1.05 (if applicable)

    Returns:
        PreTrainedModel: Configured pre-trained model for sequence classification
    """
    dist_print(
        f"🤗 Loading AutoModelForSequenceClassification from: {model_name} with {num_labels} labels"
    )
    start_time = time.time()

    config = AutoConfig.from_pretrained(
        model_name,
        num_labels=num_labels,
        problem_type=problem_type,
        trust_remote_code=True,
    )
    attn_implementation = "sdpa"

    # Apply length extension configurations if max_length > 16384
    original_model_max_length_for_scaling = 16384.0  # Using float for division

    if max_length > original_model_max_length_for_scaling * 1.05:
        dist_print(
            f"⚡️ Max_length ({max_length}) > {int(original_model_max_length_for_scaling)}. Enabling length extension mode: {length_extension_mode}"
        )

        if (
            hasattr(config, "max_position_embeddings")
            and config.max_position_embeddings < max_length
        ):
            dist_print(
                f"   Updating model config's max_position_embeddings from {config.max_position_embeddings} to {max_length}"
            )
            config.max_position_embeddings = max_length

        if length_extension_mode == "yarn_rope_scaling":
            # Calculate rope_scaling_factor based on args.max_length and the fixed original_model_max_length_for_scaling
            rope_scaling_factor = max_length / original_model_max_length_for_scaling
            # original_max_position_embeddings for YaRN config is fixed to 16384
            yarn_original_max_pos_embed = int(original_model_max_length_for_scaling)

            rope_config = {
                "type": "yarn",
                "factor": rope_scaling_factor,
                "original_max_position_embeddings": yarn_original_max_pos_embed,
            }
            config.rope_scaling = rope_config
            dist_print(
                f"✅ Applied YaRN RoPE Scaling with calculated factor: {rope_scaling_factor:.4f}, "
                f"original_max_position_embeddings: {yarn_original_max_pos_embed}"
            )

        elif length_extension_mode == "sliding_window":
            # Check if config already had sliding_window before our patch
            had_sliding_before = hasattr(config, "sliding_window")
            # sliding_window_size is fixed to 16384
            config.sliding_window = 10000 # int(original_model_max_length_for_scaling)

            # Llama-specific monkey-patch
            if getattr(config, "model_type", None) == "llama":
                import transformers
                from liger_kernel.transformers import apply_liger_kernel_to_llama
                from transformers.models.llama.modeling_llama import LlamaAttention

                apply_liger_kernel_to_llama()
                _orig_forward = LlamaAttention.forward

                def _sliding_llama_forward(
                    self,
                    hidden_states,
                    position_embeddings,
                    attention_mask=None,
                    past_key_value=None,
                    cache_position=None,
                    **kwargs,
                ):
                    # inject sliding_window into attention kwargs
                    kwargs["sliding_window"] = self.config.sliding_window
                    return _orig_forward(
                        self,
                        hidden_states,
                        position_embeddings,
                        attention_mask,
                        past_key_value,
                        cache_position,
                        **kwargs,
                    )

                LlamaAttention.forward = _sliding_llama_forward
                dist_print(
                    "🪄 Monkey-patched LlamaAttention to support sliding windows"
                )

            else:
                # for other models, warn if they did not declare sliding_window originally
                if not had_sliding_before:
                    dist_print(
                        f"⚠️ Model type '{getattr(config, 'model_type', 'unknown')}' "
                        "did not originally have `sliding_window` support in its config. "
                        "Please verify that its attention implementation can handle sliding windows."
                    )

            # Set the attention implementation to flash_attention_2 to ensure compatibility with sliding windows
            attn_implementation = "flash_attention_2"
            dist_print(f"✅ Applied Sliding Windows with size: {config.sliding_window}")

        elif length_extension_mode == "none":
            dist_print(
                "   Length extension mode is 'none'. No specific scaling or windowing technique applied from script beyond setting max_length."
            )

    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        config=config,
        trust_remote_code=True,
        attn_implementation=attn_implementation,
    )

    # Ensure pad_token_id is set
    if model.config.pad_token_id is None:
        model.config.pad_token_id = model.config.eos_token_id

    # Report model size for reference
    total_params = sum(p.numel() for p in model.parameters())
    dist_print(f"📊 Model size: {total_params / 1e6:.1f}M parameters")
    dist_print(f"⏱️ Model loading completed in {time.time() - start_time:.2f} seconds")

    return model


In [None]:
model = setup_model(model_name="GenerTeam/GENERator-eukaryote-1.2b-base",problem_type='single_label_classification',num_labels=1, max_length=75002, length_extension_mode="sliding_window")

In [None]:
model=model.to(torch.bfloat16).to(device)

In [None]:
model.gradient_checkpointing_enable()

In [None]:
model

In [None]:
# model.eval()
# with torch.no_grad():
#     output = model(batch['input_ids'].to(device))
    

# Train

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import PreTrainedModel, PreTrainedTokenizer
from typing import Dict, Any, Optional, Callable
import time
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
import os
from tqdm import tqdm

In [None]:
def train_model_custom(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    train_loader: DataLoader,
    val_loader: DataLoader,
    test_loader: Optional[DataLoader] = None,
    num_epochs: int = 10,
    learning_rate: float = 1e-4,
    weight_decay: float = 0.01,
    warmup_steps: int = 0,
    max_grad_norm: float = 1.0,
    save_dir: str = "/work/magroup/wenduoc/DNALongBench/experiments/GENERator/results/ETGP/v2",
    save_steps: int = 1000,
    early_stopping_patience: int = 5,
    device: str = "cuda",
    use_wandb: bool = False,
    gradient_accumulation_steps: int = 1,
    max_length: int = 75002,
) -> Dict[str, Any]:
 
    model = model.to(device)
    model.train()

    train_loader_custom = DataLoader(
        train_loader.dataset,
        batch_size=1,
        collate_fn=lambda b: collate_fn(b, tokenizer)
    )

    val_loader_custom = DataLoader(
        val_loader.dataset,
        batch_size=1,
        collate_fn=lambda b: collate_fn(b, tokenizer)
    )

    if test_loader is not None:
        test_loader_custom = DataLoader(
            test_loader.dataset,
            batch_size=1,
            collate_fn=lambda b: collate_fn(b, tokenizer)
        )

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay
    )

    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        start_factor=0.1,
        end_factor=1.0,
        total_iters=warmup_steps
    )

    criterion = nn.BCEWithLogitsLoss()

    best_val_loss = float('inf')
    patience_counter = 0
    global_step = 0
    training_history = {
        'train_loss': [],
        'val_loss': [],
        'learning_rates': []
    }

    print(f"🚀 Starting training for {num_epochs} epochs...")
    print(f"🔧 Gradient accumulation steps: {gradient_accumulation_steps}")
    print(f"📊 Evaluation will occur after each epoch")
    print(f"💾 Model will be saved based on lowest validation loss")

    start_time = time.time()

    for epoch in range(num_epochs):
        print(f"\n{'='*50}")
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print(f"{'='*50}")

        model.train()
        epoch_train_loss = 0.0
        train_steps = 0

        progress_bar = tqdm(train_loader_custom, desc=f"Epoch {epoch + 1}")

        for step, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['y'].to(device)

            with torch.cuda.amp.autocast():
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits

                if labels.dim() > 1:
                    labels = labels.view(-1).float()

                loss = criterion(logits.view(-1), labels)
                loss = loss / gradient_accumulation_steps

            loss.backward()

            if (step + 1) % gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

            epoch_train_loss += loss.item() * gradient_accumulation_steps
            train_steps += 1

            progress_bar.set_postfix({
                'loss': f"{(loss.item() * gradient_accumulation_steps):.4f}",
                'lr': f"{scheduler.get_last_lr()[0]:.2e}"
            })

            if global_step % save_steps == 0:
                checkpoint_path = os.path.join(save_dir, f'checkpoint_step_{global_step}.pt')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'global_step': global_step
                }, checkpoint_path)
                print(f"💾 Intermediate checkpoint saved at step {global_step}")

        avg_train_loss = epoch_train_loss / train_steps
        training_history['train_loss'].append(avg_train_loss)
        training_history['learning_rates'].append(scheduler.get_last_lr()[0])

        print(f"\n🔄 Evaluating after epoch {epoch + 1}...")
        val_metrics = evaluate_model_custom(model, val_loader_custom, device, criterion)
        training_history['val_loss'].append(val_metrics['loss'])

        print(f"\n📈 Epoch {epoch + 1} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Val Loss: {val_metrics['loss']:.4f}")
        print(f"  Learning Rate: {scheduler.get_last_lr()[0]:.2e}")

        # Save model based on lowest validation loss
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            patience_counter = 0

            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_loss': best_val_loss,
                'global_step': global_step
            }, os.path.join(save_dir, 'best_model.pt'))

            print(f"💾 New best model saved! Val loss: {best_val_loss:.4f}")
        else:
            patience_counter += 1
            print(f"⏳ No improvement in loss. Patience: {patience_counter}/{early_stopping_patience}")

        epoch_checkpoint_path = os.path.join(save_dir, f'model_epoch_{epoch + 1}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_metrics['loss'],
            'global_step': global_step
        }, epoch_checkpoint_path)
        print(f"💾 Epoch {epoch + 1} checkpoint saved")

        if patience_counter >= early_stopping_patience:
            print(f"🛑 Early stopping triggered after {patience_counter} epochs without loss improvement")
            break

        model.train()

    total_time = time.time() - start_time
    print(f"\n✅ Training completed in {total_time/60:.2f} minutes")
    print(f"🏆 Best validation loss achieved: {best_val_loss:.4f}")

    final_metrics = {'training_history': training_history, 'best_val_loss': best_val_loss}
    if test_loader is not None:
        print("\n🧪 Evaluating on test set...")
        test_metrics = evaluate_model_custom(model, test_loader_custom, device, criterion)
        final_metrics['test_metrics'] = test_metrics

        print("📊 Final Test Metrics:")
        for key, value in test_metrics.items():
            print(f"  {key}: {value:.4f}")

    return final_metrics


In [None]:
from tqdm import tqdm
import torch
import numpy as np
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    roc_auc_score,
    average_precision_score
)

def evaluate_model_custom(
    model: PreTrainedModel,
    data_loader: DataLoader,
    device: str,
    criterion: nn.Module
) -> Dict[str, float]:
    """
    Evaluate classification model on a dataset.
    
    Args:
        model: Model to evaluate
        data_loader: DataLoader for evaluation
        device: Device to run evaluation on
        criterion: Loss function
    
    Returns:
        Dictionary of evaluation metrics (including AUPRC)
    """
    model.eval()
    total_loss = 0.0
    all_predictions = []
    all_labels = []
    all_probabilities = []
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['y'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            
            # Process labels as in training
            if labels.dim() > 1:
                if model.config.num_labels == 1:
                    labels = labels.view(-1).float()
                else:
                    labels = labels.view(-1).long()
            
            # Calculate loss
            loss = criterion(logits.view(-1), labels)
            
            # Get probabilities (for binary classification)
            probabilities = torch.sigmoid(logits).float().cpu().numpy()
            predictions = (probabilities > 0.5).astype(int)
     
            total_loss += loss.item()
            
            # Store for metrics
            all_predictions.extend(predictions.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())
            all_probabilities.extend(probabilities.flatten() if model.config.num_labels == 1 else probabilities)
                
            num_batches += 1

    # Compute average loss
    avg_loss = total_loss / num_batches
    
    # Convert lists to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    all_probabilities = np.array(all_probabilities)
    
    # Accuracy
    accuracy = accuracy_score(all_labels, all_predictions)
    
    # Precision, recall, F1 (binary)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='binary', zero_division=0
    )
    
    # AUROC
    auc = roc_auc_score(all_labels, all_probabilities)
    
    # AUPRC 
    auprc = average_precision_score(all_labels, all_probabilities)

    
    return {
        'loss': avg_loss,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc,
        'auprc': auprc,
        'num_samples': len(all_predictions)
    }


In [24]:
# Train the model
training_results = train_model_custom(
    model=model,
    tokenizer=tokenizer,
    train_loader=train_loader,
    val_loader=valid_loader,
    test_loader=test_loader,
    num_epochs=5,
    learning_rate=1e-5,
    # batch_size=1,  # Start small due to memory constraints
    max_length=75000,
    device=device,
    use_wandb=False,  
    gradient_accumulation_steps=16,  # Effective batch size = 8
)



Epoch 1: 7it [01:30, 12.90s/it, loss=0.0338, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 0


Epoch 1: 8it [01:43, 12.85s/it, loss=0.0338, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 0


Epoch 1: 9it [01:56, 12.87s/it, loss=0.0338, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 0


Epoch 1: 10it [02:09, 12.88s/it, loss=0.0338, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 0


Epoch 1: 11it [02:22, 12.90s/it, loss=0.0339, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 0


Epoch 1: 12it [02:35, 12.85s/it, loss=0.0339, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 0


Epoch 1: 13it [02:48, 12.87s/it, loss=0.0339, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 0


Epoch 1: 14it [03:01, 12.91s/it, loss=0.0339, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 0


Epoch 1: 15it [03:13, 12.84s/it, loss=0.0338, lr=1.00e-06]

💾 Intermediate checkpoint saved at step 0


Epoch 1: 2066it [5:32:07,  9.65s/it, loss=3.5620, lr=1.00e-06]



🔄 Evaluating after epoch 1...


Evaluating: 266it [11:27,  2.59s/it]



📈 Epoch 1 Summary:
  Train Loss: 1.2364
  Val Loss: 1.3774
  Learning Rate: 1.00e-06
💾 New best model saved! Val loss: 1.3774
💾 Epoch 1 checkpoint saved

Epoch 2/5


  with torch.cuda.amp.autocast():
Epoch 2: 2066it [5:29:54,  9.58s/it, loss=3.4862, lr=1.00e-06]



🔄 Evaluating after epoch 2...


Evaluating: 266it [11:25,  2.58s/it]



📈 Epoch 2 Summary:
  Train Loss: 1.1919
  Val Loss: 1.3456
  Learning Rate: 1.00e-06
💾 New best model saved! Val loss: 1.3456
💾 Epoch 2 checkpoint saved

Epoch 3/5


  with torch.cuda.amp.autocast():
Epoch 3: 2066it [5:30:42,  9.60s/it, loss=3.4181, lr=1.00e-06]



🔄 Evaluating after epoch 3...


Evaluating: 266it [11:28,  2.59s/it]



📈 Epoch 3 Summary:
  Train Loss: 1.1582
  Val Loss: 1.3132
  Learning Rate: 1.00e-06
💾 New best model saved! Val loss: 1.3132
💾 Epoch 3 checkpoint saved

Epoch 4/5


  with torch.cuda.amp.autocast():
Epoch 4: 2066it [5:31:14,  9.62s/it, loss=3.3426, lr=1.00e-06]



🔄 Evaluating after epoch 4...


Evaluating: 266it [11:29,  2.59s/it]



📈 Epoch 4 Summary:
  Train Loss: 1.1304
  Val Loss: 1.3009
  Learning Rate: 1.00e-06
💾 New best model saved! Val loss: 1.3009
💾 Epoch 4 checkpoint saved

Epoch 5/5


  with torch.cuda.amp.autocast():
Epoch 5: 2066it [5:30:58,  9.61s/it, loss=3.2937, lr=1.00e-06]



🔄 Evaluating after epoch 5...


Evaluating: 266it [11:29,  2.59s/it]



📈 Epoch 5 Summary:
  Train Loss: 1.1106
  Val Loss: 1.2814
  Learning Rate: 1.00e-06
💾 New best model saved! Val loss: 1.2814
💾 Epoch 5 checkpoint saved

✅ Training completed in 1716.65 minutes
🏆 Best validation loss achieved: 1.2814

🧪 Evaluating on test set...


Evaluating: 62it [02:39,  2.58s/it]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [25]:
# Final evaluation on test set 
final_metrics = {}

test_loader_custom = DataLoader(
            test_loader.dataset,
            batch_size=1,
            collate_fn=lambda b: collate_fn(b, tokenizer)
        )

criterion = nn.BCEWithLogitsLoss()

print("\n🧪 Evaluating on test set...")
test_metrics = evaluate_model_custom(model, test_loader_custom, device, criterion)
final_metrics['test_metrics'] = test_metrics

print("📊 Final Test Metrics:")
for key, value in test_metrics.items():
    print(f"  {key}: {value:.4f}")

print(final_metrics)


🧪 Evaluating on test set...


Evaluating: 270it [11:42,  2.60s/it]

📊 Final Test Metrics:
  loss: 1.1334
  accuracy: 0.6815
  precision: 0.0250
  recall: 0.2000
  f1: 0.0444
  auc: 0.3973
  auprc: 0.0362
  num_samples: 270.0000
{'test_metrics': {'loss': 1.1334112096715856, 'accuracy': 0.6814814814814815, 'precision': 0.025, 'recall': 0.2, 'f1': 0.044444444444444446, 'auc': 0.39730769230769225, 'auprc': 0.03618277720444903, 'num_samples': 270}}



