# Fine-tuning Flan-T5-base for Legal Document Summarization - google colab version

This notebook is a modified version of the original notebook by Gourab S. (@heygourab)

Author: Gourab S.(@heygourab),

<a href="https://github.com/heygourab"><img src="https://github.com/fluidicon.png" width="20" height="20" alt="github" /> @heygourab</a>

This notebook demonstrates fine-tuning the Flan-T5-base model on the BillSum dataset using LoRA (Low-Rank Adaptation). We'll use the Hugging Face ecosystem (`transformers`, `datasets`, `peft`) for efficient fine-tuning.

## Setup Overview

- Base Model: google/flan-t5-base
- Dataset: BillSum (~800 samples)
- Training: 3 epochs, batch size 16, learning rate 2e-4
- Optimization: LoRA with r=8, alpha=32
- Output: LoRA adapter weights & training metrics

## Prerequisites

This notebook assumes a Colab environment with a GPU available. If you're running this locally, make sure to install the required packages and set up your GPU environment accordingly.

<a href="https://colab.research.google.com/github/heygourab/pdf_summarization_model_fine_tuning/blob/main/notebooks/billsum_lora_finetune_colab.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"></a>


## 1. Environment Setup

First, let's install the required dependencies and set up GPU monitoring.


In [6]:
# Install required packages
%pip install -q transformers datasets accelerate evaluate peft bitsandbytes rouge_score nltk wandb omegaconf torch 
%pip install -U bitsandbytes -q

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


## 2. Import Libraries


In [10]:
import os
import json
import torch
import nltk
import evaluate
import numpy as np
import sys
import logging
from datetime import datetime
from datasets import load_dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    BitsAndBytesConfig
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType
)
import wandb
import psutil

## 3. Logger setup


In [11]:
def setup_logger(name="train_logger", level=logging.INFO, log_file=None):
    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.propagate = False  # Avoid duplicate logs

    # Clear existing handlers
    if logger.hasHandlers():
        logger.handlers.clear()

    # Formatter for log messages
    formatter = logging.Formatter(
        fmt='%(asctime)s — %(name)s — %(levelname)s — %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )

    # Console handler setup
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    # File handler setup
    if log_file is None:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        log_dir = os.path.join(os.getcwd(), 'logs')  # Safe fallback to current dir
        os.makedirs(log_dir, exist_ok=True)
        log_file = os.path.join(log_dir, f'training_{timestamp}.log')
    else:
        log_dir = os.path.dirname(log_file)
        if log_dir:
            os.makedirs(log_dir, exist_ok=True)

    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # Log header
    logger.info(f"Logger initialized: {name}")
    logger.info(f"Log file created at: {os.path.abspath(log_file)}")
    logger.info(f"Python version: {sys.version}")

    return logger

# Use it
logger = setup_logger("train_logger", logging.INFO)

2025-05-17 17:45:06 — train_logger — INFO — Logger initialized: train_logger
2025-05-17 17:45:06 — train_logger — INFO — Log file created at: /Users/gourabsarkar/Developer/college_project/pdf_summarization_model_fine_tuning/notebooks/logs/training_20250517_174506.log
2025-05-17 17:45:06 — train_logger — INFO — Python version: 3.10.17 (main, Apr  8 2025, 12:10:59) [Clang 16.0.0 (clang-1600.0.26.6)]


## 4. Loading the required NLTK libraries


In [12]:
# Download required NLTK data
for resource in ['punkt', 'punkt_tab']:
    try:
        nltk.download(resource, quiet=True)
        logger.info(f"Successfully downloaded NLTK resource: {resource}")
    except Exception as e:
        logger.error(f"Error downloading {resource}: {e}")

2025-05-17 17:45:11 — train_logger — INFO — Successfully downloaded NLTK resource: punkt
2025-05-17 17:45:14 — train_logger — INFO — Successfully downloaded NLTK resource: punkt_tab


## 4. Memory Usage Monitoring

The `print_memory_usage()` function monitors system resource utilization during model training:

- Tracks RAM usage by getting the Resident Set Size (RSS) of current process in GB
- For GPU-enabled systems:
  - Reports allocated GPU memory
  - Shows total available GPU memory
  - Calculates percentage of GPU memory utilization
  - Resets peak memory tracking statistics

This helps identify potential memory bottlenecks and optimize resource usage during training.


In [13]:
def print_memory_usage():
    process = psutil.Process(os.getpid()) # Get the current process

    ram_gb = process.memory_info().rss / 1e9 # Convert bytes to GB
    total_gb = psutil.virtual_memory().total / 1e9 # Total system RAM in GB

    logger.info(f"RAM usage: {ram_gb:.2f} GB") # Current process RAM usage
    logger.info(f"Total system RAM: {total_gb:.2f} GB") # Total system RAM

    if torch.cuda.is_available():
        torch.cuda.synchronize()
        gpu_mem = torch.cuda.memory_allocated() / 1e9
        gpu_total = torch.cuda.get_device_properties(0).total_memory / 1e9
        peak_gpu_mem = torch.cuda.max_memory_allocated() / 1e9

        logger.info(f"GPU memory usage: {gpu_mem:.2f}/{gpu_total:.2f} GB ({gpu_mem/gpu_total*100:.1f}%)")
        logger.info(f"Peak GPU memory: {peak_gpu_mem:.2f} GB")

        torch.cuda.reset_peak_memory_stats()

print_memory_usage()

2025-05-17 17:45:22 — train_logger — INFO — RAM usage: 0.05 GB
2025-05-17 17:45:22 — train_logger — INFO — Total system RAM: 8.59 GB


## 5. Configuration Parameters

all hyperparameters and configuration settings for the model, dataset, LoRA, and training.


In [1]:
# Configuration
CONFIG = {
    # Model
    "model_name": "google/flan-t5-base",
    "model_type": "encoder-decoder",
    "quantization_config": BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    ),

    # Dataset
    "dataset_name": "billsum",
    "text_col": "text",
    "summary_col": "summary",
    "max_input_tokens": 512,
    "max_target_tokens": 128,
    "sample_size": 2000,
    "filter_by_length": True,
    "split_train_frac": 0.9,

    # Prompt
    "prompt_prefix": "Summarize this legal document:\n",

    # LoRA
    "lora_r": 8, # attention heads
    "lora_alpha": 32, # scaling factor
    "lora_target_modules": ["q", "v"], 
    "lora_dropout": 0.1,
    "lora_bias": "none",
    "lora_task_type": TaskType.SEQ_2_SEQ_LM,

    # Training
    "output_dir_base": "lora_billsum_flan_t5_base",
    "num_train_epochs": 3,
    "per_device_train_batch_size": 16,
    "per_device_eval_batch_size": 16,
    "gradient_accumulation_steps": 1,
    "learning_rate": 2e-4, # Initial learning rate
    "weight_decay": 0.01, # Weight decay
    "warmup_steps": 100, # Warmup steps
    "fp16": True,
    "logging_steps": 10,
    "evaluation_strategy": "steps",
    "eval_steps": 100,
    "save_strategy": "steps",
    "save_steps": 500,
    "save_total_limit": 2,
    "load_best_model_at_end": True,
    "metric_for_best_model": "rougeL",
    "greater_is_better": True,
    "report_to": "wandb",
    "gradient_checkpointing": True,
    "overwrite_output_dir": True,

    # Generation
    "gen_num_beams": 4,
    "gen_length_penalty": 1.0,
    "gen_early_stopping": True,

    # Seed
    "seed": 42,

    # Output paths
    "mount_drive": True,
    "drive_path": "MyDrive/ML_models/pdf_summarization",
    "training_report_filename": "training_report.json",
    "lora_adapter_name": "lora_billsum"
}


# Create timestamped output directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
CONFIG["output_dir"] = f"{CONFIG['output_dir_base']}_{timestamp}"

logger.info(f"Output directory: {CONFIG['output_dir']}")
if CONFIG["mount_drive"]:
    CONFIG["gdrive_output_dir"] = os.path.join("/content/drive", CONFIG["drive_path"], CONFIG["output_dir"])
    logger.info(f"Google Drive output directory: {CONFIG['gdrive_output_dir']}")


NameError: name 'BitsAndBytesConfig' is not defined

## 6. Login to Hugging Face Hub and Weights & Biases

You'll need to log in to Hugging Face to download models/datasets and to Weights & Biases for experiment tracking.
You can get your Hugging Face token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
and your W&B API key from [https://wandb.ai/authorize](https://wandb.ai/authorize).


In [None]:
from huggingface_hub import HfFolder, notebook_login

try:
    if HfFolder.get_token() is None:
        logger.info("Hugging Face token not found. Please log in.")
        notebook_login()
    else:
        logger.info("Already logged in to Hugging Face Hub.")
except Exception as e:
    logger.error(f"An error occurred during Hugging Face login check: {e}")
    logger.info("Attempting login...")
    notebook_login()

# Login to Weights & Biases
try:
    wandb.login()
    wandb.init(project="flan-t5-billsum-lora", name=CONFIG["output_dir"], config=CONFIG)
    logger.info("Successfully logged in to W&B and initialized experiment.")
except Exception as e:
    logger.error(f"Could not login to W&B: {e}. Ensure you have run `wandb login` or set WANDB_API_KEY.")
    CONFIG["report_to"] = "tensorboard" # Fallback to tensorboard
    logger.info("Falling back to TensorBoard for logging.")
    # No explicit init for tensorboard here, Trainer handles it via TrainingArguments
# Note: Use environment variables or notebook secrets to store your tokens securely

## 7. Mount Google Drive (Optional)

If you want to save your model checkpoints and outputs to Google Drive, mount it here.


In [None]:
def is_colab():
    """Check if the current environment is Google Colab."""
    try:
        import google.colab
        return True
    except ImportError:
        return False
        
if is_colab() and CONFIG["mount_drive"]:
    from google.colab import drive
    try:
        drive.mount('/content/drive')
        logger.info("Google Drive mounted successfully.")
        # Create the output directory on Drive if it doesn't exist
        if not os.path.exists(CONFIG["gdrive_output_dir"]):
            os.makedirs(CONFIG["gdrive_output_dir"], exist_ok=True)
            logger.info(f"Created Google Drive output directory: {CONFIG['gdrive_output_dir']}")
    except Exception as e:
        logger.error(f"Failed to mount Google Drive: {e}")
        logger.info("Proceeding without Google Drive. Outputs will be saved to Colab ephemeral storage.")
        CONFIG["mount_drive"] = False # Disable drive features if mount fails
else:
    logger.info("Not running in Colab or Google Drive is disabled in the configuration.")
    CONFIG["mount_drive"] = False

## 8. Load Model, Tokenizer and Configure LoRA

Loads the Flan-T5-base model and tokenizer from Hugging Face, configures the model for LoRA training and returns the model and tokenizer objects.


In [None]:
# Install latest bitsandbytes and required dependencies
%pip install -U bitsandbytes --no-cache-dir -q
%pip install accelerate --upgrade -q
%pip install transformers --upgrade -q

# Import required libraries
import torch
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig
import bitsandbytes as bnb

# Function to verify CUDA and bitsandbytes setup
def verify_installation():
    logger.info(f"PyTorch version: {torch.__version__}")
    logger.info(f"CUDA available: {torch.cuda.is_available()}")
    logger.info(f"bitsandbytes version: {bnb.__version__}")
    
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. 4-bit quantization requires a CUDA-enabled GPU.")
    
    # Test BitsAndBytes CUDA kernels
    try:
        _ = bnb.matmul(torch.zeros(2, 2).cuda(), torch.zeros(2, 2).cuda())
        logger.info("BitsAndBytes CUDA kernels working correctly")
    except Exception as e:
        logger.error(f"BitsAndBytes CUDA test failed: {e}")
        raise

# Verify installation
verify_installation()

# Load tokenizer
logger.info(f"Loading tokenizer for model: {CONFIG['model_name']}")
try:
    tokenizer = AutoTokenizer.from_pretrained(
        CONFIG["model_name"],
        use_fast=True,
        padding_side="left",
        model_max_length=CONFIG["max_input_tokens"]
    )
    logger.info("Tokenizer loaded successfully")
except Exception as e:
    logger.error(f"Failed to load tokenizer: {e}")
    raise


# Configure quantization
logger.info("Configuring 4-bit quantization...")
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,  # Changed from bfloat16 for better compatibility
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

# Load model with proper error handling
logger.info(f"Loading base model: {CONFIG['model_name']} with 4-bit quantization...")
try:
    model = AutoModelForSeq2SeqLM.from_pretrained(
        CONFIG["model_name"],
        quantization_config=quantization_config,
        device_map="auto",
        torch_dtype=torch.float16,  # Match compute dtype
        low_cpu_mem_usage=True
    )
    logger.info("Model loaded successfully")

except Exception as e:
    logger.error(f"Failed to load model: {e}")
    logger.info("Attempting to load without quantization as fallback...")
    try:
        model = AutoModelForSeq2SeqLM.from_pretrained(
            CONFIG["model_name"],
            device_map="auto",
            torch_dtype=torch.float16
        )
        logger.warning("Model loaded without quantization")
    except Exception as e:
        logger.error(f"Failed to load model even without quantization: {e}")
        raise

# Prepare model for k-bit training
logger.info("Preparing model for k-bit training...")
try:
    model = prepare_model_for_kbit_training(
        model,
        use_gradient_checkpointing=CONFIG["gradient_checkpointing"]
    )
    logger.info("Model prepared for k-bit training")
except Exception as e:
    logger.error(f"Failed to prepare model for kbit training: {e}")
    raise

# LoRA configuration and application
lora_config = LoraConfig(
    r=CONFIG["lora_r"],
    lora_alpha=CONFIG["lora_alpha"],
    target_modules=CONFIG["lora_target_modules"],
    lora_dropout=CONFIG["lora_dropout"],
    bias=CONFIG["lora_bias"],
    task_type=CONFIG["lora_task_type"],
)

logger.info('Applying LoRA to the model')
try:
    model = get_peft_model(model, lora_config)
    logger.info("LoRA configured and applied to the model")
except Exception as e:
    logger.error(f"Failed to apply LoRA to the model: {e}")
    raise

# Print model statistics
model.print_trainable_parameters()
print_memory_usage()

## 9. Load and Preprocess Dataset

Load the BillSum dataset, preprocess it for Flan-T5, and split into training and evaluation sets.


In [None]:
def clean_text(text):
    """Clean and normalize text"""
    text = text.strip()
    text = " ".join(text.split())  # Normalize whitespace
    return text

def preprocess_function(examples):
    try:
        if not examples[CONFIG["text_col"]]:
            raise ValueError("Empty input texts received")
        
        # Clean and normalize texts
        cleaned_inputs = [clean_text(doc) for doc in examples[CONFIG["text_col"]]]
        
        # Generate input text with prompt prefix
        inputs = [CONFIG["prompt_prefix"] + doc for doc in cleaned_inputs]
        
        # Tokenize inputs with improved settings
        try:
            model_inputs = tokenizer(
                inputs,
                max_length=CONFIG["max_input_tokens"],
                padding="max_length",
                truncation=True,
                return_tensors="pt",  # Return PyTorch tensors
                return_attention_mask=True  # Explicitly request attention masks
            )
        except Exception as e:
            logger.error(f"Error tokenizing inputs: {e}")
            raise

        # Tokenize targets
        try:
            with tokenizer.as_target_tokenizer():
                labels = tokenizer(
                    examples[CONFIG["summary_col"]],
                    max_length=CONFIG["max_target_tokens"],
                    truncation=True,
                    padding="max_length"
                )

        except Exception as e:
            logger.error(f"Error tokenizing targets: {e}")
            raise

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    except KeyError as e:
        logger.error(f"Missing column in dataset: {e}")
        raise
    
    except Exception as e:
        logger.error(f"Preprocessing failed: {e}")
        raise

In [None]:
logger.warning("Attempting to update some libraries.")
%pip install datasets --upgrade -q
%pip install fsspec --upgrade -q
%pip install pyarrow --upgrade -q
logger.warning("Library update attempts finished. If issues persist, ensure runtime was restarted after updates.")

# --- Dataset Loading and Processing ---
import shutil
from datasets import load_dataset, Dataset

logger.info(f"Starting dataset loading and processing for: {CONFIG['dataset_name']}")

# Cache Directory Setup
hf_cache_home = os.path.expanduser(os.path.join("~", ".cache", "huggingface"))
default_datasets_cache_dir = os.path.join(hf_cache_home, "datasets")

# Construct dataset-specific cache path (Hugging Face replaces '/' with '--' for dataset names in cache paths)
dataset_name_for_cache = CONFIG["dataset_name"].replace("/", "--")
specific_dataset_cache_dir = os.path.join(default_datasets_cache_dir, dataset_name_for_cache)

# Cache Clearing
logger.info("Attempting to clear dataset cache...")
try:
    if os.path.exists(specific_dataset_cache_dir):
        shutil.rmtree(specific_dataset_cache_dir)
        logger.info(f"Successfully cleared specific cache for {CONFIG['dataset_name']} at {specific_dataset_cache_dir}")
    else:
        logger.info(f"Specific cache for {CONFIG['dataset_name']} not found at {specific_dataset_cache_dir}, no need to clear.")

except Exception as e:
    logger.warning(f"Could not clear dataset cache: {e}. Continuing...")


dataset = None

try:
    logger.info("Attempt 1: Loading full 'train' split with streaming=False...")
    full_train_dataset = load_dataset(CONFIG["dataset_name"], split='train', streaming=False, trust_remote_code=True) # Added trust_remote_code
    logger.info(f"Successfully loaded full 'train' split. Total records: {len(full_train_dataset)}")
    
    if len(full_train_dataset) >= CONFIG["sample_size"]:
        dataset = full_train_dataset.select(range(CONFIG["sample_size"]))
        logger.info(f"Selected first {CONFIG['sample_size']} samples. Dataset size: {len(dataset)}")
    else:
        logger.warning(f"Full 'train' split has {len(full_train_dataset)} samples (less than {CONFIG['sample_size']}). Using all available.")
        dataset = full_train_dataset
    logger.info("Dataset loaded and sliced successfully using Attempt 1 (full load then select).")

except Exception as e1:
    logger.warning(f"Attempt 1 (streaming=False, full load then slice) failed: {e1}")
    
    # Attempt 2: Load with 'train[:800]' split directly with streaming=False
    try:
        logger.info(f"Attempt 2: Loading with 'train[:{CONFIG['sample_size']}]' split directly with streaming=False...")
        dataset = load_dataset(CONFIG["dataset_name"], split=f'train[:{CONFIG["sample_size"]}]', streaming=False, trust_remote_code=True)
        logger.info(f"Attempt 2: Dataset loaded successfully with split='train[:{CONFIG['sample_size']}]' and streaming=False.")

    except Exception as e2:
        logger.warning(f"Attempt 2 (streaming=False, direct slice) failed: {e2}")
        
        # Attempt 3: Load with streaming=True and split='train[:800]', then convert
        try:
            logger.info(f"Attempt 3: Loading with streaming=True and split='train[:{CONFIG['sample_size']}]'...")
            iterable_dataset = load_dataset(
                CONFIG["dataset_name"],
                split='train[:800]',
                streaming=True,
                trust_remote_code=True 
            )
            logger.info("Attempt 3: IterableDataset loaded successfully with streaming=True.")
            
            logger.info("Converting IterableDataset to map-style Dataset...")
            records = list(iterable_dataset)  # Materialize the (expected) 800 records
            if not records:
                raise ValueError(f"Streaming dataset with slice 'train[:{CONFIG['sample_size']}]' resulted in no records.")
            dataset = Dataset.from_list(records)
            logger.info(f"Converted to map-style Dataset. Size: {len(dataset)}")

        except Exception as e_stream:
            logger.error(f"Attempt 3 (streaming=True) failed: {e_stream}")
            logger.error("All dataset loading attempts failed.")
            raise # Re-raise the last error to halt execution

if dataset is None:
    logger.critical("FATAL: Dataset could not be loaded by any implemented method.")
    raise RuntimeError("Failed to load dataset after multiple attempts.")

# --- Dataset Renaming and Processing (Your existing logic) ---

logger.info(f"Dataset loaded. Original columns: {dataset.column_names}")
logger.info(f"Number of samples in loaded dataset: {len(dataset)}")

if 'text' in dataset.column_names:
    dataset = dataset.rename_column('text', 'article')
    logger.info("Renamed dataset column: 'text' -> 'article'")
    CONFIG['text_col'] = 'article'
    logger.info(f"Updated CONFIG['text_col'] to '{CONFIG['text_col']}'")
else:
    logger.warning(f"'text' column not found in dataset columns: {dataset.column_names}. Skipping rename. Current text column in CONFIG: {CONFIG['text_col']}")

logger.info("Creating dataset splits...")
try:
    total_size = len(dataset)
    logger.info(f"Total dataset size for splitting: {total_size}")

    train_size = int(total_size * CONFIG["split_train_frac"])
    eval_size = total_size - train_size

    if train_size <= 0 or eval_size <= 0:
        logger.error(f"Calculated train_size ({train_size}) or eval_size ({eval_size}) is non-positive. Aborting split.")
        raise ValueError("Train or evaluation set size is not positive. Check dataset size and split_train_frac.")

    dataset_shuffled = dataset.shuffle(seed=CONFIG["seed"]) # Shuffle before selecting
    train_dataset = dataset_shuffled.select(range(train_size))
    eval_dataset = dataset_shuffled.select(range(train_size, total_size))

    logger.info(f"Created splits - Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
except Exception as e:
    logger.error(f"Error creating dataset splits: {e}")
    raise

logger.info("Processing datasets (tokenization, etc.)...")
try:
    tokenized_datasets = {
        'train': train_dataset.map(
            preprocess_function,
            batched=True,
            remove_columns=train_dataset.column_names
        ),
        'eval': eval_dataset.map(
            preprocess_function,
            batched=True,
            remove_columns=eval_dataset.column_names
        )
    }
    logger.info("Dataset processing complete.")
    logger.info(f"Final tokenized dataset sizes - Training samples: {len(tokenized_datasets['train'])}, Evaluation samples: {len(tokenized_datasets['eval'])}")
except Exception as e:
    logger.error(f"Error processing datasets: {e}")
    raise

## 10. Define Training Arguments

Configure the training process using Seq2SeqTrainingArguments.


In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir=CONFIG["output_dir"],
    num_train_epochs=CONFIG["num_train_epochs"],
    per_device_train_batch_size=CONFIG["per_device_train_batch_size"],
    per_device_eval_batch_size=CONFIG["per_device_eval_batch_size"],
    gradient_accumulation_steps=4,  
    learning_rate=CONFIG["learning_rate"],
    weight_decay=CONFIG["weight_decay"],
    warmup_steps=CONFIG["warmup_steps"],
    
    # Optimization
    fp16=CONFIG["fp16"],
    fp16_full_eval=True,
    bf16=False,
    optim="adafactor",
    
    # Logging & Evaluation
    logging_dir=f"{CONFIG['output_dir']}/logs",
    logging_strategy="steps",
    logging_steps=50,  # More frequent logging
    evaluation_strategy=CONFIG["evaluation_strategy"],
    eval_steps=100,
    
    # Saving
    save_strategy="steps",
    save_steps=250,  # More frequent saving
    save_total_limit=3,  # Keep one more checkpoint
    
    # Model Loading
    load_best_model_at_end=True,
    metric_for_best_model="rougeL",
    greater_is_better=True,
    
    # Generation
    predict_with_generate=True,
    generation_max_length=CONFIG["max_target_tokens"],
    generation_num_beams=CONFIG["gen_num_beams"],
    
    # Other
    report_to=CONFIG["report_to"],
    seed=CONFIG["seed"],
    gradient_checkpointing=CONFIG["gradient_checkpointing"],
    overwrite_output_dir=CONFIG["overwrite_output_dir"],
)

logger.info(f"Training arguments: {training_args}")


## 11. Define Metrics Computation

Function to compute ROUGE and BLEU scores for evaluation.


In [None]:
from functools import lru_cache
import numpy as np
import nltk
from typing import Dict, Any, List

@lru_cache(maxsize=1)
def get_metrics():
    """Load and cache evaluation metrics."""
    return {
        "rouge": evaluate.load("rouge"),
        "bleu": evaluate.load("bleu"),
        "bertscore": evaluate.load("bertscore")
    }

def process_texts(texts: List[str]) -> List[str]:
    """Clean and process texts for evaluation."""
    return ["\n".join(nltk.sent_tokenize(text.strip())) for text in texts]

def compute_metrics(eval_preds, batch_size: int = 32) -> Dict[str, float]:
    """Compute evaluation metrics with improved error handling and statistics."""
    try:
        metrics = get_metrics()
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]

        # Decode predictions and labels
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Process texts
        decoded_preds = process_texts(decoded_preds)
        decoded_labels = process_texts(decoded_labels)

        # Compute metrics
        rouge_results = metrics["rouge"].compute(
            predictions=decoded_preds, 
            references=decoded_labels
        )
        
        decoded_labels_bleu = [[label] for label in decoded_labels]
        bleu_results = metrics["bleu"].compute(
            predictions=decoded_preds, 
            references=decoded_labels_bleu
        )

        # Compute additional statistics
        pred_lengths = [len(p.split()) for p in decoded_preds]
        ref_lengths = [len(r.split()) for r in decoded_labels]

        results = {
            "rouge1": rouge_results["rouge1"],
            "rouge2": rouge_results["rouge2"],
            "rougeL": rouge_results["rougeL"],
            "rougeLsum": rouge_results["rougeLsum"],
            "bleu": bleu_results["bleu"],
            "avg_pred_length": np.mean(pred_lengths),
            "avg_ref_length": np.mean(ref_lengths),
            "compression_ratio": np.mean([p/r for p, r in zip(pred_lengths, ref_lengths)])
        }

        # Add generation length metric
        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
        results["gen_len"] = np.mean(prediction_lens)

        return {k: round(v, 4) for k, v in results.items()}

    except Exception as e:
        logger.error(f"Error computing metrics: {e}")
        return {
            "rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0,
            "rougeLsum": 0.0, "bleu": 0.0, "gen_len": 0,
            "error": str(e)
        }

logger.info("Enhanced metrics computation function defined.")

## 11. Initialize Trainer

Set up the `Seq2SeqTrainer` with the model, arguments, datasets, tokenizer, and metrics function.


In [None]:
from transformers import EarlyStoppingCallback,TrainerCallback
import gc
import torch


def clear_memory():
    """Clear unused memory before training"""
    gc.collect()
    torch.cuda.empty_cache()
    print_memory_usage()

class MemoryTrackingCallback(TrainerCallback):
    """Callback to track memory usage during training"""
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 100 == 0:  # Monitor every 100 steps
            print_memory_usage()

def validate_training_args(args, model):
    """Validate training arguments for potential issues"""
    if args.per_device_train_batch_size * args.gradient_accumulation_steps > 64:
        logger.warning("Total batch size might be too large for available memory")
    
    if args.fp16 and not torch.cuda.is_available():
        raise ValueError("FP16 requires CUDA")

# Clear memory before initialization
clear_memory()

# Initialize data collator with error handling
try:
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=tokenizer.pad_token_id,
        pad_to_multiple_of=8 if CONFIG["fp16"] else None
    )
    logger.info("Data collator initialized successfully")
except Exception as e:
    logger.error(f"Failed to initialize data collator: {e}")
    raise

# Validate training arguments
validate_training_args(training_args, model)

# Initialize trainer with enhanced monitoring
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["eval"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=3),
        MemoryTrackingCallback()
    ]
)

logger.info("Trainer initialized with enhanced monitoring")
print_memory_usage()

## 12. Train the Model

Start the fine-tuning process. This will take some time depending on the dataset size and hardware. 🥳


In [None]:
logger.info("Starting training...\
,
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Execute training
    train_result = trainer.train(
        resume_from_checkpoint=None  # Set to checkpoint path if resuming
    )
    
    # Save training metrics
    metrics = train_result.metrics
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

    # Store and log training loss
    training_loss = metrics.get("train_loss", "N/A")
    logger.info(f"Final Training Loss: {training_loss}")
    
    # Save checkpoint
    trainer.save_model(checkpoint_dir)
    logger.info(f"Model checkpoint saved to {checkpoint_dir}")

except Exception as e:
    logger.error(f"An error occurred during training: {e}")
    if wandb.run:
        wandb.log({"training_error": str(e)})
        wandb.run.finish(exit_code=1)
    raise e

finally:
    # Always clean up memory
    print_memory_usage()
    torch.cuda.empty_cache()

## 13. Evaluate the Model

Evaluate the fine-tuned model on the evaluation set to get final performance metrics.


In [None]:
logger.log("Evaluating model...")
eval_metrics = trainer.evaluate()

logger.info("Evaluation metrics:")
for key, value in eval_metrics.items():
    logger.info(f"{key}: {value}")

# Log evaluation metrics
trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics) # Saves to all_results.json

# Prepare the training_report.json
training_report = {
    "model_name": CONFIG["model_name"],
    "dataset_name": CONFIG["dataset_name"],
    "lora_adapter_name": CONFIG["lora_adapter_name"],
    "output_directory": CONFIG["output_dir"],
    "training_arguments": {k: str(v) if isinstance(v, (torch.device, BitsAndBytesConfig)) else v for k, v in training_args.to_dict().items()}, # Convert non-serializable items
    "train_metrics": trainer.state.log_history[:-1], # All logged steps except final eval
    "eval_metrics": eval_metrics,
    "final_training_loss": trainer.state.log_history[-2].get('loss') if len(trainer.state.log_history) > 1 and 'loss' in trainer.state.log_history[-2] else trainer.state.log_history[-1].get('train_loss', 'N/A')
}


# Add ROUGE and BLEU from eval_metrics to the top level for easier access
for metric_key in ["eval_rouge1", "eval_rouge2", "eval_rougeL", "eval_bleu"]:
    if metric_key in eval_metrics:
        training_report[metric_key.replace("eval_", "")] = eval_metrics[metric_key]


# Save training_report.json locally
report_path = os.path.join(CONFIG["output_dir"], CONFIG["training_report_filename"])
with open(report_path, "w") as f:
    json.dump(training_report, f, indent=4)
logger.info(f"Training report saved to {report_path}")

if wandb.run:
    wandb.log(eval_metrics) # Log final eval metrics
    wandb.save(report_path) # Save report to W&B artifacts
    logger.info("Evaluation metrics and report logged to W&B.")

print_memory_usage()

## 14. Save Model and LoRA Adapter

Save the fine-tuned LoRA adapter and the full model if needed.


In [None]:
# Save the LoRA adapter
lora_adapter_path = os.path.join(CONFIG["output_dir"], CONFIG["lora_adapter_name"])
model.save_pretrained(lora_adapter_path) # Saves only the LoRA adapter
tokenizer.save_pretrained(lora_adapter_path) # Save tokenizer with adapter
logger.info(f"LoRA adapter and tokenizer saved to {lora_adapter_path}")

# To save the full model (optional, requires more space)
# merged_model_path = os.path.join(CONFIG["output_dir"], "merged_model_flan_t5_base_billsum")
# try:
#     # Merge LoRA weights with the base model
#     merged_model = model.merge_and_unload()
#     merged_model.save_pretrained(merged_model_path)
#     tokenizer.save_pretrained(merged_model_path)
#     logger.info(f"Full merged model saved to {merged_model_path}")
# except Exception as e:
#     logger.error(f"Could not merge and save full model: {e}. This might happen if the base model is not fully on CPU or due to memory constraints.")
#     logger.info("Only the LoRA adapter was saved.")


# If Google Drive is mounted, copy outputs there
if CONFIG["mount_drive"] and os.path.exists(CONFIG["gdrive_output_dir"]):
    logger.info(f"Copying outputs to Google Drive: {CONFIG['gdrive_output_dir']}")
    # Copy LoRA adapter
    gdrive_lora_path = os.path.join(CONFIG["gdrive_output_dir"], CONFIG["lora_adapter_name"])
    if os.path.exists(gdrive_lora_path):
        logger.info(f"Removing existing LoRA adapter from GDrive: {gdrive_lora_path}")
        os.system(f"rm -rf '{gdrive_lora_path}'") # Use os.system for `rm -rf`
    os.system(f"cp -r '{lora_adapter_path}' '{CONFIG['gdrive_output_dir']}/'")
    logger.info(f"LoRA adapter copied to {gdrive_lora_path}")

    # Copy training report
    gdrive_report_path = os.path.join(CONFIG["gdrive_output_dir"], CONFIG["training_report_filename"])
    os.system(f"cp '{report_path}' '{gdrive_report_path}'")
    logger.info(f"Training report copied to {gdrive_report_path}")

    # Copy all_results.json (contains eval metrics)
    all_results_path = os.path.join(CONFIG["output_dir"], "all_results.json")
    if os.path.exists(all_results_path):
        gdrive_all_results_path = os.path.join(CONFIG["gdrive_output_dir"], "all_results.json")
        os.system(f"cp '{all_results_path}' '{gdrive_all_results_path}'")
        logger.info(f"all_results.json copied to {gdrive_all_results_path}")

    # If merged model was saved and exists, copy it too
    # if 'merged_model' in locals() and os.path.exists(merged_model_path):
    #     gdrive_merged_model_path = os.path.join(CONFIG["gdrive_output_dir"], "merged_model_flan_t5_base_billsum")
    #     if os.path.exists(gdrive_merged_model_path):
    #         logger.info(f"Removing existing merged model from GDrive: {gdrive_merged_model_path}")
    #         os.system(f"rm -rf '{gdrive_merged_model_path}'")
    #     os.system(f"cp -r '{merged_model_path}' '{CONFIG['gdrive_output_dir']}/'")
    #     logger.info(f"Full merged model copied to {gdrive_merged_model_path}")
else:
    logger.warning("Google Drive not mounted or GDrive output path does not exist. Outputs saved locally.")

if wandb.run:
    # Log LoRA adapter as artifact if desired
    # lora_artifact = wandb.Artifact(CONFIG["lora_adapter_name"], type="model")
    # lora_artifact.add_dir(lora_adapter_path)
    # wandb.log_artifact(lora_artifact)
    # logger.info(f"LoRA adapter logged as W&B artifact: {CONFIG['lora_adapter_name']}")
    wandb.finish()

logger.info("Script finished.")