# Fine-tuning Flan-T5-base for Legal Document Summarization - Colab GPU Version

Author: Gourab S. (@heygourab), <a href="https://github.com/heygourab"><img src="https://github.com/fluidicon.png" width="20" height="20" alt="github" /> @heygourab</a>

This notebook fine-tunes Flan-T5-base on the BillSum dataset using LoRA with 4-bit quantization on Google Colab's GPU (T4, 16GB VRAM). Optimized for stability, fixes NaN loss, and handles memory constraints.

## Setup Overview

- Base Model: google/flan-t5-base
- Dataset: BillSum (~1000 samples)
- Hardware: Colab GPU (T4, 16GB VRAM)

## Prerequisites

- Colab environment with GPU enabled
- Install dependencies: `pip install transformers datasets accelerate evaluate peft bitsandbytes nltk wandb torch psutil`

<a href="https://colab.research.google.com"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"></a>


## 1. Install Dependencies

**What it does**: Installs specific versions of required packages to avoid version mismatches and ensure bitsandbytes works with CUDA.


In [None]:
%pip install -q torch --force-reinstall
%pip install -q bitsandbytes accelerate datasets evaluate peft transformers nltk wandb psutil

[31mERROR: Could not find a version that satisfies the requirement bitsandbytes==0.44.1 (from versions: 0.31.8, 0.32.0, 0.32.1, 0.32.2, 0.32.3, 0.33.0, 0.33.1, 0.34.0, 0.35.0, 0.35.1, 0.35.2, 0.35.3, 0.35.4, 0.36.0, 0.36.0.post1, 0.36.0.post2, 0.37.0, 0.37.1, 0.37.2, 0.38.0, 0.38.0.post1, 0.38.0.post2, 0.38.1, 0.39.0, 0.39.1, 0.40.0, 0.40.0.post1, 0.40.0.post2, 0.40.0.post3, 0.40.0.post4, 0.40.1, 0.40.1.post1, 0.40.2, 0.41.0, 0.41.1, 0.41.2, 0.41.2.post1, 0.41.2.post2, 0.41.3, 0.41.3.post1, 0.41.3.post2, 0.42.0)[0m[31m
[0m[31mERROR: No matching distribution found for bitsandbytes==0.44.1[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.
Dependencies installed. Restart runtime if prompted.


## 2. Import Libraries

**What it does**: Imports all necessary Python libraries for fine-tuning, quantization, and logging.


In [None]:
import os
import json
import torch
import nltk
import evaluate
import numpy as np
import logging
from datetime import datetime
from datasets import load_dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    BitsAndBytesConfig,
    TrainerCallback
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
import wandb
import psutil
from functools import lru_cache

## 3. Logger setup

**What it does**: Configures a logger to output to console and a timestamped log file for debugging.


In [None]:
def setup_logger(name="train_logger", level=logging.INFO):
    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.propagate = False
    if logger.hasHandlers():
        logger.handlers.clear()
    formatter = logging.Formatter(
        fmt='%(asctime)s — %(name)s — %(levelname)s — %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)
    log_dir = os.path.join(os.getcwd(), 'logs')
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, f'training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info(f"Logger initialized: {name}")
    logger.info(f"Log file: {os.path.abspath(log_file)}")
    return logger

logger = setup_logger("train_logger", logging.INFO)

2025-05-19 13:52:07 — train_logger — INFO — Logger initialized: train_logger


2025-05-19 13:52:07 — train_logger — INFO — Log file created at: /Users/gourabsarkar/Developer/college_project/pdf_summarization_model_fine_tuning/notebooks/logs/training_20250519_135207.log
2025-05-19 13:52:07 — train_logger — INFO — Python version: 3.10.17 (main, Apr  8 2025, 12:10:59) [Clang 16.0.0 (clang-1600.0.26.6)]


## 4. NLTK Setup

**What it does**: Downloads NLTK resources for sentence tokenization used in metrics.


In [None]:
for resource in ['punkt', 'punkt_tab']:
    try:
        nltk.download(resource, quiet=True)
        logger.info(f"Downloaded NLTK resource: {resource}")
    except Exception as e:
        logger.error(f"Error downloading {resource}: {e}")

2025-05-19 13:52:31 — train_logger — INFO — Successfully downloaded NLTK resource: punkt
2025-05-19 13:52:33 — train_logger — INFO — Successfully downloaded NLTK resource: punkt_tab


## 5. Memory Usage Monitoring

The `print_memory_usage()` function monitors system resource utilization during model training:

- Tracks RAM usage by getting the Resident Set Size (RSS) of current process in GB
- For GPU-enabled systems:
  - Reports allocated GPU memory
  - Shows total available GPU memory
  - Calculates percentage of GPU memory utilization
  - Resets peak memory tracking statistics

This helps identify potential memory bottlenecks and optimize resource usage during training.


In [None]:
def print_memory_usage():
    process = psutil.Process(os.getpid())
    ram_gb = process.memory_info().rss / 1e9
    total_ram_gb = psutil.virtual_memory().total / 1e9
    logger.info(f"RAM usage: {ram_gb:.2f} GB / {total_ram_gb:.2f} GB ({ram_gb/total_ram_gb*100:.1f}%)")
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        gpu_mem = torch.cuda.memory_allocated() / 1e9
        gpu_total = torch.cuda.get_device_properties(0).total_memory / 1e9
        logger.info(f"GPU memory: {gpu_mem:.2f}/{gpu_total:.2f} GB ({gpu_mem/gpu_total*100:.1f}%)")

def clear_memory():
    import gc
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print_memory_usage()

print_memory_usage()

2025-05-19 13:52:43 — train_logger — INFO — RAM usage: 0.04 GB
2025-05-19 13:52:43 — train_logger — INFO — Total system RAM: 8.59 GB


## 6. Configuration Parameters

all hyperparameters and configuration settings for the model, dataset, LoRA, and training.


In [None]:
CONFIG = {
    # Model & Quantization
    "model_name": "google/flan-t5-base",
    "model_type": "encoder-decoder",
    "quantization_config": BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True
    ),
    # Dataset
    "dataset_name": "billsum",
    "text_col": "text",
    "summary_col": "summary",
    "max_input_tokens": 512,
    "max_target_tokens": 256,
    "sample_size": 1000,
    "filter_by_length": True,
    "split_train_frac": 0.9,
    # Prompt
    "prompt_prefix": "summarize: ",
    # LoRA
    "lora_r": 8,
    "lora_alpha": 16,
    "lora_target_modules": ["q", "k", "v"],
    "lora_dropout": 0.1,
    "lora_bias": "none",
    "lora_task_type": TaskType.SEQ_2_SEQ_LM,
    "lora_adapter_name": "lora_billsum_legal",
    # Training
    "do_train": True,
    "do_eval": True,
    "num_train_epochs": 3,
    "per_device_train_batch_size": 4,
    "per_device_eval_batch_size": 4,
    "gradient_accumulation_steps": 4,
    "learning_rate": 2e-4,
    "weight_decay": 0.01,
    "warmup_steps": 50,
    "fp16": True,
    "bf16": False,
    "gradient_checkpointing": True,
    "optim": "adamw_8bit",
    # Logging & Checkpointing
    "logging_steps": 10,
    "evaluation_strategy": "steps",
    "eval_steps": 50,
    "save_strategy": "steps",
    "save_steps": 50,
    "save_total_limit": 1,
    "load_best_model_at_end": True,
    "metric_for_best_model": "eval_loss",
    "greater_is_better": False,
    "report_to": "wandb",
    "overwrite_output_dir": True,
    # Generation
    "gen_num_beams": 4,
    "gen_length_penalty": 0.8,
    "gen_early_stopping": True,
    # Random seed
    "seed": 42,
    # Google Drive
    "mount_drive": True,
    "drive_path": "MyDrive/ML_models/pdf_summarization",
    "training_report_filename": "training_report.json"
}

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
CONFIG["output_dir"] = f"{CONFIG['lora_adapter_name']}_{timestamp}"
CONFIG["gdrive_output_dir"] = f"/content/drive/{CONFIG['drive_path']}/{CONFIG['output_dir']}" if CONFIG["mount_drive"] else CONFIG["output_dir"]

logger.info("CONFIG:")
for k, v in CONFIG.items():
    logger.info(f"  {k}: {v}")

'NoneType' object has no attribute 'cadam32bit_grad_fp32'


  warn("The installed version of bitsandbytes was compiled without GPU support. "


ImportError: cannot import name 'BitsAndBytesConfig' from 'bitsandbytes' (/Users/gourabsarkar/Developer/college_project/pdf_summarization_model_fine_tuning/.venv/lib/python3.10/site-packages/bitsandbytes/__init__.py)

## 6. Login to Hugging Face Hub and Weights & Biases

You'll need to log in to Hugging Face to download models/datasets and to Weights & Biases for experiment tracking.
You can get your Hugging Face token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
and your W&B API key from [https://wandb.ai/authorize](https://wandb.ai/authorize).


In [None]:
from huggingface_hub import HfFolder, notebook_login

try:
    if HfFolder.get_token() is None:
        logger.info("Hugging Face token not found. Please log in.")
        notebook_login()
    else:
        logger.info("Already logged in to Hugging Face Hub.")
except Exception as e:
    logger.error(f"Hugging Face login error: {e}")
    notebook_login()

try:
    wandb.login()
    wandb.init(project="flan-t5-billsum-lora", name=CONFIG["output_dir"], config=CONFIG)
    logger.info("Logged in to W&B and initialized experiment.")
except Exception as e:
    logger.error(f"W&B login failed: {e}. Falling back to TensorBoard.")
    CONFIG["report_to"] = "tensorboard"

## 7. Mount Google Drive (Optional)

If you want to save your model checkpoints and outputs to Google Drive, mount it here.


In [None]:
if CONFIG["mount_drive"]:
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        os.makedirs(CONFIG["gdrive_output_dir"], exist_ok=True)
        logger.info(f"Google Drive mounted: {CONFIG['gdrive_output_dir']}")
    except Exception as e:
        logger.error(f"Google Drive mount failed: {e}. Saving locally.")
        CONFIG["mount_drive"] = False
else:
    logger.info("Google Drive mount disabled. Saving locally.")

## 8. Load Model, Tokenizer and Configure LoRA

Loads the Flan-T5-base model and tokenizer from Hugging Face, configures the model for LoRA training and returns the model and tokenizer objects.


In [None]:
logger.info(f"Loading tokenizer: {CONFIG['model_name']}")
try:
    tokenizer = AutoTokenizer.from_pretrained(
        CONFIG["model_name"],
        use_fast=True,
        padding_side="right",
        model_max_length=CONFIG["max_input_tokens"]
    )
    logger.info("Tokenizer loaded.")
except Exception as e:
    logger.error(f"Tokenizer loading failed: {e}")
    raise

logger.info(f"Loading model: {CONFIG['model_name']} with 4-bit quantization")
try:
    model = AutoModelForSeq2SeqLM.from_pretrained(
        CONFIG["model_name"],
        quantization_config=CONFIG["quantization_config"],
        device_map="auto",
        torch_dtype=torch.float16
    )
    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=CONFIG["gradient_checkpointing"])
    logger.info("Model loaded with quantization.")
except Exception as e:
    logger.error(f"Model loading failed: {e}")
    raise

try:
    lora_config = LoraConfig(
        r=CONFIG["lora_r"],
        lora_alpha=CONFIG["lora_alpha"],
        target_modules=CONFIG["lora_target_modules"],
        lora_dropout=CONFIG["lora_dropout"],
        bias=CONFIG["lora_bias"],
        task_type=CONFIG["lora_task_type"]
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    logger.info("LoRA adapter applied.")
except Exception as e:
    logger.error(f"LoRA setup failed: {e}")
    raise

clear_memory()

## 9. Load and Preprocess Dataset

Load the BillSum dataset, preprocess it for Flan-T5, and split into training and evaluation sets.


In [None]:
import re

def clean_text(text):
    if not isinstance(text, str):
        return ""
    text = text.strip()
    text = " ".join(text.split())
    text = re.sub(r'\s*\([^)]{0,40}\)\s*', ' ', text)
    text = re.sub(r'\s*\[[^\]]{0,40}\]\s*', ' ', text)
    return text

def preprocess_function(examples):
    try:
        input_texts = examples.get(CONFIG["text_col"], [])
        summary_texts = examples.get(CONFIG["summary_col"], [])

        valid_pairs = [(i, s) for i, s in zip(input_texts, summary_texts) if i and s and isinstance(i, str) and isinstance(s, str)]
        if not valid_pairs:
            logger.warning("No valid input-summary pairs found.")
            return {"input_ids": [], "attention_mask": [], "labels": []}

        input_texts, summary_texts = zip(*valid_pairs)
        cleaned_inputs = [clean_text(doc) for doc in input_texts]
        prompts = [f'{CONFIG["prompt_prefix"]}{doc}' for doc in cleaned_inputs]

        logger.debug(f"Sample raw input: {input_texts[0][:150]}...")
        logger.debug(f"Sample cleaned input: {cleaned_inputs[0][:150]}...")

        model_inputs = tokenizer(
            prompts,
            max_length=CONFIG["max_input_tokens"],
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        summaries = [clean_text(s) if s else "No summary provided." for s in summary_texts]
        labels = tokenizer(
            summaries,
            max_length=CONFIG["max_target_tokens"],
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        model_inputs["labels"] = labels["input_ids"]
        model_inputs["labels"][model_inputs["labels"] == tokenizer.pad_token_id] = -100
        return model_inputs
    except Exception as e:
        logger.error(f"Preprocessing failed: {e}")
        raise

logger.info(f"Loading dataset: {CONFIG['dataset_name']}")
try:
    dataset = load_dataset(CONFIG["dataset_name"], split=f"train[:{CONFIG['sample_size']}]")
    logger.info(f"Dataset loaded: {len(dataset)} samples")
except Exception as e:
    logger.error(f"Dataset loading failed: {e}")
    raise

try:
    total_size = len(dataset)
    train_size = int(total_size * CONFIG["split_train_frac"])
    eval_size = total_size - train_size
    dataset_shuffled = dataset.shuffle(seed=CONFIG["seed"])
    train_dataset = dataset_shuffled.select(range(train_size))
    eval_dataset = dataset_shuffled.select(range(train_size, total_size))
    logger.info(f"Splits created - Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
except Exception as e:
    logger.error(f"Dataset splitting failed: {e}")
    raise

try:
    tokenized_datasets = {
        'train': train_dataset.map(
            preprocess_function,
            batched=True,
            remove_columns=train_dataset.column_names
        ),
        'eval': eval_dataset.map(
            preprocess_function,
            batched=True,
            remove_columns=eval_dataset.column_names
        )
    }
    logger.info(f"Tokenized datasets - Train: {len(tokenized_datasets['train'])}, Eval: {len(tokenized_datasets['eval'])}")
except Exception as e:
    logger.error(f"Dataset tokenization failed: {e}")
    raise

## 10. Define Metrics Computation

**What it does**: Defines a function to compute ROUGE and BLEU metrics, with error handling to prevent NaN issues.


In [None]:
@lru_cache(maxsize=1)
def get_metrics():
    """Load and cache evaluation metrics.

    Returns:
        Dict: Dictionary containing loaded metrics for evaluation
    """
    return {
        "rouge": evaluate.load("rouge"),
        "bleu": evaluate.load("bleu")
    }

def process_texts(texts):
    """Process and clean texts for evaluation.

    Args:
        texts (List[str]): List of texts to process

    Returns:
        List[str]: List of processed texts with sentence tokenization
    """
    return ["\n".join(nltk.sent_tokenize(text.strip())) 
            for text in texts if text.strip()]

def compute_metrics(eval_preds: tuple):
    """Compute evaluation metrics for model predictions.

    Args:
        eval_preds (tuple): Tuple containing predictions and labels

    Returns:
        Dict[str, float]: Dictionary containing computed metrics
    """
    try:
        metrics = get_metrics()
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]

        # Decode predictions and labels
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Process texts
        decoded_preds = process_texts(decoded_preds)
        decoded_labels = process_texts(decoded_labels)

        if not decoded_preds or not decoded_labels:
            logger.warning("Empty predictions or labels. Returning default metrics.")
            return {
                "rouge1": 0.0, 
                "rouge2": 0.0, 
                "rougeL": 0.0, 
                "bleu": 0.0, 
                "gen_len": 0
            }

        # Compute ROUGE scores
        rouge_results = metrics["rouge"].compute(
            predictions=decoded_preds,
            references=decoded_labels
        )

        # Compute BLEU scores
        bleu_results = metrics["bleu"].compute(
            predictions=decoded_preds,
            references=[[label] for label in decoded_labels]
        )

        # Calculate length statistics
        pred_lengths = [len(p.split()) for p in decoded_preds]
        ref_lengths = [len(r.split()) for r in decoded_labels]
        
        # Compile results
        results = {
            "rouge1": rouge_results["rouge1"],
            "rouge2": rouge_results["rouge2"],
            "rougeL": rouge_results["rougeL"],
            "bleu": bleu_results["bleu"],
            "gen_len": np.mean(pred_lengths) if pred_lengths else 0,
            "compression_ratio": np.mean([p/r for p, r in zip(pred_lengths, ref_lengths)]) 
                               if pred_lengths and ref_lengths else 0
        }

        return {k: round(v, 4) for k, v in results.items()}

    except Exception as e:
        logger.error(f"Metrics computation failed: {e}")
        return {
            "rouge1": 0.0,
            "rouge2": 0.0,
            "rougeL": 0.0,
            "bleu": 0.0,
            "gen_len": 0,
            "error": str(e)
        }

## 11. Define Initialize Trainer

**What it does**: Sets up the trainer with training arguments, data collator, and callbacks for loss and memory tracking.


In [None]:
class LossTrackingCallback(TrainerCallback):
    def __init__(self):
        self.training_loss = []
        self.eval_loss = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            step = state.global_step
            if "loss" in logs:
                self.training_loss.append((step, logs["loss"]))
                wandb.log({"training_loss": logs["loss"]}, step=step)
            if "eval_loss" in logs:
                self.eval_loss.append((step, logs["eval_loss"]))
                wandb.log({"eval_loss": logs["eval_loss"]}, step=step)

class MemoryTrackingCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 50 == 0:
            print_memory_usage()

training_args = Seq2SeqTrainingArguments(
    output_dir=CONFIG["output_dir"],
    num_train_epochs=CONFIG["num_train_epochs"],
    per_device_train_batch_size=CONFIG["per_device_train_batch_size"],
    per_device_eval_batch_size=CONFIG["per_device_eval_batch_size"],
    gradient_accumulation_steps=CONFIG["gradient_accumulation_steps"],
    learning_rate=CONFIG["learning_rate"],
    weight_decay=CONFIG["weight_decay"],
    warmup_steps=CONFIG["warmup_steps"],
    fp16=CONFIG["fp16"],
    bf16=CONFIG["bf16"],
    optim=CONFIG["optim"],
    logging_dir=f"{CONFIG['output_dir']}/logs",
    logging_strategy="steps",
    logging_steps=CONFIG["logging_steps"],
    eval_strategy=CONFIG["evaluation_strategy"],
    eval_steps=CONFIG["eval_steps"],
    save_strategy=CONFIG["save_strategy"],
    save_steps=CONFIG["save_steps"],
    save_total_limit=CONFIG["save_total_limit"],
    load_best_model_at_end=CONFIG["load_best_model_at_end"],
    metric_for_best_model=CONFIG["metric_for_best_model"],
    greater_is_better=CONFIG["greater_is_better"],
    predict_with_generate=True,
    generation_max_length=CONFIG["max_target_tokens"],
    generation_num_beams=CONFIG["gen_num_beams"],
    report_to=CONFIG["report_to"],
    seed=CONFIG["seed"],
    gradient_checkpointing=CONFIG["gradient_checkpointing"],
    overwrite_output_dir=CONFIG["overwrite_output_dir"]
)

try:
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=-100,
        pad_to_multiple_of=8
    )
    logger.info("Data collator initialized.")
except Exception as e:
    logger.error(f"Data collator initialization failed: {e}")
    raise

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["eval"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[
        LossTrackingCallback(),
        MemoryTrackingCallback()
    ]
)

logger.info("Trainer initialized.")
clear_memory()

## 12. Train the Model 😁

**What it does**: Runs the training loop, resumes from checkpoints if available, and logs metrics.


In [None]:
def get_latest_checkpoint(checkpoint_dir):
    checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("checkpoint-")]
    if not checkpoints:
        return None
    return os.path.join(checkpoint_dir, sorted(checkpoints, key=lambda x: int(x.split("-")[1]))[-1])

logger.info("Starting training...")
try:
    checkpoint_dir = os.path.join(CONFIG["output_dir"], "checkpoints")
    os.makedirs(checkpoint_dir, exist_ok=True)
    resume_checkpoint = get_latest_checkpoint(checkpoint_dir)
    if resume_checkpoint:
        logger.info(f"Resuming from checkpoint: {resume_checkpoint}")

    logger.info(f"Training samples: {len(trainer.train_dataset)}")
    logger.info(f"Validation samples: {len(trainer.eval_dataset)}")
    logger.info(f"Epochs: {CONFIG['num_train_epochs']}")

    train_result = trainer.train(resume_from_checkpoint=resume_checkpoint)
    metrics = train_result.metrics
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

    logger.info("Training completed!")
    logger.info(f"Final Training Loss: {metrics.get('train_loss', 'N/A')}")
except Exception as e:
    logger.error(f"Training failed: {e}")
    if wandb.run:
        wandb.log({"training_error": str(e)})
        wandb.run.finish(exit_code=1)
    raise
finally:
    clear_memory()

## 13. Evaluate the Model

**What it does**: Evaluates the model on the validation set and computes ROUGE and BLEU scores.


In [None]:
logger.info("Evaluating model...")
eval_metrics = trainer.evaluate()

logger.info("Evaluation metrics:")
for key, value in eval_metrics.items():
    logger.info(f"{key}: {value}")

trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics)

training_report = {
    "model_name": CONFIG["model_name"],
    "dataset_name": CONFIG["dataset_name"],
    "lora_adapter_name": CONFIG["lora_adapter_name"],
    "output_directory": CONFIG["output_dir"],
    "training_arguments": training_args.to_dict(),
    "train_metrics": trainer.state.log_history[:-1],
    "eval_metrics": eval_metrics,
    "final_training_loss": trainer.state.log_history[-2].get('loss', 'N/A') if len(trainer.state.log_history) > 1 else 'N/A'
}

for metric_key in ["eval_rouge1", "eval_rouge2", "eval_rougeL", "eval_bleu"]:
    if metric_key in eval_metrics:
        training_report[metric_key.replace("eval_", "")] = eval_metrics[metric_key]

report_path = os.path.join(CONFIG["output_dir"], CONFIG["training_report_filename"])
with open(report_path, "w") as f:
    json.dump(training_report, f, indent=4)
logger.info(f"Training report saved to {report_path}")

if CONFIG["mount_drive"] and os.path.exists(CONFIG["gdrive_output_dir"]):
    gdrive_report_path = os.path.join(CONFIG["gdrive_output_dir"], CONFIG["training_report_filename"])
    os.system(f"cp '{report_path}' '{gdrive_report_path}'")
    logger.info(f"Training report copied to {gdrive_report_path}")

if wandb.run:
    wandb.log(eval_metrics)
    wandb.save(report_path)

clear_memory()

## 13. Save Model and LoRA Adapter

**What it does**: Saves the fine-tuned model and LoRA adapter to the specified directory.


In [None]:
lora_adapter_path = os.path.join(CONFIG["output_dir"], CONFIG["lora_adapter_name"])
model.save_pretrained(lora_adapter_path)
tokenizer.save_pretrained(lora_adapter_path)
logger.info(f"LoRA adapter and tokenizer saved to {lora_adapter_path}")

if CONFIG["mount_drive"] and os.path.exists(CONFIG["gdrive_output_dir"]):
    gdrive_lora_path = os.path.join(CONFIG["gdrive_output_dir"], CONFIG["lora_adapter_name"])
    os.system(f"cp -r '{lora_adapter_path}' '{CONFIG['gdrive_output_dir']}/'")
    logger.info(f"LoRA adapter copied to {gdrive_lora_path}")

if wandb.run:
    wandb.finish()

clear_memory()

## 14. Test the Model

**What it does**: Tests the model on a sample input and prints the generated summary.


In [None]:
test_document = """
CYBERSECURITY AND PRIVACY PROTECTION ACT OF 2025

SECTION 1. SHORT TITLE AND PURPOSE

    (a) This Act may be cited as the 'Cybersecurity and Privacy Protection Act of 2025'.
    (b) The purpose of this Act is to enhance cybersecurity measures and protect individual privacy in the digital age.

SECTION 2. DEFINITIONS

In this Act:
(1) 'Personal Data' means any information relating to an identified or identifiable natural person.
(2) 'Data Controller' means any entity that determines the purposes and means of processing personal data.
(3) 'Critical Infrastructure' means systems and assets vital to national security.

SECTION 3. CYBERSECURITY REQUIREMENTS

    (a) MANDATORY SECURITY MEASURES.—
        (1) All Data Controllers shall implement:
            (A) End-to-end encryption for data transmission
            (B) Multi-factor authentication for system access
            (C) Regular security audits and vulnerability assessments

    (b) INCIDENT REPORTING.—
        (1) Data Controllers shall report any security breach within 48 hours.
        (2) Penalties for non-compliance shall be up to $500,000 per incident.

SECTION 4. PRIVACY PROTECTIONS

    (a) CONSENT REQUIREMENTS.—
        (1) Explicit consent required for data collection
        (2) Right to access and delete personal data
        (3) Annual privacy impact assessments

    (b) CHILDREN'S PRIVACY.—
        (1) Enhanced protections for users under 13
        (2) Parental consent requirements

SECTION 5. ENFORCEMENT

    (a) The Federal Trade Commission shall enforce this Act.
    (b) State Attorneys General may bring civil actions.

SECTION 6. AUTHORIZATION OF APPROPRIATIONS

    There is authorized to be appropriated $275,000,000 for fiscal year 2026 to carry out this Act.
"""

logger.info("Loading trained model for testing...")
try:
    model = AutoModelForSeq2SeqLM.from_pretrained(
        CONFIG["model_name"],
        quantization_config=CONFIG["quantization_config"],
        device_map="auto",
        torch_dtype=torch.float16
    )
    model = get_peft_model(model, lora_config)
    model.load_adapter(lora_adapter_path, CONFIG["lora_adapter_name"])
    model.set_adapter(CONFIG["lora_adapter_name"])
    model.eval()
    logger.info("Model loaded for testing.")
except Exception as e:
    logger.error(f"Failed to load model for testing: {e}")
    raise

inputs = tokenizer(
    f"{CONFIG['prompt_prefix']}{clean_text(test_document)}",
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=CONFIG["max_input_tokens"]
).to("cuda")

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_length=CONFIG["max_target_tokens"],
        num_beams=CONFIG["gen_num_beams"],
        length_penalty=CONFIG["gen_length_penalty"],
        early_stopping=CONFIG["gen_early_stopping"]
    )
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)

print("Generated Summary:")
print(summary)
clear_memory()
logger.info("Test completed.")