# Fine-Tune Flan-T5-base on BillSum

Author: Gourab S.(heygourab),
github: https://github.com/heygourab

This notebook fine-tunes Flan-T5-base with LoRA on the BillSum dataset (~800 samples) for legal document summarization. Outputs a LoRA adapter (`lora_billsum`) and a `training_report.json` via RogerReportCallback.

The model is trained for 3 epochs with a batch size of 16 and a learning rate of 2e-4.
The model is saved in the `lora_billsum` directory.


## For macOS M1/Apple Silicon

Drop this cell into your Jupyter Notebook (assuming venv is active and you're not using Colab):


In [None]:
%pip install -r ../requirements/mac.txt -q

## Import necessary libraries


In [None]:
import os
import json
import logging
import sys
import tqdm
import traceback
from datetime import datetime
from typing import Dict
import torch
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq, TrainerCallback,  logging
)
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from accelerate import Accelerator
import evaluate
from omegaconf import OmegaConf
import pandas as pd
import matplotlib.pyplot as plt
import psutil
from collections import defaultdict

In [None]:
# Set the logging level for the transformers library
logging.set_verbosity_info()

### Setup logger

Sets up logging to track training progress and debugging information. The logger is configured to:

- Display timestamp, log level, and message
- Output logs to standard output (stdout)
- Use INFO level logging
- Create a logger instance named after the current module


In [None]:
import sys 
import os
import logging
from datetime import datetime
def setup_logger(name="train_logger", level=logging.INFO, log_file=None):
    """
    Set up a logger with both console and file output.

    Args:
        name (str): Logger name
        level (int): Logging level (e.g., logging.INFO)
        log_file (str or None): Custom log file path. If None, auto-generates one.

    Returns:
        logging.Logger: Configured logger
    """
    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.propagate = False  # Avoid duplicate logs

    # 🧼 Remove existing handlers if already attached
    if logger.hasHandlers():
        logger.handlers.clear()

    # 📦 Formatter
    formatter = logging.Formatter(
        fmt='%(asctime)s — %(name)s — %(levelname)s — %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )

    # 🖥️ Console handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    # 📁 File handler setup
    if log_file is None:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        log_dir = os.path.join(os.getcwd(), 'logs')  # 👈 Safe fallback to current dir
        os.makedirs(log_dir, exist_ok=True)
        log_file = os.path.join(log_dir, f'training_{timestamp}.log')
    else:
        log_dir = os.path.dirname(log_file)
        if log_dir:
            os.makedirs(log_dir, exist_ok=True)

    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # Log header
    logger.info(f"Logger initialized: {name}")
    logger.info(f"Log file created at: {os.path.abspath(log_file)}")
    logger.info(f"Python version: {sys.version}")

    return logger

# Use it
logger = setup_logger("train_logger", logging.INFO)

## Memory Usage Monitoring

The `print_memory_usage()` function monitors system resource utilization during model training:

- Tracks RAM usage by getting the Resident Set Size (RSS) of current process in GB
- For GPU-enabled systems:
  - Reports allocated GPU memory
  - Shows total available GPU memory
  - Calculates percentage of GPU memory utilization
  - Resets peak memory tracking statistics

This helps identify potential memory bottlenecks and optimize resource usage during training.


In [None]:
def print_memory_usage():
    process = psutil.Process(os.getpid())

    ram_gb = process.memory_info().rss / 1e9
    total_gb = psutil.virtual_memory().total / 1e9

    logger.info(f"RAM usage: {ram_gb:.2f} GB")
    logger.info(f"Total system RAM: {total_gb:.2f} GB")

    if torch.cuda.is_available():
        torch.cuda.synchronize()
        gpu_mem = torch.cuda.memory_allocated() / 1e9
        gpu_total = torch.cuda.get_device_properties(0).total_memory / 1e9
        peak_gpu_mem = torch.cuda.max_memory_allocated() / 1e9

        logger.info(f"GPU memory usage: {gpu_mem:.2f}/{gpu_total:.2f} GB ({gpu_mem/gpu_total*100:.1f}%)")
        logger.info(f"Peak GPU memory: {peak_gpu_mem:.2f} GB")

        torch.cuda.reset_peak_memory_stats()


In [None]:
# Testing memory usage
print_memory_usage()

## MetricsTrackingCallback

The `MetricsTrackingCallback` is a custom callback for tracking and logging training metrics during the training process. It is designed to work with the Hugging Face Trainer API and provides functionality to log various metrics at specified intervals.

### This callback is particularly useful because it:

- Provides real-time monitoring of training progress
- Creates visualizations to help understand model performance
- Saves metrics for later analysis
- Helps identify potential issues during training (like overfitting or unstable training)


In [None]:
class MetricsTrackingCallback(TrainerCallback):
    """
    A callback to track and plot training metrics during training.
    """
    def __init__(self, output_dir: str, plot_every_n: int = 5):
        self.output_dir = output_dir
        self.plot_every_n = plot_every_n
        self.training_loss = []  # (step, loss)
        self.eval_metrics_by_key = defaultdict(list)  # key -> [(step, value)]
        self.eval_count = 0
        logger.info("Initialized MetricsTrackingCallback")

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        """Called after evaluation"""
        if not metrics:
            logger.debug("No metrics provided in on_evaluate")
            return
        step = state.global_step
        for key, value in metrics.items():
            if isinstance(value, (int, float)):
                self.eval_metrics_by_key[key].append((step, value))
                logger.info(f"Tracked metric: {key}={value} at step {step}")
        self.eval_count += 1
        if self.eval_count % self.plot_every_n == 0:
            self._generate_plots()

    def on_log(self, args, state, control, logs=None, **kwargs):
        """Track training loss only"""
        if not logs:
            logger.debug("No logs provided")
            return
        step = state.global_step
        if 'loss' in logs:
            self.training_loss.append((step, logs['loss']))
            logger.info(f"Logged loss: {logs['loss']} at step {step}")

    def on_train_end(self, args, state, control, **kwargs):
        """Generate final plots and save metrics"""
        try:
            self._generate_plots()
            self._save_metrics_data()
            logger.info(f"Training completed. Metrics saved to {self.output_dir}")
        except Exception as e:
            logger.error(f"Failed to save metrics on train end: {e}")

    def _generate_plots(self):
        """Generate plots for loss and selected metrics"""
        plot_dir = os.path.join(self.output_dir, 'plots')
        os.makedirs(plot_dir, exist_ok=True)
        plt.figure(figsize=(12, 8))

        # Plot training loss
        if self.training_loss:
            steps, losses = zip(*self.training_loss)
            plt.subplot(2, 1, 1)
            plt.plot(steps, losses, label='Training Loss', color='blue')
            plt.xlabel('Steps')
            plt.ylabel('Loss')
            plt.title('Training Loss')
            plt.legend()

        # Plot ROUGE-L F1 (and optionally other metrics)
        if self.eval_metrics_by_key.get('rougeL_f1'):
            steps, metrics = zip(*self.eval_metrics_by_key['rougeL_f1'])
            plt.subplot(2, 1, 2)
            plt.plot(steps, metrics, label='ROUGE-L F1', color='red')
            plt.xlabel('Steps')
            plt.ylabel('ROUGE-L F1')
            plt.title('Evaluation Metric')
            plt.legend()

        try:
            plt.tight_layout()
            plt.savefig(os.path.join(plot_dir, 'training_metrics.png'))
            plt.close()
            logger.info(f"Saved plot to {plot_dir}/training_metrics.png")
        except Exception as e:
            logger.error(f"Failed to save plot: {e}")

    def _save_metrics_data(self):
        """Save metrics to CSVs"""
        try:
            if self.training_loss:
                pd.DataFrame(self.training_loss, columns=['step', 'loss'])\
                  .to_csv(os.path.join(self.output_dir, 'training_loss.csv'), index=False)
                logger.info(f"Saved training_loss.csv")
            for key, metrics in self.eval_metrics_by_key.items():
                pd.DataFrame(metrics, columns=['step', key])\
                  .to_csv(os.path.join(self.output_dir, f'eval_{key}.csv'), index=False)
                logger.info(f"Saved eval_{key}.csv")
        except Exception as e:
            logger.error(f"Failed to save metrics CSVs: {e}")

## RogerReportCallback class:

1. Purpose:

- Creates a detailed JSON report of the training process
- Captures timing information
- Records model and dataset details
- Stores training metrics and system information


In [None]:
# Roger report callback
class RogerReportCallback(TrainerCallback):
    """
    A callback to generate a training report at the end of training.
    """
    def __init__(self, output_dir: str, config):
        self.output_dir = output_dir # Directory to save report
        self.config = config # Configuration object
        self.start_time = datetime.now() # Start time of training
        logger.info(f"Training started at {self.start_time.isoformat()}")

    def on_train_end(self, args, state, control, **kwargs):
        end_time = datetime.now() # End time of training

        # Calculate duration in minutes
        duration = (end_time - self.start_time).total_seconds() / 60

        # Safe extraction with fallback
        def safe_get(obj, attr, default='N/A'):
            return getattr(obj, attr, default)

        # Get training state and arguments
        report = {
            'training_summary': {
                'model': self.config.model.name,
                'dataset': self.config.dataset.name,
                'start_time': self.start_time.isoformat(),
                'end_time': end_time.isoformat(),
                'duration_minutes': duration,
                'epochs': safe_get(self.config.training, 'epochs'),
                'train_examples': safe_get(state, 'num_train_examples'),
                'eval_examples': safe_get(state, 'num_eval_examples'),
                'best_metric': safe_get(state, 'best_metric'),
                'best_model_checkpoint': safe_get(state, 'best_model_checkpoint'),
            },
            'state': {
                k: v for k, v in state.__dict__.items()
                if isinstance(v, (int, float, str, bool))
            },
            'training_args': {
                k: v for k, v in args.__dict__.items()
                if isinstance(v, (int, float, str, bool))
            },
            'config': OmegaConf.to_container(self.config),
            'system_info': {
                'torch_version': torch.__version__,
                'cuda_available': torch.cuda.is_available(),
                'cuda_device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
                'cuda_device_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
            }
        }

        # create output directory if it doesn't exist
        os.makedirs(self.output_dir, exist_ok=True)
        # Save the report as a JSON file
        report_path = os.path.join(self.output_dir, 'training_report.json')
        with open(report_path, 'w') as f:
            json.dump(report, f, indent=2)
        logger.info(f"Training report saved to {report_path}")

## load_and_preprocess_data function:

- Loads the BillSum dataset from Hugging Face
- Preprocesses the data by tokenizing the input and output sequences
- Splits the dataset into training and validation sets
- Returns the preprocessed datasets for training and validation

### This function is crucial for the fine-tuning process because it:

1. Ensures data quality by removing invalid examples
2. Prevents memory issues by filtering too-long sequences
3. Provides consistent train/eval splits
4. Saves processing time through caching
5. Maintains reproducibility through fixed seeds
6. Gives visibility into the data preparation process


In [None]:
def filter_non_empty(example: dict) -> bool:
    """
    Safely filter out dataset examples where 'text' or 'summary' is missing, empty, or too short.

    Args:
        example (dict): A dictionary with at least 'text' and 'summary' keys.

    Returns:
        bool: True if valid, False otherwise.
    """
    try:
        text = example.get('text', '')
        summary = example.get('summary', '')

        if not isinstance(text, str) or not isinstance(summary, str):
            logger.debug(f"Invalid types - text: {type(text)}, summary: {type(summary)}")
            return False

        text = text.strip()
        summary = summary.strip()

        text_valid = bool(text) and len(text.split()) >= 3
        summary_valid = bool(summary) and len(summary.split()) >= 1

        if not (text_valid and summary_valid):
            logger.debug(
                f"Filtered example - text_len={len(text)} chars ({len(text.split())} words), "
                f"summary_len={len(summary)} chars ({len(summary.split())} words)"
            )
            return False

        return True

    except Exception as e:
        logger.warning(f"Exception during filtering: {e}")
        return False


In [None]:
def load_and_prepare_dataset(cfg, tokenizer=None):
    """
    Load and preprocess the dataset based on the provided config.

    Args:
        cfg (DictConfig): Hydra/OmegaConf config object.
        tokenizer (PreTrainedTokenizer): Optional tokenizer for token-based filtering.

    Returns:
        dict: Dict with 'train' and 'eval' splits.
    """
    try:
        logger.info(f"Loading dataset: {cfg.dataset.name}")

        # Load a sample or full dataset
        if cfg.dataset.sample_size:
            ds = load_dataset(cfg.dataset.name, split=f"train[:{cfg.dataset.sample_size}]", verification_mode="no_checks")
        else:
            ds = load_dataset(cfg.dataset.name, split="train", verification_mode="no_checks")

        logger.info(f"Dataset loaded with {len(ds)} examples")

        # Validate required columns
        for col in [cfg.dataset.text_col, cfg.dataset.summary_col]:
            if col not in ds.column_names:
                raise ValueError(f"Column '{col}' not found in dataset")

        # Rename for consistency
        ds = ds.rename_columns({
            cfg.dataset.text_col: "article",
            cfg.dataset.summary_col: "highlights"
        })
        logger.info(f"Dataset columns after renaming: {ds.column_names}")

        # Filter out empty or bad examples
        ds = ds.filter(filter_non_empty, num_proc=cfg.preprocessing.num_proc)
        logger.info(f"After non-empty filtering: {len(ds)} examples remain")

        # Optionally filter based on token length
        if tokenizer and cfg.dataset.filter_by_length:
            def within_token_limit(example):
                return (
                    len(tokenizer.encode(example["article"])) < cfg.dataset.max_input_tokens and
                    len(tokenizer.encode(example["highlights"])) < cfg.dataset.max_target_tokens
                )
            ds = ds.filter(within_token_limit, num_proc=cfg.preprocessing.num_proc)
            logger.info(f"After token-length filtering: {len(ds)} examples remain")

        # Shuffle and split
        ds = ds.shuffle(seed=cfg.seed)
        train_size = int(cfg.split.train_frac * len(ds))

        logger.info(f"Train size: {train_size}, Eval size: {len(ds) - train_size}")

        return {
            "train": ds.select(range(train_size)),
            "eval": ds.select(range(train_size, len(ds)))
        }

    except Exception as e:
        logger.error(f"Dataset loading/preprocessing failed: {e}")
        raise

## setup_model_and_tokenizer function:

- Loads the Flan-T5-base model and tokenizer from Hugging Face
- Configures the model for LoRA training
- Returns the model and tokenizer objects

### This function is important because it:

1. Minimizes memory usage through LoRA
2. Provides extensive configuration options
3. Handles common edge cases
4. Gives clear visibility into the model setup process


In [None]:
def setup_model_and_tokenizer(cfg):
    logger.info(f"Loading tokenizer from {cfg.model.name}")
    tokenizer = AutoTokenizer.from_pretrained(cfg.model.name)
    tokenizer.padding_side = cfg.tokenizer.get("padding_side", "right")

    if tokenizer.pad_token is None:
        logger.warning("Tokenizer has no pad token. Setting pad_token = eos_token.")
        tokenizer.pad_token = tokenizer.eos_token

    logger.info(f"Loading base model from {cfg.model.name}")
    low_cpu_mem = getattr(cfg.model.loading_args, "low_cpu_mem_usage", True)
    device_map = getattr(cfg.model.loading_args, "device_map", "auto")

    # model loading
    model = AutoModelForSeq2SeqLM.from_pretrained(
        cfg.model.name,
        torch_dtype=torch.float16 if cfg.training.fp16 else None,
        low_cpu_mem_usage=low_cpu_mem,
        device_map=device_map
    )

    logger.info(f"Model class: {model.__class__.__name__}")

    # setup LoRA
    lora_cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=cfg.lora.r,
        lora_alpha=cfg.lora.alpha,
        target_modules=cfg.lora.target_modules,
        lora_dropout=cfg.lora.dropout,
        bias=cfg.lora.bias
    )

    logger.info(f"LoRA config: {lora_cfg}")
    logger.info(f"LoRA targeting modules: {lora_cfg.target_modules}")

    # Apply LoRA
    logger.info("Applying LoRA to the model...")
    peft_model = get_peft_model(model, lora_cfg)
    logger.info("LoRA applied successfully.")

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)

    logger.info(f"Total parameters: {total_params:,}")
    logger.info(f"Trainable parameters: {trainable_params:,} ({trainable_params / total_params * 100:.2f}%)")

    return tokenizer, peft_model

## Preprocessing Function

The `preprocess_datasets` function handles:

- Tokenization of input texts and summaries
- Proper padding and truncation
- Handling of special tokens
- Batch processing for efficiency


In [None]:
def chunk_with_stride(text, tokenizer, max_length=512, stride=128):
    tokenized = tokenizer(
        text,
        max_length=max_length,
        truncation=True,
        stride=stride,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_ids = tokenized["input_ids"]
    lengths   = tokenized["length"]
    chunks = []
    for ids, length in zip(input_ids, lengths):
        chunks.append(tokenizer.decode(ids[:length], skip_special_tokens=True))
    return chunks

In [None]:
def preprocess_datasets(tokenizer, datasets: Dict[str, Dataset], cfg) -> Dict[str, Dataset]:
    """
    Preprocess datasets by tokenizing inputs and targets for seq2seq training.

    Args:
        tokenizer (PreTrainedTokenizer): Tokenizer for encoding text (e.g., Flan-T5-base).
        datasets (dict): Dictionary of train/eval datasets with 'article' and 'highlights' columns.
        cfg: Config object with dataset.max_input_tokens, max_target_tokens, and prompt.prefix.

    Returns:
        dict: Processed datasets with tokenized inputs and labels.

    Raises:
        ValueError: If required columns are missing.
        Exception: If tokenization or mapping fails.
    """

    # here we assume that the datasets are already loaded and passed as a dictionary
    # with keys 'train', 'validation', etc.

    # setup logging
    logger.info("Preprocessing datasets...")

    def preprocess_function(examples):
        try:
            # Validate inputs
            inputs = []
            for doc in examples['article']:
                if not isinstance(doc, str):
                    logger.warning(f"Non-string article: {type(doc)}")
                    doc = ""

                # Chunk the document if it's too long
                chunks = chunk_with_stride(doc, tokenizer=tokenizer, max_length=cfg.dataset.max_input_tokens)

                # Add prefix to the first chunk
                # prompt prefix -- Summarize this legal document:\n
                inputs.append(cfg.prompt.prefix + (chunks[0] if chunks else doc[:cfg.dataset.max_input_tokens]))

            # Tokenize inputs
            model_inputs = tokenizer(
                inputs,
                max_length=cfg.dataset.max_input_tokens,
                padding=False,
                truncation=True
            )

            # Tokenize targets
            highlights = [h if isinstance(h, str) else "" for h in examples["highlights"]]
            if not all(isinstance(h, str) for h in highlights):
                logger.warning("Non-string highlights detected")
            labels = tokenizer(
                highlights,
                max_length=cfg.dataset.max_target_tokens,
                padding=False,
                truncation=True
            )

            model_inputs["labels"] = labels["input_ids"]
            return model_inputs
        except Exception as e:
            logger.error(f"Preprocessing failed for batch: {str(e)}")
            raise

    processed_datasets = {}

    for split, dataset in datasets.items():
        logger.info(f"Processing {split} split...")
        try:
            # Validate columns
            if not all(col in dataset.column_names for col in ['article', 'highlights']):
                raise ValueError(f"Missing required columns in {split} split: {dataset.column_names}")

            processed = dataset.map(
                preprocess_function,
                batched=True,
                batch_size=1000,
                remove_columns=dataset.column_names,
                desc=f"Preprocessing {split} dataset",
                num_proc=max(1, cfg.preprocessing.num_proc // 2),
                load_from_cache_file=True
            )
            processed_datasets[split] = processed
            logger.info(f"Processed {len(processed)} examples in {split} split")
        except Exception as e:
            logger.error(f"Error processing {split} split: {str(e)}")
            logger.error(f"Dataset columns: {dataset.column_names}")
            raise

    return processed_datasets

## Setup Training

- Defines training arguments


In [None]:
def setup_training(tokenizer, cfg):
    logger.info("Setting up training configuration...")

    # Set compute dtype
    compute_dtype = torch.float16 if cfg.training.fp16 else torch.float32
    logger.info(f"Compute dtype: {compute_dtype}")

    logger.info(f"Loading model from {cfg.model.name}")
    # Load base model with proper memory settings
    model = AutoModelForSeq2SeqLM.from_pretrained(
        cfg.model.name,
        torch_dtype=compute_dtype,
        low_cpu_mem_usage=cfg.model.loading_args.low_cpu_mem_usage,
        device_map=cfg.model.loading_args.device_map
    )
    logger.info(f"Model loaded successfully.")
    logger.info(f"Model class: {model.__class__.__name__}")
    logger.info(f"Model config: {model.config}")
    logger.info(f"Model parameters: {model.num_parameters():,}")

    # Prepare model for LoRA training
    model = prepare_model_for_kbit_training(model)

    # LoRA Configuration
    peft_config = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        inference_mode=False,
        r=cfg.lora.r,
        lora_alpha=cfg.lora.alpha,
        lora_dropout=cfg.lora.dropout,
        target_modules=cfg.lora.target_modules,
        bias=cfg.lora.bias
    )
    model = get_peft_model(model, peft_config)

    # Training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=cfg.output_dir,
        evaluation_strategy=cfg.training.evaluation_strategy,
        learning_rate=cfg.training.lr,
        per_device_train_batch_size=cfg.training.batch_size,
        per_device_eval_batch_size=cfg.training.eval_batch_size,
        gradient_accumulation_steps=cfg.training.grad_accum_steps,
        num_train_epochs=cfg.training.epochs,
        weight_decay=cfg.training.weight_decay,
        logging_steps=cfg.training.logging_steps,
        save_steps=cfg.training.save_steps,
        eval_steps=cfg.training.eval_steps,
        save_strategy=cfg.training.save_strategy,
        metric_for_best_model=cfg.training.metric_for_best,
        greater_is_better=cfg.training.greater_is_better,
        load_best_model_at_end=cfg.training.load_best_model_at_end,
        save_total_limit=2,
        generation_max_length=cfg.dataset.max_target_tokens,
        generation_num_beams=cfg.generation.num_beams,
        fp16=cfg.training.fp16,
        optim="paged_adamw_32bit",
        gradient_checkpointing=cfg.training.gradient_checkpointing
    )

    # Data collator
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=-100,
        pad_to_multiple_of=8 if cfg.training.fp16 else None
    )

    return model, training_args, data_collator

# Train function


In [None]:
def train(cfg):
    try:
        # Set up logging
        logger.info("Starting training process...")

        # Load tokenizer
        logger.info("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(cfg.model.name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token


        # Load and prepare datasets
        logger.info("Loading and preparing datasets...")
        datasets = load_and_prepare_dataset(cfg, tokenizer)
        print_memory_usage()  # Print memory usage after loading datasets
        logger.info(f"Loaded {len(datasets['train'])} training examples and {len(datasets['eval'])} evaluation examples")


        processed_datasets = preprocess_datasets(tokenizer, datasets, cfg)

        # Set up model and training configuration
        logger.info("Setting up model and training configuration...")
        model, training_args, data_collator = setup_training(tokenizer, cfg)

        # Setup callbacks
        callbacks = [
            MetricsTrackingCallback(cfg.output_dir),
            RogerReportCallback(cfg.output_dir, cfg)
        ]

        # Set up trainer
        trainer = Seq2SeqTrainer(
            model=model,
            args=training_args,
            train_dataset=processed_datasets["train"],
            eval_dataset=processed_datasets["eval"],  # Changed from "validation" to "eval"
            data_collator=data_collator,
            tokenizer=tokenizer,
            callbacks=callbacks
        )

        logger.info("Starting training...")
        trainer.train()

        # Save the final model
        logger.info("Saving final model...")
        final_model_dir = os.path.join(cfg.output_dir, "final_model")
        trainer.save_model(final_model_dir)

        logger.info(f"Training completed successfully! Model saved to: {final_model_dir}")
        return cfg.output_dir

    except Exception as e:
        logger.error(f"Error during training: {str(e)}")
        raise

## Main Execution

Load the configuration and start the training process.


In [None]:
if __name__ == "__main__":
    config_path = "../configs/billsum.yaml"

    try:
        # 1. Load configuration
        cfg = OmegaConf.load(config_path)
        logger.info(f"✅ Loaded configuration from {config_path}")
        logger.debug(f"🔧 Full config:\n{OmegaConf.to_yaml(cfg)}")

        # 2. Log core info
        logger.info(f"📦 Model: {cfg.model.name}")
        logger.info(f"📊 Dataset: {cfg.dataset.name}")
        logger.info(f"📁 Output dir: {cfg.output_dir}")
        logger.info(f"📈 Epochs: {cfg.training.epochs}")

        # 3. Create output directory early
        os.makedirs(cfg.output_dir, exist_ok=True)

        # 4. Launch training
        output_dir = train(cfg)
        logger.info(f"🏁 Training completed successfully! Artifacts at: {output_dir}")

        # 5. Final memory log
        print_memory_usage()
        sys.exit(0)

    except Exception as e:
        logger.error(f"❌ Training failed: {str(e)}")
        traceback_str = traceback.format_exc()
        logger.error(f"📉 Full traceback:\n{traceback_str}")
        sys.exit(1)
