# Fine-Tune Flan-T5-base on BillSum
Author: Gourab S.(heygourab),
github: https://github.com/heygourab

This notebook fine-tunes Flan-T5-base with LoRA on the BillSum dataset (~800 samples) for legal document summarization. Outputs a LoRA adapter (`lora_billsum`) and a `training_report.json` via RogerReportCallback.

The model is trained for 3 epochs with a batch size of 16 and a learning rate of 2e-4.
The model is saved in the `lora_billsum` directory.

## For macOS M1/Apple Silicon
Drop this cell into your Jupyter Notebook (assuming venv is active and you're not using Colab):

In [1]:
%pip install -r ../requirements/mac.txt -q

Note: you may need to restart the kernel to use updated packages.


## Import necessary libraries


In [14]:
 
import os
import json
import logging
import sys
from datetime import datetime
from typing import Dict
import torch
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq, TrainerCallback
)
from peft import LoraConfig, get_peft_model, TaskType
from accelerate import Accelerator
import evaluate
from omegaconf import OmegaConf
import pandas as pd
import matplotlib.pyplot as plt
import psutil


### Setup logger
Sets up logging to track training progress and debugging information. The logger is configured to:
- Display timestamp, log level, and message
- Output logs to standard output (stdout)
- Use INFO level logging
- Create a logger instance named after the current module

In [15]:
# Logging setup
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

## Memory Usage Monitoring
The `print_memory_usage()` function monitors system resource utilization during model training:

- Tracks RAM usage by getting the Resident Set Size (RSS) of current process in GB
- For GPU-enabled systems:
    - Reports allocated GPU memory
    - Shows total available GPU memory
    - Calculates percentage of GPU memory utilization
    - Resets peak memory tracking statistics

This helps identify potential memory bottlenecks and optimize resource usage during training.

In [16]:
# Memory usage
def print_memory_usage():
    process = psutil.Process(os.getpid()) # Get current process
    # Get memory usage in GB
    ram_gb = process.memory_info().rss / 1e9
    # Get total memory in GB
    total_gb = psutil.virtual_memory().total / 1e9
    # Print memory usage
    logger.info(f"RAM usage: {ram_gb:.2f} GB")
    logger.info(f"Total memory: {total_gb:.2f} GB")

    # if running on GPU, print GPU memory usage
    if torch.cuda.is_available():
        gpu_mem = torch.cuda.memory_allocated() / 1e9
        gpu_total = torch.cuda.get_device_properties(0).total_memory / 1e9
        logger.info(f"GPU memory: {gpu_mem:.2f}/{gpu_total:.2f} GB ({gpu_mem/gpu_total*100:.1f}%)")
        torch.cuda.reset_peak_memory_stats()

In [17]:
# Testing memory usage
print_memory_usage()

2025-05-15 18:00:42,740 - INFO - RAM usage: 0.03 GB
2025-05-15 18:00:42,743 - INFO - Total memory: 8.59 GB


## MetricsTrackingCallback
The `MetricsTrackingCallback` is a custom callback for tracking and logging training metrics during the training process. It is designed to work with the Hugging Face Trainer API and provides functionality to log various metrics at specified intervals.

### This callback is particularly useful because it:

- Provides real-time monitoring of training progress
- Creates visualizations to help understand model performance
- Saves metrics for later analysis
- Helps identify potential issues during training (like overfitting or unstable training)


In [18]:
# MetricsTrackingCallback 
class MetricsTrackingCallback(TrainerCallback):

    """
    A callback to track and plot training metrics during training.
    """
    def __init__(self, output_dir: str, plot_every_n: int = 5):
        self.output_dir = output_dir # Directory to save plots and metrics
        self.plot_every_n = plot_every_n # Frequency of plotting
        self.training_loss = [] # List to store training loss
        self.eval_metrics = [] # List to store evaluation metrics
        self.step_numbers = [] # List to store step numbers

    # function for logging metrics
    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs: # Check if logs are empty
            return   # No logs to process
        step = state.global_step 
        if 'loss' in logs:
            # Append loss to training loss list
            self.training_loss.append((step, logs['loss'])) 
        if 'eval_rougeL_f1' in logs: # Check if eval_rougeL_f1 is in logs
            # Append eval_rougeL_f1 to eval metrics list
            self.eval_metrics.append((step, logs['eval_rougeL_f1']))
            # Append step number to step numbers list
            self.step_numbers.append(step) 

        if self.eval_metrics and len(self.eval_metrics) % self.plot_every_n == 0:
            self._generate_plots() # Generate plots every n steps

    # function for on_train_end
    def on_train_end(self, args, state, control, **kwargs):
        self._generate_plots() # Generate plots at the end of training
        self._save_metrics_data() # Save metrics data to CSV
        logger.info(f"Training completed. Metrics saved to {self.output_dir}")

    # function for generating plots
    def _generate_plots(self):
        plot_dir = os.path.join(self.output_dir, 'plots') # Directory to save plots
        os.makedirs(plot_dir, exist_ok=True) # Create directory if it doesn't exist

        # Plot training loss and evaluation metrics
        plt.figure(figsize=(12, 8))
        if self.training_loss:
            steps, losses = zip(*self.training_loss)
            plt.subplot(2, 1, 1)
            plt.plot(steps, losses, label='Training Loss', color='blue')
            plt.xlabel('Steps')
            plt.ylabel('Loss')
            plt.title('Training Loss')
            plt.legend()

        # Plot evaluation metrics
        if self.eval_metrics:
            steps, metrics = zip(*self.eval_metrics)
            plt.subplot(2, 1, 2)
            plt.plot(steps, metrics, label='ROUGE-L F1', color='red')
            plt.xlabel('Steps')
            plt.ylabel('ROUGE-L F1')
            plt.title('Evaluation Metric')
            plt.legend()

        # Adjust layout and save plot
        plt.tight_layout()
        plt.savefig(os.path.join(plot_dir, 'training_metrics.png'))
        plt.close()

    # function for saving metrics data to CSV
    def _save_metrics_data(self):
        if self.training_loss:
            pd.DataFrame(self.training_loss, columns=['step', 'loss'])\
              .to_csv(os.path.join(self.output_dir, 'training_loss.csv'), index=False)
        if self.eval_metrics:
            pd.DataFrame(self.eval_metrics, columns=['step', 'rougeL_f1'])\
              .to_csv(os.path.join(self.output_dir, 'eval_metrics.csv'), index=False)

## RogerReportCallback class:
1. Purpose:
- Creates a detailed JSON report of the training process
- Captures timing information
- Records model and dataset details
- Stores training metrics and system information

In [19]:
# Roger report callback
class RogerReportCallback(TrainerCallback):
    """
    A callback to generate a training report at the end of training.
    """
    def __init__(self, output_dir: str, config):
        self.output_dir = output_dir # Directory to save report
        self.config = config # Configuration object
        self.start_time = datetime.now() # Start time of training
        logger.info(f"Training started at {self.start_time.isoformat()}")
        # Print system information

    def on_train_end(self, args, state, control, **kwargs):
        end_time = datetime.now() # End time of training

        # Calculate duration in minutes
        duration = (end_time - self.start_time).total_seconds() / 60 

        # Safe extraction with fallback
        def safe_get(obj, attr, default='N/A'):
            return getattr(obj, attr, default)

        # Get training state and arguments
        report = {
            'training_summary': {
                'model': self.config.model.name,
                'dataset': self.config.dataset.name,
                'start_time': self.start_time.isoformat(),
                'end_time': end_time.isoformat(),
                'duration_minutes': duration,
                'epochs': safe_get(self.config.training, 'epochs'),
                'train_examples': safe_get(state, 'num_train_examples'),
                'eval_examples': safe_get(state, 'num_eval_examples'),
                'best_metric': safe_get(state, 'best_metric'),
                'best_model_checkpoint': safe_get(state, 'best_model_checkpoint'),
            },
            'state': {
                k: v for k, v in state.__dict__.items()
                if isinstance(v, (int, float, str, bool))
            },
            'training_args': {
                k: v for k, v in args.__dict__.items()
                if isinstance(v, (int, float, str, bool))
            },
            'config': OmegaConf.to_container(self.config),
            'system_info': {
                'torch_version': torch.__version__,
                'cuda_available': torch.cuda.is_available(),
                'cuda_device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
                'cuda_device_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
            }
        }
        
        # create output directory if it doesn't exist
        os.makedirs(self.output_dir, exist_ok=True)
        # Save the report as a JSON file
        report_path = os.path.join(self.output_dir, 'training_report.json')
        with open(report_path, 'w') as f:
            json.dump(report, f, indent=2)
        logger.info(f"Training report saved to {report_path}")

## load_and_preprocess_data function:

- Loads the BillSum dataset from Hugging Face
- Preprocesses the data by tokenizing the input and output sequences
- Splits the dataset into training and validation sets
- Returns the preprocessed datasets for training and validation


### This function is crucial for the fine-tuning process because it:

1. Ensures data quality by removing invalid examples
2. Prevents memory issues by filtering too-long sequences
3. Provides consistent train/eval splits
4. Saves processing time through caching
5. Maintains reproducibility through fixed seeds
6. Gives visibility into the data preparation process

In [20]:
def filter_non_empty(example):
    return len(example['article']) > 0 and len(example['highlights']) > 0

# Function to load and prepare dataset
def load_and_prepare_dataset(cfg, tokenizer=None):
    """
    Load and prepare dataset with progress tracking and caching support.
    Args:
        cfg: Configuration object containing dataset parameters
        tokenizer: Optional tokenizer for length filtering
    Returns:
        dict: Contains 'train' and 'eval' datasets
    """

    # Check cache first
    cache_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'cache')
    os.makedirs(cache_dir, exist_ok=True)
    cache_file = os.path.join(
        cache_dir, 
        f"{cfg.dataset.name}_{cfg.dataset.sample_size if cfg.dataset.sample_size else 'full'}.cache"
    )

    # Try loading from cache
    if os.path.exists(cache_file):
        logger.info(f"Loading dataset from cache: {cache_file}")
        return torch.load(cache_file)

    logger.info(f"Loading dataset: {cfg.dataset.name}")
    ds = load_dataset(cfg.dataset.name, split=f"train[:{cfg.dataset.sample_size}]") if cfg.dataset.sample_size \
         else load_dataset(cfg.dataset.name)['train']

    # Log initial dataset size
    logger.info(f"Initial dataset size: {len(ds)} examples")

    # Validate columns
    for col in [cfg.dataset.text_col, cfg.dataset.summary_col]:
        if col not in ds.column_names:
            raise ValueError(f"Column '{col}' not found in dataset")

    # Rename to standard
    ds = ds.rename_columns({cfg.dataset.text_col: "article", cfg.dataset.summary_col: "highlights"})
    original_len = len(ds)

    # Filter empty examples with progress tracking
    logger.info("Filtering empty examples...")
    ds = ds.filter(
        filter_non_empty,
        desc="Filtering empty examples",
        load_from_cache_file=True
    )
    logger.info(f"Removed {original_len - len(ds)} empty examples")

    if tokenizer and cfg.dataset.get('filter_by_length', True):
        logger.info("Filtering by token length...")
        def within_token_limit(example):
            return len(tokenizer.encode(example['article'])) < cfg.dataset.max_input_tokens and \
                   len(tokenizer.encode(example['highlights'])) < cfg.dataset.max_target_tokens
        
        length_filtered_len = len(ds)
        ds = ds.filter(
            within_token_limit,
            desc="Filtering by length",
            load_from_cache_file=True
        )
        logger.info(f"Removed {length_filtered_len - len(ds)} examples exceeding token limits")

    # Shuffle and split
    logger.info("Shuffling and splitting dataset...")
    ds = ds.shuffle(seed=cfg.seed) # 42 is the default seed
    train_size = int(cfg.split.train_frac * len(ds))
    
    # Create the split datasets
    dataset_dict = {
        "train": ds.select(range(train_size)),
        "eval": ds.select(range(train_size, len(ds)))
    }

    # Log split sizes
    logger.info(f"Train set size: {len(dataset_dict['train'])} examples")
    logger.info(f"Eval set size: {len(dataset_dict['eval'])} examples")

    # Cache the processed dataset
    logger.info(f"Caching processed dataset to: {cache_file}")
    torch.save(dataset_dict, cache_file)

    return dataset_dict

## setup_model_and_tokenizer function:
- Loads the Flan-T5-base model and tokenizer from Hugging Face
- Configures the model for LoRA training
- Returns the model and tokenizer objects

### This function is important because it:
1. Minimizes memory usage through LoRA
2. Provides extensive configuration options
3. Handles common edge cases
4. Gives clear visibility into the model setup process

In [None]:
def setup_model_and_tokenizer(cfg):
    logger.info(f"Loading tokenizer from {cfg.model.name}") 
    tokenizer = AutoTokenizer.from_pretrained(cfg.model.name)
    tokenizer.padding_side = cfg.tokenizer.get("padding_side", "right")
    
    if tokenizer.pad_token is None:
        logger.warning("Tokenizer has no pad token. Setting pad_token = eos_token.")
        tokenizer.pad_token = tokenizer.eos_token

    logger.info(f"Loading base model from {cfg.model.name}")
    low_cpu_mem = getattr(cfg.model.loading_args, "low_cpu_mem_usage", True)
    device_map = getattr(cfg.model.loading_args, "device_map", "auto")
    
    # model loading
    model = AutoModelForSeq2SeqLM.from_pretrained(
        cfg.model.name,
        torch_dtype=torch.float16 if cfg.training.fp16 else None,
        low_cpu_mem_usage=low_cpu_mem,
        device_map=device_map
    )
    
    logger.info(f"Model class: {model.__class__.__name__}")
    
    # setup LoRA
    lora_cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=cfg.lora.r,
        lora_alpha=cfg.lora.alpha,
        target_modules=cfg.lora.target_modules,
        lora_dropout=cfg.lora.dropout,
        bias=cfg.lora.bias
    )

    logger.info(f"LoRA config: {lora_cfg}")
    logger.info(f"LoRA targeting modules: {lora_cfg.target_modules}")

    # Apply LoRA
    logger.info("Applying LoRA to the model...")
    peft_model = get_peft_model(model, lora_cfg)
    logger.info("LoRA applied successfully.")

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
    
    logger.info(f"Total parameters: {total_params:,}")
    logger.info(f"Trainable parameters: {trainable_params:,} ({trainable_params / total_params * 100:.2f}%)") 

    return tokenizer, peft_model

## Preprocessing Function
The `preprocess_datasets` function handles:
- Tokenization of input texts and summaries
- Proper padding and truncation
- Handling of special tokens
- Batch processing for efficiency

In [None]:
def preprocess_datasets(tokenizer, datasets, cfg):
    """
    Preprocess datasets by tokenizing inputs and targets.
    Args:
        tokenizer: HuggingFace tokenizer
        datasets: Dictionary containing train and eval datasets
        cfg: Configuration object
    Returns:
        trainer: Trained Seq2SeqTrainer object
    """
    # Setup output directory with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    if cfg.add_timestamp_to_output:
        cfg.output_dir = f"{cfg.output_dir}_{timestamp}"
    os.makedirs(cfg.output_dir, exist_ok=True)

    # Save configuration
    config_path = os.path.join(cfg.output_dir, 'config.yaml')
    OmegaConf.save(cfg, config_path)
    logger.info(f"Configuration saved to {config_path}")

    # Initialize accelerator
    accelerator = Accelerator()
    logger.info(f"Using accelerator: {accelerator.device}")
    print_memory_usage()

    # Load and prepare data
    logger.info("Loading and preparing datasets...")
    datasets = load_and_prepare_dataset(cfg)
    tokenizer, peft_model = setup_model_and_tokenizer(cfg)
    processed_datasets = preprocess_datasets(tokenizer, datasets, cfg)

    # Setup data collator
    logger.info("Setting up data collator...")
    collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=peft_model,
        padding='longest',
        pad_to_multiple_of=8,
            max_length=cfg.dataset.max_input_tokens
    )

    # Setup ROUGE metrics
    logger.info("Setting up evaluation metrics...")
    rouge = evaluate.load('rouge')

    def compute_metrics(eval_pred):
        """Compute ROUGE metrics for evaluation."""
        preds, labels = eval_pred
        
        # Handle padding in labels
        decoded_labels = []
        for label_seq in labels:
            label_seq = [l if l != -100 else tokenizer.pad_token_id for l in label_seq]
            decoded_labels.append(tokenizer.decode(label_seq, skip_special_tokens=True).strip())
        
        # Decode predictions
        decoded_preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() 
                        for pred in preds]
        
        # Compute ROUGE scores
        res = rouge.compute(
            predictions=decoded_preds, 
            references=decoded_labels, 
            use_stemmer=True
        )

        return {
            'rougeL_f1': res['rougeL'].mid.fmeasure,
            'rouge1_f1': res['rouge1'].mid.fmeasure,
            'rouge2_f1': res['rouge2'].mid.fmeasure,
        }

    # Setup training arguments
    logger.info("Configuring training arguments...")
    training_args = Seq2SeqTrainingArguments(
        output_dir=cfg.output_dir,
        per_device_train_batch_size=cfg.training.batch_size,
        per_device_eval_batch_size=cfg.training.eval_batch_size,
        gradient_accumulation_steps=cfg.training.grad_accum_steps,
        learning_rate=cfg.training.lr,
        num_train_epochs=cfg.training.epochs,
        weight_decay=cfg.training.weight_decay,
        fp16=cfg.training.fp16,
        logging_steps=cfg.training.logging_steps,
        save_strategy=cfg.training.save_strategy,
        save_steps=cfg.training.save_steps,
        evaluation_strategy=cfg.training.evaluation_strategy,
        eval_steps=cfg.training.eval_steps,
        predict_with_generate=True,
        generation_max_length=cfg.dataset.max_target_tokens,
        load_best_model_at_end=cfg.training.load_best_model_at_end,
        metric_for_best_model=cfg.training.metric_for_best,
        greater_is_better=cfg.training.greater_is_better,
        warmup_steps=cfg.training.warmup_steps,
        report_to=cfg.training.report_to,
        push_to_hub=False,
        gradient_checkpointing=cfg.training.gradient_checkpointing,
        label_names=["labels"]
    )

    # Setup callbacks
    logger.info("Setting up training callbacks...")
    callbacks = [
        RogerReportCallback(cfg.output_dir, cfg),
        MetricsTrackingCallback(cfg.output_dir)
    ]

    # Initialize trainer
    logger.info("Initializing trainer...")
    trainer = Seq2SeqTrainer(
        model=peft_model,
        args=training_args,
        train_dataset=processed_datasets["train"],
        eval_dataset=processed_datasets["eval"],
        data_collator=collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        callbacks=callbacks
    )

    # Start training
    logger.info("Starting training...")
    train_result = trainer.train()
    
    # Save final model
    logger.info("Saving final model...")
    trainer.save_model(cfg.output_dir)
    
    # Log and save training results
    metrics = train_result.metrics
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    
    # Final evaluation
    logger.info("Running final evaluation...")
    eval_metrics = trainer.evaluate(
        max_length=cfg.dataset.max_target_tokens,
        num_beams=cfg.generation.num_beams,
        metric_key_prefix="eval"
    )
    trainer.log_metrics("eval", eval_metrics)
    trainer.save_metrics("eval", eval_metrics)
    
    logger.info("Training completed successfully!")
    return trainer

In [21]:
def train(cfg):
    """
    Main training function that orchestrates the fine-tuning process.
    Args:
        cfg: Configuration object containing all training parameters
    Returns:
        trainer: Trained Seq2SeqTrainer object
    """
    # Setup output directory with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    if cfg.add_timestamp_to_output:
        cfg.output_dir = f"{cfg.output_dir}_{timestamp}"
    os.makedirs(cfg.output_dir, exist_ok=True)

    # Save configuration
    config_path = os.path.join(cfg.output_dir, 'config.yaml')
    OmegaConf.save(cfg, config_path)
    logger.info(f"Configuration saved to {config_path}")

    # Initialize accelerator
    accelerator = Accelerator()
    logger.info(f"Using accelerator: {accelerator.device}")
    print_memory_usage()

    # Load and prepare data
    logger.info("Loading and preparing datasets...")
    datasets = load_and_prepare_dataset(cfg)
    tokenizer, peft_model = setup_model_and_tokenizer(cfg)
    processed_datasets = preprocess_datasets(tokenizer, datasets, cfg)

    # Setup data collator
    logger.info("Setting up data collator...")
    collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=peft_model,
        padding='longest',
        pad_to_multiple_of=8,
        max_length=cfg.dataset.max_input_tokens
    )

    # Setup ROUGE metrics
    logger.info("Setting up evaluation metrics...")
    rouge = evaluate.load('rouge')

    def compute_metrics(eval_pred):
        """Compute ROUGE metrics for evaluation."""
        preds, labels = eval_pred
        
        # Handle padding in labels
        decoded_labels = []
        for label_seq in labels:
            label_seq = [l if l != -100 else tokenizer.pad_token_id for l in label_seq]
            decoded_labels.append(tokenizer.decode(label_seq, skip_special_tokens=True).strip())
        
        # Decode predictions
        decoded_preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() 
                        for pred in preds]
        
        # Compute ROUGE scores
        res = rouge.compute(
            predictions=decoded_preds, 
            references=decoded_labels, 
            use_stemmer=True
        )
        
        return {
            'rougeL_f1': res['rougeL'].mid.fmeasure,
            'rouge1_f1': res['rouge1'].mid.fmeasure,
            'rouge2_f1': res['rouge2'].mid.fmeasure,
        }

    # Setup training arguments
    logger.info("Configuring training arguments...")
    training_args = Seq2SeqTrainingArguments(
        output_dir=cfg.output_dir,
        per_device_train_batch_size=cfg.training.batch_size,
        per_device_eval_batch_size=cfg.training.eval_batch_size,
        gradient_accumulation_steps=cfg.training.grad_accum_steps,
        learning_rate=cfg.training.lr,
        num_train_epochs=cfg.training.epochs,
        weight_decay=cfg.training.weight_decay,
        fp16=cfg.training.fp16,
        logging_steps=cfg.training.logging_steps,
        save_strategy=cfg.training.save_strategy,
        save_steps=cfg.training.save_steps,
        evaluation_strategy=cfg.training.evaluation_strategy,
        eval_steps=cfg.training.eval_steps,
        predict_with_generate=True,
        generation_max_length=cfg.dataset.max_target_tokens,
        load_best_model_at_end=cfg.training.load_best_model_at_end,
        metric_for_best_model=cfg.training.metric_for_best,
        greater_is_better=cfg.training.greater_is_better,
        warmup_steps=cfg.training.warmup_steps,
        report_to=cfg.training.report_to,
        push_to_hub=False,
        gradient_checkpointing=cfg.training.gradient_checkpointing,
        label_names=["labels"]
    )

    # Setup callbacks
    logger.info("Setting up training callbacks...")
    callbacks = [
        RogerReportCallback(cfg.output_dir, cfg),
        MetricsTrackingCallback(cfg.output_dir)
    ]

    # Initialize trainer
    logger.info("Initializing trainer...")
    trainer = Seq2SeqTrainer(
        model=peft_model,
        args=training_args,
        train_dataset=processed_datasets["train"],
        eval_dataset=processed_datasets["eval"],
        data_collator=collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        callbacks=callbacks
    )

    # Start training
    logger.info("Starting training...")
    train_result = trainer.train()
    
    # Save final model
    logger.info("Saving final model...")
    trainer.save_model(cfg.output_dir)
    
    # Log and save training results
    metrics = train_result.metrics
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    
    # Final evaluation
    logger.info("Running final evaluation...")
    eval_metrics = trainer.evaluate(
        max_length=cfg.dataset.max_target_tokens,
        num_beams=cfg.generation.num_beams,
        metric_key_prefix="eval"
    )
    trainer.log_metrics("eval", eval_metrics)
    trainer.save_metrics("eval", eval_metrics)
    
    logger.info("Training completed successfully!")
    return trainer

## Main Execution
Load the configuration and start the training process.

In [22]:
if __name__ == "__main__":
    config_path = "../configs/billsum.yaml"
    
    try:
        # 1. Load configuration
        cfg = OmegaConf.load(config_path)
        logger.info(f"✅ Loaded configuration from {config_path}")
        logger.debug(f"🔧 Full config:\n{OmegaConf.to_yaml(cfg)}")

        # 2. Log core info
        logger.info(f"📦 Model: {cfg.model.name}")
        logger.info(f"📊 Dataset: {cfg.dataset.name}")
        logger.info(f"📁 Output dir: {cfg.output_dir}")
        logger.info(f"📈 Epochs: {cfg.training.epochs}")
        
        # 3. Create output directory early
        os.makedirs(cfg.output_dir, exist_ok=True)

        # 4. Launch training
        output_dir = train(cfg)
        logger.info(f"🏁 Training completed successfully! Artifacts at: {output_dir}")

        # 5. Final memory log
        print_memory_usage()
        sys.exit(0)

    except Exception as e:
        logger.error(f"❌ Training failed: {str(e)}")
        traceback_str = traceback.format_exc()
        logger.error(f"📉 Full traceback:\n{traceback_str}")
        sys.exit(1)


2025-05-15 18:04:43,947 - INFO - ✅ Loaded configuration from ../configs/billsum.yaml
2025-05-15 18:04:43,957 - INFO - 📦 Model: google/flan-t5-base
2025-05-15 18:04:43,958 - INFO - 📊 Dataset: billsum
2025-05-15 18:04:43,959 - INFO - 📁 Output dir: lora_billsum
2025-05-15 18:04:43,959 - INFO - 📈 Epochs: 1
2025-05-15 18:04:43,962 - INFO - Configuration saved to lora_billsum_20250515_180443/config.yaml
2025-05-15 18:04:44,135 - INFO - Using accelerator: mps
2025-05-15 18:04:44,136 - INFO - RAM usage: 0.06 GB
2025-05-15 18:04:44,136 - INFO - Total memory: 8.59 GB
2025-05-15 18:04:44,137 - INFO - Loading and preparing datasets...
2025-05-15 18:04:44,137 - ERROR - ❌ Training failed: name '__file__' is not defined


NameError: name 'traceback' is not defined