# Fine-tuning Flan-T5-base for Legal Document Summarization

This notebook is a modified version of the original notebook by Gourab S. (@heygourab).
Author: Gourab S. (@heygourab)
This notebook demonstrates fine-tuning the Flan-T5-base model on the BillSum dataset using LoRA (Low-Rank Adaptation). We'll use the Hugging Face ecosystem (`transformers`, `datasets`, `peft`) for efficient fine-tuning.

## Setup Overview

- Base Model: google/flan-t5-base
- Dataset: BillSum (~2000 samples)

## Prerequisites

This notebook assumes a Colab environment with a GPU available. If you're running this locally, make sure to install the required packages and set up your GPU environment accordingly.
[Open in Colab](https://colab.research.google.com/github/heygourab/pdf_summarization_model_fine_tuning/blob/main/notebooks/billsum_lora_finetune_colab.ipynb)


## 1. Environment Setup

First, let's install the required dependencies and set up GPU monitoring.


In [1]:
# M1/M2 Mac Setup for PyTorch with MPS (Metal GPU)
%pip install -q torch torchvision torchaudio
%pip install -q transformers datasets accelerate evaluate peft nltk wandb omegaconf fsspec pyarrow
# Note: bitsandbytes is not supported on M1/M2 Mac with MPS.


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
# 🔥 Device setup for M1/M2 (Metal Performance Shaders)
import torch

def get_device():
    if torch.backends.mps.is_available():
        print("MPS (Metal Performance Shaders) is available! Using MPS for acceleration.")
        return torch.device('mps')
    elif torch.cuda.is_available():
        print("CUDA is available! Using CUDA for acceleration.")
        return torch.device('cuda')
    else:
        print("No GPU acceleration available. Using CPU.")
        return torch.device('cpu')

device = get_device()
print(f'Using device: {device}')

MPS (Metal Performance Shaders) is available! Using MPS for acceleration.
Using device: mps


## 2. Import Libraries


In [3]:
import os
import json
import torch
import nltk
import evaluate
import numpy as np
import sys
import logging
from datetime import datetime
from datasets import load_dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    BitsAndBytesConfig
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType
)
import wandb
import psutil

## 3. Logger setup


In [4]:
def setup_logger(name="train_logger", level=logging.INFO, log_file=None):
    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.propagate = False  # Avoid duplicate logs

    # Clear existing handlers
    if logger.hasHandlers():
        logger.handlers.clear()

    # Formatter for log messages
    formatter = logging.Formatter(
        fmt='%(asctime)s — %(name)s — %(levelname)s — %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )

    # Console handler setup
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    # File handler setup
    if log_file is None:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        log_dir = os.path.join(os.getcwd(), 'logs')  # Safe fallback to current dir
        os.makedirs(log_dir, exist_ok=True)
        log_file = os.path.join(log_dir, f'training_{timestamp}.log')
    else:
        log_dir = os.path.dirname(log_file)
        if log_dir:
            os.makedirs(log_dir, exist_ok=True)

    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # Log header
    logger.info(f"Logger initialized: {name}")
    logger.info(f"Log file created at: {os.path.abspath(log_file)}")
    logger.info(f"Python version: {sys.version}")

    return logger

# Use it
logger = setup_logger("train_logger", logging.INFO)

2025-05-20 12:42:41 — train_logger — INFO — Logger initialized: train_logger
2025-05-20 12:42:41 — train_logger — INFO — Log file created at: /Users/gourabsarkar/Developer/college_project/pdf_summarization_model_fine_tuning/notebooks/logs/training_20250520_124241.log
2025-05-20 12:42:41 — train_logger — INFO — Python version: 3.10.17 (main, Apr  8 2025, 12:10:59) [Clang 16.0.0 (clang-1600.0.26.6)]
2025-05-20 12:42:41 — train_logger — INFO — Log file created at: /Users/gourabsarkar/Developer/college_project/pdf_summarization_model_fine_tuning/notebooks/logs/training_20250520_124241.log
2025-05-20 12:42:41 — train_logger — INFO — Python version: 3.10.17 (main, Apr  8 2025, 12:10:59) [Clang 16.0.0 (clang-1600.0.26.6)]


## 4. Loading the required NLTK libraries


In [5]:
# Download required NLTK data
for resource in ['punkt', 'punkt_tab']:
    try:
        nltk.download(resource, quiet=True)
        logger.info(f"Successfully downloaded NLTK resource: {resource}")
    except Exception as e:
        logger.error(f"Error downloading {resource}: {e}")

2025-05-20 12:42:42 — train_logger — INFO — Successfully downloaded NLTK resource: punkt
2025-05-20 12:42:42 — train_logger — INFO — Successfully downloaded NLTK resource: punkt_tab
2025-05-20 12:42:42 — train_logger — INFO — Successfully downloaded NLTK resource: punkt_tab


## 5. Memory Usage Monitoring

The `print_memory_usage()` function monitors system resource utilization during model training:

- Tracks RAM usage by getting the Resident Set Size (RSS) of current process in GB
- For GPU-enabled systems:
  - Reports allocated GPU memory
  - Shows total available GPU memory
  - Calculates percentage of GPU memory utilization
  - Resets peak memory tracking statistics

This helps identify potential memory bottlenecks and optimize resource usage during training.


In [6]:
def print_memory_usage():
    process = psutil.Process(os.getpid()) # Get the current process

    ram_gb = process.memory_info().rss / 1e9 # Convert bytes to GB
    total_gb = psutil.virtual_memory().total / 1e9 # Total system RAM in GB

    logger.info(f"RAM usage: {ram_gb:.2f} GB") # Current process RAM usage
    logger.info(f"Total system RAM: {total_gb:.2f} GB") # Total system RAM

    # Check for CUDA device
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        gpu_mem = torch.cuda.memory_allocated() / 1e9
        gpu_total = torch.cuda.get_device_properties(0).total_memory / 1e9
        peak_gpu_mem = torch.cuda.max_memory_allocated() / 1e9

        logger.info(f"GPU memory usage: {gpu_mem:.2f}/{gpu_total:.2f} GB ({gpu_mem/gpu_total*100:.1f}%)")
        logger.info(f"Peak GPU memory: {peak_gpu_mem:.2f} GB")

        torch.cuda.reset_peak_memory_stats()

    # Check for MPS device
    elif torch.backends.mps.is_available():
        # MPS doesn't have built-in memory tracking like CUDA
        # We can only log that we're using MPS
        logger.info("Using MPS (Metal Performance Shaders) - Memory stats not available")

print_memory_usage()

2025-05-20 12:42:42 — train_logger — INFO — RAM usage: 0.46 GB
2025-05-20 12:42:42 — train_logger — INFO — Total system RAM: 8.59 GB
2025-05-20 12:42:42 — train_logger — INFO — Using MPS (Metal Performance Shaders) - Memory stats not available
2025-05-20 12:42:42 — train_logger — INFO — Total system RAM: 8.59 GB
2025-05-20 12:42:42 — train_logger — INFO — Using MPS (Metal Performance Shaders) - Memory stats not available


## 6. Configuration Parameters

all hyperparameters and configuration settings for the model, dataset, LoRA, and training.


In [7]:
from peft import TaskType 
from datetime import datetime
import logging

logger = logging.getLogger(__name__)

CONFIG = {
    # ====== MODEL & ARCHITECTURE ======
    "model_name": "google/flan-t5-base",  # Lightweight encoder-decoder model
    "model_type": "encoder-decoder",       # Needed for proper Trainer setup

    # ====== DATASET CONFIG ======
    "dataset_name": "billsum",
    "text_col": "text",
    "summary_col": "summary",
    "max_input_tokens": 512,
    "max_target_tokens": 256,

    "sample_size": 2000,  # Increased sample size to stabilize training
    "filter_by_length": True,
    "split_train_frac": 0.9,  # Give model more data to learn from

    # ====== PROMPT INJECTION ======
    "prompt_prefix": "Summarize the following legal bill: ",

    # ====== LoRA CONFIG ======
    "lora_r": 8,  # Reduced rank for MPS stability
    "lora_alpha": 16,  # Lower alpha to reduce overfitting
    "lora_target_modules": ["q", "k", "v"],  # Cover more layers (T5 specifics)
    "lora_dropout": 0.1,  # Slightly increased to help generalization
    "lora_bias": "none",
    "lora_task_type": TaskType.SEQ_2_SEQ_LM,

    # ====== TRAINING CONFIG ======
    "do_train": True,
    "do_eval": True,
    "num_train_epochs": 3,  # 3-5 epochs is safe for small datasets
    "per_device_train_batch_size": 1,  # Lower batch size for stability on MPS
    "per_device_eval_batch_size": 1,
    "gradient_accumulation_steps": 8,  # Effective batch size = 8

    "learning_rate": 1e-4,  # Reduced LR to avoid erratic weight updates
    "weight_decay": 0.1,   # Lower decay to avoid underfitting
    "warmup_steps": 20,     # Faster warmup for small dataset

    "fp16": False,  # MPS doesn’t support it
    "bf16": False,
    "torch_compile": False,  # Not stable on MPS
    "gradient_checkpointing": True,  # Saves memory

    "optim": "adamw_torch",  # MPS-safe optimizer

    # ====== EVALUATION & LOGGING ======
    "logging_steps": 10,
    "evaluation_strategy": "steps",
    "eval_steps": 50,  # 🔥 Slightly slower eval to reduce jitter
    "save_strategy": "steps",
    "save_steps": 100,
    "save_total_limit": 1,

    "load_best_model_at_end": True,
    "metric_for_best_model": "rougeL",  # Logical for summarization
    "greater_is_better": True,
    "report_to": "wandb",  # Replace with "none" if not using WandB
    "overwrite_output_dir": True,

    # ====== GENERATION CONFIG ======
    "gen_num_beams": 4,         # Slightly cheaper beam search
    "gen_length_penalty": 0.8,
    "gen_early_stopping": True,

    # ====== MISC ======
    "seed": 42,

    # ====== GOOGLE DRIVE SUPPORT (COLAB) ======
    "mount_drive": False,
    "drive_path": "MyDrive/ML_models/pdf_summarization",
    "training_report_filename": "training_report.json",
    "lora_adapter_name": "lora_billsum_legal",
}

# ====== OUTPUT DIR HANDLER ======
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
CONFIG["output_dir"] = f"{CONFIG['lora_adapter_name']}_{timestamp}"

if CONFIG["mount_drive"]:
    CONFIG["gdrive_output_dir"] = f"/content/drive/{CONFIG['drive_path']}/{CONFIG['output_dir']}"

# ====== SANITY LOG ======
logger.info("🔧 MPS CONFIG DUMP:")
for k, v in CONFIG.items():
    logger.info(f"  {k}: {v}")


## 6. Login to Hugging Face Hub and Weights & Biases

You'll need to log in to Hugging Face to download models/datasets and to Weights & Biases for experiment tracking.
You can get your Hugging Face token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
and your W&B API key from [https://wandb.ai/authorize](https://wandb.ai/authorize).


In [8]:
from huggingface_hub import HfFolder, notebook_login

try:
    if HfFolder.get_token() is None:
        logger.info("Hugging Face token not found. Please log in.")
        notebook_login()
    else:
        logger.info("Already logged in to Hugging Face Hub.")
except Exception as e:
    logger.error(f"An error occurred during Hugging Face login check: {e}")
    logger.info("Attempting login...")
    notebook_login()

# Login to Weights & Biases
try:
    wandb.login()
    # Use a different run name than output_dir
    run_name = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    wandb.init(
        project="flan-t5-billsum-lora", 
        name=run_name,  # Use different run name
        config=CONFIG
    )
    logger.info("Successfully logged in to W&B and initialized experiment.")
except Exception as e:
    logger.error(f"Could not login to W&B: {e}. Ensure you have run `wandb login` or set WANDB_API_KEY.")
    CONFIG["report_to"] = "tensorboard" # Fallback to tensorboard
    logger.info("Falling back to TensorBoard for logging.")
    # No explicit init for tensorboard here, Trainer handles it via TrainingArguments
# Note: Use environment variables or notebook secrets to store your tokens securely


[34m[1mwandb[0m: Currently logged in as: [33mgoheygourab[0m ([33mgoheygourab-self[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## 7. Mount Google Drive (Optional)

If you want to save your model checkpoints and outputs to Google Drive, mount it here.


In [9]:
def is_colab():
    """Check if the current environment is Google Colab."""
    try:
        import google.colab
        return True
    except ImportError:
        return False
        
if is_colab() and CONFIG["mount_drive"]:
    from google.colab import drive
    try:
        drive.mount('/content/drive')
        logger.info("Google Drive mounted successfully.")
        # Create the output directory on Drive if it doesn't exist
        if not os.path.exists(CONFIG["gdrive_output_dir"]):
            os.makedirs(CONFIG["gdrive_output_dir"], exist_ok=True)
            logger.info(f"Created Google Drive output directory: {CONFIG['gdrive_output_dir']}")
    except Exception as e:
        logger.error(f"Failed to mount Google Drive: {e}")
        logger.info("Proceeding without Google Drive. Outputs will be saved to Colab ephemeral storage.")
        CONFIG["mount_drive"] = False # Disable drive features if mount fails
else:
    logger.info("Not running in Colab or Google Drive is disabled in the configuration.")
    CONFIG["mount_drive"] = False

## 8. Load Model, Tokenizer and Configure LoRA

Loads the Flan-T5-base model and tokenizer from Hugging Face, configures the model for LoRA training and returns the model and tokenizer objects.


In [10]:
# Install latest bitsandbytes and required dependencies
%pip install -U bitsandbytes --no-cache-dir -q
%pip install accelerate --upgrade -q
%pip install transformers --upgrade -q

# Don't install bitsandbytes on Mac M1 as it's not compatible with MPS

# Import required libraries
import torch
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig
import bitsandbytes as bnb

# Print versions and available devices 
def check_environment():
    logger.info(f"PyTorch version: {torch.__version__}")
    logger.info(f"MPS available: {torch.backends.mps.is_available()}")
    logger.info(f"MPS built: {torch.backends.mps.is_built()}")
    logger.info(f"CUDA available: {torch.cuda.is_available()}")

check_environment()

# Function to verify CUDA and bitsandbytes setup
def verify_installation():
    logger.info(f"PyTorch version: {torch.__version__}")
    logger.info(f"CUDA available: {torch.cuda.is_available()}")
    logger.info(f"bitsandbytes version: {bnb.__version__}")
    
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. 4-bit quantization requires a CUDA-enabled GPU.")
    
    # Test BitsAndBytes CUDA kernels
    try:
        _ = bnb.matmul(torch.zeros(2, 2).cuda(), torch.zeros(2, 2).cuda())
        logger.info("BitsAndBytes CUDA kernels working correctly")
    except Exception as e:
        logger.error(f"BitsAndBytes CUDA test failed: {e}")
        raise

# Verify installation
# verify_installation()

# Load tokenizer
logger.info(f"Loading tokenizer for model: {CONFIG['model_name']}")
try:
    tokenizer = AutoTokenizer.from_pretrained(
        CONFIG["model_name"],
        use_fast=True,
        padding_side="right",
        model_max_length=CONFIG["max_input_tokens"]
    )
    logger.info("Tokenizer loaded successfully")
except Exception as e:
    logger.error(f"Failed to load tokenizer: {e}")
    raise

# Load model for MPS
logger.info(f"Loading base model: {CONFIG['model_name']} for MPS...")
try:
    model = AutoModelForSeq2SeqLM.from_pretrained(
        CONFIG["model_name"],
        device_map=None,  # Don't use auto device mapping on MPS
        torch_dtype=torch.float32,  # Use FP32 on MPS for better compatibility
        use_cache=False  # Disable KV cache to avoid past_key_values warning
    )
    # Move model to MPS device after loading
    if torch.backends.mps.is_available():
        model = model.to(device)
    logger.info(f"Model loaded successfully and moved to {device}")
except Exception as e:
    logger.error(f"Failed to load model: {e}")
    raise

# Prepare model for training with proper error handling
logger.info("Configuring model for LoRA training...")
try:
    # Configure LoRA for the model
    lora_config = LoraConfig(
        r=CONFIG["lora_r"],
        lora_alpha=CONFIG["lora_alpha"],
        target_modules=CONFIG["lora_target_modules"],
        lora_dropout=CONFIG["lora_dropout"],
        bias=CONFIG["lora_bias"],
        task_type=CONFIG["lora_task_type"],
    )
    
    # Enable gradients before applying LoRA
    for param in model.parameters():
        param.requires_grad = True
    
    model = get_peft_model(model, lora_config)
    
    # Explicitly verify trainable parameters
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    logger.info(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")
    
    logger.info("LoRA adapter applied successfully")
    model.print_trainable_parameters()
except Exception as e:
    logger.error(f"Failed to apply LoRA adapter: {e}")
    raise

# Show memory usage
print_memory_usage()

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
'NoneType' object has no attribute 'cadam32bit_grad_fp32'
Note: you may need to restart the kernel to use updated packages.
'NoneType' object has no attribute 'cadam32bit_grad_fp32'


  warn("The installed version of bitsandbytes was compiled without GPU support. "


trainable params: 1,327,104 || all params: 248,904,960 || trainable%: 0.5332


In [11]:
model

PeftModelForSeq2SeqLM(
  (base_model): LoraModel(
    (model): T5ForConditionalGeneration(
      (shared): Embedding(32128, 768)
      (encoder): T5Stack(
        (embed_tokens): Embedding(32128, 768)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): lora.Linear(
                    (base_layer): Linear(in_features=768, out_features=768, bias=False)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=768, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=8, out_features=768, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
               

## 9. Load and Preprocess Dataset

Load the BillSum dataset, preprocess it for Flan-T5, and split into training and evaluation sets.


In [12]:
import re
def clean_text(text):
    text = text.strip()
    text = " ".join(text.split())
    # Only strip metadata if required
    text = re.sub(r'\s*\([^\)]{0,40}\)\s*', ' ', text)  # remove very short inlines
    text = re.sub(r'\s*\[[^\]]{0,40}\]\s*', ' ', text)
    return text

def preprocess_function(examples):
    try:
        input_texts = examples.get(CONFIG["text_col"], [])
        summary_texts = examples.get(CONFIG["summary_col"], [])

        if not input_texts:
            raise ValueError("Empty input text")

        cleaned_inputs = [clean_text(doc) for doc in input_texts]

        prompts = [f'{CONFIG["prompt_prefix"]}{doc}' for doc in cleaned_inputs]

        print(f"Raw: {input_texts[0][:150]}...")
        print(f"Cleaned: {cleaned_inputs[0][:150]}...")

        model_inputs = tokenizer(
            prompts,
            max_length=CONFIG["max_input_tokens"] - 32,
            padding="max_length",
            truncation=True,
        )

        # Use plain summaries — no extra formatting
        summaries = [s if s else "No summary provided." for s in summary_texts]

        labels = tokenizer(
            summaries,
            max_length=CONFIG["max_target_tokens"],
            truncation=True,
            padding="max_length",
        )

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    except Exception as e:
        print(f"[ERROR] Preprocessing failed: {e}")
        raise


In [13]:
logger.warning("Attempting to update some libraries.")
%pip install datasets --upgrade -q
%pip install fsspec --upgrade -q
%pip install pyarrow --upgrade -q
logger.warning("Library update attempts finished. If issues persist, ensure runtime was restarted after updates.")


from datasets import load_dataset
import pandas as pd

logger.info(f"Starting dataset loading and processing for: {CONFIG['dataset_name']}")

dataset = None
load_dataset_kwargs = {
    "path": CONFIG["dataset_name"],
    "split": f"train[:{CONFIG['sample_size']}]",
}

try:
    logger.info(f"Loading dataset with streaming=False and split sample size: {CONFIG['sample_size']}")
    dataset = load_dataset(
        **load_dataset_kwargs,
    )
    logger.info("Dataset loaded successfully with streaming=False")

except Exception as e:
    logger.warning(f"Attempt 1 failed: {str(e)}")
    logger.info("Attempting alternative loading method using pandas...")
    try:
        # Define splits and paths
        splits = {
            'train': 'data/train-00000-of-00001.parquet',
            'test': 'data/test-00000-of-00001.parquet',
            'ca_test': 'data/ca_test-00000-of-00001.parquet'
        }
        
        # Load training data using pandas
        df = pd.read_parquet(f"hf://datasets/FiscalNote/billsum/{splits['train']}")
        
        # Convert to Dataset format
        from datasets import Dataset
        dataset = Dataset.from_pandas(df)
        
        # Sample if needed
        if CONFIG['sample_size'] and CONFIG['sample_size'] < len(dataset):
            dataset = dataset.shuffle(seed=CONFIG['seed']).select(range(CONFIG['sample_size']))
            
        logger.info("Dataset loaded successfully using pandas alternative method")
        
    except Exception as e2:
        logger.error(f"Alternative loading method also failed: {str(e2)}")
        raise RuntimeError(f"Both loading methods failed. Original error: {str(e)}, Alternative error: {str(e2)}")
        
if dataset is None:
    logger.critical("FATAL: Dataset could not be loaded by any implemented method.")
    raise RuntimeError("Dataset loading failed. Please check the dataset name and parameters.")


logger.info(f"Dataset loaded. Original columns: {dataset.column_names}")
logger.info(f"Number of samples in loaded dataset: {len(dataset)}")

if 'text' in dataset.column_names:
    dataset = dataset.rename_column('text', 'article')
    logger.info("Renamed dataset column: 'text' -> 'article'")
    CONFIG['text_col'] = 'article'
    logger.info(f"Updated CONFIG['text_col'] to '{CONFIG['text_col']}'")
else:
    logger.warning(f"'text' column not found in dataset columns: {dataset.column_names}. Skipping rename. Current text column in CONFIG: {CONFIG['text_col']}")

logger.info("Creating dataset splits...")
try:
    total_size = len(dataset)
    logger.info(f"Total dataset size for splitting: {total_size}")

    train_size = int(total_size * CONFIG["split_train_frac"])
    eval_size = total_size - train_size

    if train_size <= 0 or eval_size <= 0:
        logger.error(f"Calculated train_size ({train_size}) or eval_size ({eval_size}) is non-positive. Aborting split.")
        raise ValueError("Train or evaluation set size is not positive. Check dataset size and split_train_frac.")

    dataset_shuffled = dataset.shuffle(seed=CONFIG["seed"]) # Shuffle before selecting
    train_dataset = dataset_shuffled.select(range(train_size))
    eval_dataset = dataset_shuffled.select(range(train_size, total_size))

    logger.info(f"Created splits - Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
except Exception as e:
    logger.error(f"Error creating dataset splits: {e}")
    raise

logger.info("Processing datasets (tokenization, etc.)...")
try:
    tokenized_datasets = {
        'train': train_dataset.map(
            preprocess_function,
            batched=True,
            remove_columns=train_dataset.column_names
        ),
        'eval': eval_dataset.map(
            preprocess_function,
            batched=True,
            remove_columns=eval_dataset.column_names
        )
    }
    logger.info("Dataset processing complete.")
    logger.info(f"Final tokenized dataset sizes - Training samples: {len(tokenized_datasets['train'])}, Evaluation samples: {len(tokenized_datasets['eval'])}")
except Exception as e:
    logger.error(f"Error processing datasets: {e}")
    raise

Attempting to update some libraries.


Note: you may need to restart the kernel to use updated packages.
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datasets 3.6.0 requires fsspec[http]<=2025.3.0,>=2023.1.0, but you have fsspec 2025.3.2 which is incompatible.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datasets 3.6.0 requires fsspec[http]<=2025.3.0,>=2023.1.0, but you have fsspec 2025.3.2 which is incompatible.[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


Library update attempts finished. If issues persist, ensure runtime was restarted after updates.


Note: you may need to restart the kernel to use updated packages.


Map:   0%|          | 0/1800 [00:00<?, ? examples/s]

Raw: SECTION 1. SHORT TITLE.

    This Act may be cited as the ``Pentagon 9/11 Memorial Commemorative 
Coin Act of 2005''.

SEC. 2. FINDINGS.

    The Cong...
Cleaned: SECTION 1. SHORT TITLE. This Act may be cited as the ``Pentagon 9/11 Memorial Commemorative Coin Act of 2005''. SEC. 2. FINDINGS. The Congress finds a...
Raw: SECTION 1. SHORT TITLE.

    This Act may be cited as the ``Children's Hope Act of 2003''.

SEC. 2. TAX CREDIT FOR CONTRIBUTIONS TO EDUCATION INVESTME...
Cleaned: SECTION 1. SHORT TITLE. This Act may be cited as the ``Children's Hope Act of 2003''. SEC. 2. TAX CREDIT FOR CONTRIBUTIONS TO EDUCATION INVESTMENT ORG...
Raw: SECTION 1. SHORT TITLE.

    This Act may be cited as the ``Children's Hope Act of 2003''.

SEC. 2. TAX CREDIT FOR CONTRIBUTIONS TO EDUCATION INVESTME...
Cleaned: SECTION 1. SHORT TITLE. This Act may be cited as the ``Children's Hope Act of 2003''. SEC. 2. TAX CREDIT FOR CONTRIBUTIONS TO EDUCATION INVESTMENT ORG...


Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Raw: SECTION 1. SHORT TITLE.

    This Act may be cited as the ``Social Security Account Number Anti-
Fraud Act''.

SEC. 2. STATEMENT OF PURPOSE.

    The ...
Cleaned: SECTION 1. SHORT TITLE. This Act may be cited as the ``Social Security Account Number Anti- Fraud Act''. SEC. 2. STATEMENT OF PURPOSE. The purposes of...


## 10. Define Training Arguments

Configure the training process using Seq2SeqTrainingArguments.


In [14]:
training_args = Seq2SeqTrainingArguments(
    output_dir=CONFIG["output_dir"],
    num_train_epochs=CONFIG["num_train_epochs"],
    per_device_train_batch_size=CONFIG["per_device_train_batch_size"],
    per_device_eval_batch_size=CONFIG["per_device_eval_batch_size"],
    gradient_accumulation_steps=CONFIG["gradient_accumulation_steps"],
    learning_rate=CONFIG["learning_rate"],
    weight_decay=CONFIG["weight_decay"],
    warmup_steps=CONFIG["warmup_steps"],
    
    # Optimization - modified for MPS
    fp16=False,  # Disable FP16 for MPS
    fp16_full_eval=False,
    bf16=False,
    optim="adamw_torch",  # Use adamw_torch optimizer
    
    # Disable features that might cause issues on MPS
    gradient_checkpointing=False,  # Disable gradient checkpointing on MPS
    group_by_length=False,  # Disable length batching
    dataloader_pin_memory=False,  # Disable pin_memory on MPS
    
    # Logging & Evaluation
    logging_dir=f"{CONFIG['output_dir']}/logs",
    logging_strategy="steps",
    logging_steps=CONFIG["logging_steps"],
    eval_strategy=CONFIG["evaluation_strategy"],
    eval_steps=CONFIG["eval_steps"],
    
    # Saving
    save_strategy=CONFIG["save_strategy"],
    save_steps=CONFIG["save_steps"],
    save_total_limit=CONFIG["save_total_limit"],
    
    # Model Loading
    load_best_model_at_end=CONFIG["load_best_model_at_end"],
    metric_for_best_model=CONFIG["metric_for_best_model"],
    greater_is_better=CONFIG["greater_is_better"],
    
    # Generation
    predict_with_generate=True,
    generation_max_length=CONFIG["max_target_tokens"],
    generation_num_beams=CONFIG["gen_num_beams"],
    
    # Other
    report_to=CONFIG["report_to"],
    seed=CONFIG["seed"],
    remove_unused_columns=False,  # Keep all columns
    overwrite_output_dir=CONFIG["overwrite_output_dir"],
    use_legacy_prediction_loop=True,  # Use legacy prediction loop to avoid past_key_values warning
)

logger.info(f"Training arguments configured for MPS")


## 11. Define Metrics Computation

Function to compute ROUGE and BLEU scores for evaluation.


In [15]:
from functools import lru_cache
import numpy as np
import nltk
from typing import Dict, List

@lru_cache(maxsize=1)
def get_metrics():
    """Load and cache evaluation metrics."""
    return {
        "rouge": evaluate.load("rouge"),
        "bleu": evaluate.load("bleu"),
        "bertscore": evaluate.load("bertscore")
    }

def process_texts(texts: List[str]) -> List[str]:
    """Clean and process texts for evaluation."""
    return ["\n".join(nltk.sent_tokenize(text.strip())) for text in texts]

def compute_metrics(eval_preds, batch_size: int = 32) -> Dict[str, float]:
    """Compute evaluation metrics with improved error handling and statistics."""
    try:
        metrics = get_metrics()
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]

        # Decode predictions and labels
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Process texts
        decoded_preds = process_texts(decoded_preds)
        decoded_labels = process_texts(decoded_labels)

        # Compute metrics
        rouge_results = metrics["rouge"].compute(
            predictions=decoded_preds, 
            references=decoded_labels
        )
        
        decoded_labels_bleu = [[label] for label in decoded_labels]
        bleu_results = metrics["bleu"].compute(
            predictions=decoded_preds, 
            references=decoded_labels_bleu
        )

        # Compute additional statistics
        pred_lengths = [len(p.split()) for p in decoded_preds]
        ref_lengths = [len(r.split()) for r in decoded_labels]

        results = {
            "rouge1": rouge_results["rouge1"],
            "rouge2": rouge_results["rouge2"],
            "rougeL": rouge_results["rougeL"],
            "rougeLsum": rouge_results["rougeLsum"],
            "bleu": bleu_results["bleu"],
            "avg_pred_length": np.mean(pred_lengths),
            "avg_ref_length": np.mean(ref_lengths),
            "compression_ratio": np.mean([p/r for p, r in zip(pred_lengths, ref_lengths)])
        }

        # Add generation length metric
        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
        results["gen_len"] = np.mean(prediction_lens)

        return {k: round(v, 4) for k, v in results.items()}

    except Exception as e:
        logger.error(f"Error computing metrics: {e}")
        return {
            "rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0,
            "rougeLsum": 0.0, "bleu": 0.0, "gen_len": 0,
            "error": str(e)
        }

logger.info("Enhanced metrics computation function defined.")

## 11. Initialize Trainer

Set up the `Seq2SeqTrainer` with the model, arguments, datasets, tokenizer, and metrics function.


In [16]:
from transformers import EarlyStoppingCallback,TrainerCallback
import gc
import torch


def clear_memory():
    """Clear unused memory before training"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print_memory_usage()

class MemoryTrackingCallback(TrainerCallback):
    """Callback to track memory usage during training"""
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 100 == 0:  # Monitor every 100 steps
            print_memory_usage()

def validate_training_args(args, model):
    """Validate training arguments for potential issues"""
    if args.per_device_train_batch_size * args.gradient_accumulation_steps > 32:
        logger.warning("Total batch size might be too large for available memory")
    
    if args.fp16 and not torch.cuda.is_available() and not torch.backends.mps.is_available():
        raise ValueError("FP16 requires CUDA or MPS")
    
    if args.fp16 and torch.backends.mps.is_available():
        logger.warning("FP16 is not fully supported on MPS. Turning it off.")
        args.fp16 = False
        args.fp16_full_eval = False

# Clear memory before initialization
clear_memory()

# Initialize data collator with error handling
try:
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=tokenizer.pad_token_id,
        pad_to_multiple_of=None  # Changed for MPS compatibility
    )
    logger.info("Data collator initialized successfully")
except Exception as e:
    logger.error(f"Failed to initialize data collator: {e}")
    raise

# Validate training arguments
validate_training_args(training_args, model)

# Initialize trainer with enhanced monitoring
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["eval"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=3),
        MemoryTrackingCallback()
    ]
)

logger.info("Trainer initialized with enhanced monitoring")
print_memory_usage()

  trainer = Seq2SeqTrainer(
No label_names provided for model class `PeftModelForSeq2SeqLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


## 12. Train the Model

Start the fine-tuning process. This will take some time depending on the dataset size and hardware. 🥳


In [17]:
def get_latest_checkpoint(checkpoint_dir):
    """Find most recent checkpoint in the directory"""
    checkpoints = [d for d in os.listdir(checkpoint_dir) 
                  if d.startswith("checkpoint-")]
    if not checkpoints:
        return None
    return os.path.join(checkpoint_dir, 
                       sorted(checkpoints, key=lambda x: int(x.split("-")[1]))[-1])

In [18]:
logger.info("Starting training...")
try:
    # Setup checkpoint directory
    checkpoint_dir = os.path.join(CONFIG["output_dir"], "checkpoints")
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Check for existing checkpoint
    resume_checkpoint = get_latest_checkpoint(checkpoint_dir)
    if resume_checkpoint:
        logger.info(f"Resuming from checkpoint: {resume_checkpoint}")
    
    # Print training info
    logger.info("Training Configuration:")
    logger.info(f"- Number of training examples: {len(trainer.train_dataset)}")
    logger.info(f"- Number of validation examples: {len(trainer.eval_dataset)}")
    logger.info(f"- Training Epochs: {CONFIG['num_train_epochs']}")
    logger.info(f"- Batch size: {CONFIG['per_device_train_batch_size']}")
    
    # Execute training
    train_result = trainer.train(resume_from_checkpoint=resume_checkpoint)
    
    # Log final metrics
    metrics = train_result.metrics
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()
    
    logger.info("Training completed successfully!")
    logger.info(f"Final Training Loss: {metrics.get('train_loss', 'N/A')}")
    
except Exception as e:
    logger.error(f"Training failed: {e}")
    if wandb.run:
        wandb.log({"training_error": str(e)})
        wandb.run.finish(exit_code=1)
    raise e

finally:
    print_memory_usage()
    torch.cuda.empty_cache()



Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Bleu,Gen Len,Error
50,13.2812,14.751487,0.0,0.0,0.0,0.0,0.0,0,out of range integral type conversion attempted


Error computing metrics: out of range integral type conversion attempted
Error computing metrics: out of range integral type conversion attempted


KeyboardInterrupt: 

## 13. Evaluate the Model

Evaluate the fine-tuned model on the evaluation set to get final performance metrics.


In [None]:
logger.log("Evaluating model...")
eval_metrics = trainer.evaluate()

logger.info("Evaluation metrics:")
for key, value in eval_metrics.items():
    logger.info(f"{key}: {value}")

# Log evaluation metrics
trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics) # Saves to all_results.json

# Prepare the training_report.json
training_report = {
    "model_name": CONFIG["model_name"],
    "dataset_name": CONFIG["dataset_name"],
    "lora_adapter_name": CONFIG["lora_adapter_name"],
    "output_directory": CONFIG["output_dir"],
    "training_arguments": {k: str(v) if isinstance(v, (torch.device, BitsAndBytesConfig)) else v for k, v in training_args.to_dict().items()}, # Convert non-serializable items
    "train_metrics": trainer.state.log_history[:-1], # All logged steps except final eval
    "eval_metrics": eval_metrics,
    "final_training_loss": trainer.state.log_history[-2].get('loss') if len(trainer.state.log_history) > 1 and 'loss' in trainer.state.log_history[-2] else trainer.state.log_history[-1].get('train_loss', 'N/A')
}


# Add ROUGE and BLEU from eval_metrics to the top level for easier access
for metric_key in ["eval_rouge1", "eval_rouge2", "eval_rougeL", "eval_bleu"]:
    if metric_key in eval_metrics:
        training_report[metric_key.replace("eval_", "")] = eval_metrics[metric_key]


# Save training_report.json locally
report_path = os.path.join(CONFIG["output_dir"], CONFIG["training_report_filename"])
with open(report_path, "w") as f:
    json.dump(training_report, f, indent=4)
logger.info(f"Training report saved to {report_path}")

if wandb.run:
    wandb.log(eval_metrics) # Log final eval metrics
    wandb.save(report_path) # Save report to W&B artifacts
    logger.info("Evaluation metrics and report logged to W&B.")

print_memory_usage()

## 14. Save Model and LoRA Adapter

Save the fine-tuned LoRA adapter and the full model if needed.


In [None]:
# Save the LoRA adapter
lora_adapter_path = os.path.join(CONFIG["output_dir"], CONFIG["lora_adapter_name"])
model.save_pretrained(lora_adapter_path) # Saves only the LoRA adapter
tokenizer.save_pretrained(lora_adapter_path) # Save tokenizer with adapter
logger.info(f"LoRA adapter and tokenizer saved to {lora_adapter_path}")

# To save the full model (optional, requires more space)
# merged_model_path = os.path.join(CONFIG["output_dir"], "merged_model_flan_t5_base_billsum")
# try:
#     # Merge LoRA weights with the base model
#     merged_model = model.merge_and_unload()
#     merged_model.save_pretrained(merged_model_path)
#     tokenizer.save_pretrained(merged_model_path)
#     logger.info(f"Full merged model saved to {merged_model_path}")
# except Exception as e:
#     logger.error(f"Could not merge and save full model: {e}. This might happen if the base model is not fully on CPU or due to memory constraints.")
#     logger.info("Only the LoRA adapter was saved.")


# If Google Drive is mounted, copy outputs there
if CONFIG["mount_drive"] and os.path.exists(CONFIG["gdrive_output_dir"]):
    logger.info(f"Copying outputs to Google Drive: {CONFIG['gdrive_output_dir']}")
    # Copy LoRA adapter
    gdrive_lora_path = os.path.join(CONFIG["gdrive_output_dir"], CONFIG["lora_adapter_name"])
    if os.path.exists(gdrive_lora_path):
        logger.info(f"Removing existing LoRA adapter from GDrive: {gdrive_lora_path}")
        os.system(f"rm -rf '{gdrive_lora_path}'") # Use os.system for `rm -rf`
    os.system(f"cp -r '{lora_adapter_path}' '{CONFIG['gdrive_output_dir']}/'")
    logger.info(f"LoRA adapter copied to {gdrive_lora_path}")

    # Copy training report
    gdrive_report_path = os.path.join(CONFIG["gdrive_output_dir"], CONFIG["training_report_filename"])
    os.system(f"cp '{report_path}' '{gdrive_report_path}'")
    logger.info(f"Training report copied to {gdrive_report_path}")

    # Copy all_results.json (contains eval metrics)
    all_results_path = os.path.join(CONFIG["output_dir"], "all_results.json")
    if os.path.exists(all_results_path):
        gdrive_all_results_path = os.path.join(CONFIG["gdrive_output_dir"], "all_results.json")
        os.system(f"cp '{all_results_path}' '{gdrive_all_results_path}'")
        logger.info(f"all_results.json copied to {gdrive_all_results_path}")

    # If merged model was saved and exists, copy it too
    # if 'merged_model' in locals() and os.path.exists(merged_model_path):
    #     gdrive_merged_model_path = os.path.join(CONFIG["gdrive_output_dir"], "merged_model_flan_t5_base_billsum")
    #     if os.path.exists(gdrive_merged_model_path):
    #         logger.info(f"Removing existing merged model from GDrive: {gdrive_merged_model_path}")
    #         os.system(f"rm -rf '{gdrive_merged_model_path}'")
    #     os.system(f"cp -r '{merged_model_path}' '{CONFIG['gdrive_output_dir']}/'")
    #     logger.info(f"Full merged model copied to {gdrive_merged_model_path}")
else:
    logger.warning("Google Drive not mounted or GDrive output path does not exist. Outputs saved locally.")

if wandb.run:
    # Log LoRA adapter as artifact if desired
    # lora_artifact = wandb.Artifact(CONFIG["lora_adapter_name"], type="model")
    # lora_artifact.add_dir(lora_adapter_path)
    # wandb.log_artifact(lora_artifact)
    # logger.info(f"LoRA adapter logged as W&B artifact: {CONFIG['lora_adapter_name']}")
    wandb.finish()

logger.info("Script finished.")

## Test the model with a sample input


In [None]:
# Test document - Cybersecurity and Privacy Protection Act
test_document = """
CYBERSECURITY AND PRIVACY PROTECTION ACT OF 2025

SECTION 1. SHORT TITLE AND PURPOSE

    (a) This Act may be cited as the 'Cybersecurity and Privacy Protection Act of 2025'.
    (b) The purpose of this Act is to enhance cybersecurity measures and protect individual privacy in the digital age.

SECTION 2. DEFINITIONS

In this Act:
    (1) 'Personal Data' means any information relating to an identified or identifiable natural person.
    (2) 'Data Controller' means any entity that determines the purposes and means of processing personal data.
    (3) 'Critical Infrastructure' means systems and assets vital to national security.

SECTION 3. CYBERSECURITY REQUIREMENTS

    (a) MANDATORY SECURITY MEASURES.—
        (1) All Data Controllers shall implement:
            (A) End-to-end encryption for data transmission
            (B) Multi-factor authentication for system access
            (C) Regular security audits and vulnerability assessments

    (b) INCIDENT REPORTING.—
        (1) Data Controllers shall report any security breach within 48 hours.
        (2) Penalties for non-compliance shall be up to $500,000 per incident.

SECTION 4. PRIVACY PROTECTIONS

    (a) CONSENT REQUIREMENTS.—
        (1) Explicit consent required for data collection
        (2) Right to access and delete personal data
        (3) Annual privacy impact assessments

    (b) CHILDREN'S PRIVACY.—
        (1) Enhanced protections for users under 13
        (2) Parental consent requirements

SECTION 5. ENFORCEMENT

    (a) The Federal Trade Commission shall enforce this Act.
    (b) State Attorneys General may bring civil actions.

SECTION 6. AUTHORIZATION OF APPROPRIATIONS

    There is authorized to be appropriated $275,000,000 for fiscal year 2026 to carry out this Act.
"""

# Test the model with the adapter
try:
    # Load the trained model with LoRA adapter
    logger.info(f"Loading model with LoRA adapter from {lora_adapter_path}...")
    
    # Get the device
    current_device = get_device()
    
    model = AutoModelForSeq2SeqLM.from_pretrained(
        lora_adapter_path,
        torch_dtype=torch.float32  # Use float32 for MPS compatibility
    )
    model = get_peft_model(model, lora_config)
    model.load_adapter(lora_adapter_path, CONFIG["lora_adapter_name"])
    model.set_adapter(CONFIG["lora_adapter_name"])
    model.eval()
    model.to(current_device)  # Move the model to the appropriate device
    
    # Process the test document
    logger.info("Tokenizing test document...")
    prompt = f"{CONFIG['prompt_prefix']}{test_document}"
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(current_device)
    
    # Generate summary
    logger.info("Generating summary...")
    with torch.no_grad():
        outputs = model.generate(
            **inputs, 
            max_length=CONFIG["max_target_tokens"],
            num_beams=CONFIG["gen_num_beams"],
            length_penalty=CONFIG["gen_length_penalty"],
            early_stopping=CONFIG["gen_early_stopping"]
        )
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    print("\nGenerated Summary:")
    print("-" * 50)
    print(summary)
    print("-" * 50)
    
except Exception as e:
    logger.error(f"Error during inference: {e}")
    print(f"An error occurred during inference: {str(e)}")

print_memory_usage()
