# Legal Case Summarization Assistant
## Domain-Specific Fine-Tuning with LoRA

**Project Overview:**
This notebook implements a production-quality legal case summarization system by fine-tuning a lightweight generative LLM using Parameter-Efficient Fine-Tuning (PEFT) with LoRA on the `joelniklaus/legal_case_document_summarization` dataset.

**Key Components:**
- Dataset: Legal court judgments with human-written summaries
- Model: Gemma-2B with 4-bit quantization
- Training: LoRA for parameter-efficient fine-tuning
- Evaluation: ROUGE metrics comparing base vs fine-tuned models
- Deployment: Gradio interface for interactive inference

**Author:** ML Engineering Project  
**Date:** February 2026

---
## 1. Environment Setup

Install required dependencies and configure the runtime environment for efficient training on Colab's free GPU (T4).

In [1]:
# Install required packages
# CRITICAL: Force upgrade all packages to ensure compatibility
print("Upgrading core packages to latest versions...")

# Upgrade transformers and huggingface_hub together for compatibility
!pip install --upgrade --force-reinstall --no-cache-dir transformers huggingface_hub

print("\nInstalling other required packages...")
!pip install --upgrade -q datasets \
                peft \
                accelerate \
                bitsandbytes \
                evaluate \
                rouge_score \
                gradio \
                sentencepiece \
                protobuf


print("\n" + "="*60)
print("✓ All packages installed successfully")
print("="*60)
print("⚠️  CRITICAL: RESTART KERNEL NOW!")
print("   Kaggle: Session → Restart Session (or Ctrl+M+.)")
print("   Then re-run cells 4-5 to verify installation")
print("="*60)

Upgrading core packages to latest versions...
Collecting transformers
  Downloading transformers-5.1.0-py3-none-any.whl.metadata (31 kB)
Collecting huggingface_hub
  Downloading huggingface_hub-1.4.1-py3-none-any.whl.metadata (13 kB)
Collecting numpy>=1.17 (from transformers)
  Downloading numpy-2.4.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (6.6 kB)
Collecting packaging>=20.0 (from transformers)
  Downloading packaging-26.0-py3-none-any.whl.metadata (3.3 kB)
Collecting pyyaml>=5.1 (from transformers)
  Downloading pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (2.4 kB)
Collecting regex!=2019.12.17 (from transformers)
  Downloading regex-2026.1.15-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.5/40.5 kB[0m [31m229.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tokenizers<=0.23.0,>=0.2

In [2]:
# Import required libraries
import torch
import numpy as np
import pandas as pd
import time
import random
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

from datasets import load_dataset, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    PeftModel
)
import evaluate
import gradio as gr

print("✓ Libraries imported successfully")

✓ Libraries imported successfully


In [3]:
# Verify transformers version (must be 4.38.0+ for Gemma support)
import transformers
print("="*60)
print("VERSION CHECK")
print("="*60)
print(f"Transformers version: {transformers.__version__}")
print()

# Parse version
version_parts = transformers.__version__.split('.')
major = int(version_parts[0])
minor = int(version_parts[1]) if len(version_parts) > 1 else 0

if major > 4 or (major == 4 and minor >= 38):
    print("✓ Transformers version is compatible with Gemma")
    print("  (Requires 4.38.0+, you have {})".format(transformers.__version__))
else:
    print("❌ ERROR: Transformers version too old!")
    print(f"  Current: {transformers.__version__}")
    print("  Required: 4.38.0+")
    print()
    print("FIX:")
    print("  1. Go back to Cell 3")
    print("  2. Re-run the installation cell")
    print("  3. RESTART KERNEL (Session → Restart Session)")
    print("  4. Re-run this cell")
    raise RuntimeError("Transformers version incompatible with Gemma")

print("="*60)

VERSION CHECK
Transformers version: 5.1.0

✓ Transformers version is compatible with Gemma
  (Requires 4.38.0+, you have 5.1.0)


In [4]:
# Detect and print GPU information
print("="*60)
print("HARDWARE CONFIGURATION")
print("="*60)

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Total GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"Available GPU Memory: {torch.cuda.mem_get_info()[0] / 1e9:.2f} GB")
    print(f"PyTorch Version: {torch.__version__}")
    device = "cuda"
else:
    print("⚠️  WARNING: No GPU detected. Training will be extremely slow.")
    device = "cpu"

print(f"Device: {device}")
print("="*60)

HARDWARE CONFIGURATION
GPU: Tesla P100-PCIE-16GB
CUDA Version: 12.6
Total GPU Memory: 17.06 GB
Available GPU Memory: 16.79 GB
PyTorch Version: 2.8.0+cu126
Device: cuda


In [5]:
# Set seed for reproducibility
SEED = 42

def set_seed(seed):
    """Set seed for reproducibility across all random number generators."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # Ensure deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(SEED)
print(f"✓ Random seed set to {SEED} for reproducibility")

✓ Random seed set to 42 for reproducibility


In [None]:
# Check internet connectivity (required for Kaggle)
import urllib.request

print("=" * 60)
print("CHECKING INTERNET CONNECTION")
print("=" * 60)

try:
    urllib.request.urlopen('https://huggingface.co', timeout=5)
    print("✓ Internet connection: ACTIVE")
    print("✓ Can access Hugging Face")
except:
    print("❌ NO INTERNET CONNECTION!")
    print()
    raise ConnectionError("Internet not enabled in Kaggle notebook")

print("=" * 60)

CHECKING INTERNET CONNECTION
✓ Internet connection: ACTIVE
✓ Can access Hugging Face


---
## 2. Load Dataset

Load the legal case summarization dataset from Hugging Face and perform exploratory data analysis.

In [7]:
# Load dataset from Hugging Face
print("Loading dataset from Hugging Face...")
print("Dataset: joelniklaus/legal_case_document_summarization")
print("Config: en (English)")
print()

# Load the dataset - use "en" for English, not "default"
dataset = load_dataset("joelniklaus/legal_case_document_summarization", "default")

print("✓ Dataset loaded successfully")
print()
print("Dataset structure:")
print(dataset)

Loading dataset from Hugging Face...
Dataset: joelniklaus/legal_case_document_summarization
Config: en (English)



README.md: 0.00B [00:00, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


train.jsonl.xz:   0%|          | 0.00/50.6M [00:00<?, ?B/s]

test.jsonl.xz:   0%|          | 0.00/2.30M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7773 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/200 [00:00<?, ? examples/s]

✓ Dataset loaded successfully

Dataset structure:
DatasetDict({
    train: Dataset({
        features: ['judgement', 'dataset_name', 'summary'],
        num_rows: 7773
    })
    test: Dataset({
        features: ['judgement', 'dataset_name', 'summary'],
        num_rows: 200
    })
})


In [8]:
# Inspect dataset structure and print example
print("="*60)
print("DATASET INSPECTION")
print("="*60)
print()
print(f"Dataset splits: {list(dataset.keys())}")
print(f"Features: {dataset['train'].features}")
print()
print(f"Training samples: {len(dataset['train'])}")
# Removed the problematic line as no 'validation' split exists
print(f"Test samples: {len(dataset['test'])}")
print()
print("="*60)
print("SAMPLE EXAMPLE")
print("="*60)

sample = dataset['train'][0]
print(f"\nJudgment (first 500 chars):\n{sample['judgement'][:500]}...")
print(f"\n{'='*60}")
print(f"\nSummary:\n{sample['summary']}")
print(f"\n{'='*60}")

DATASET INSPECTION

Dataset splits: ['train', 'test']
Features: {'judgement': Value('string'), 'dataset_name': Value('string'), 'summary': Value('string')}

Training samples: 7773
Test samples: 200

SAMPLE EXAMPLE

Judgment (first 500 chars):
Appeal No. LXVI of 1949.
Appeal from the High Court of judicature, Bombay, in a reference under section 66 of the Indian Income tax Act, 1022.
K.M. Munshi (N. P. Nathvani, with him), for the appel lant. ' M.C. Setalvad, Attorney General for India (H. J. Umrigar, with him), for the respondent. 1950.
May 26.
The judgment of the Court was delivered by MEHR CHAND MAHAJAN J.
This is an appeal against a judgment of the High Court of Judicature at Bombay in an income tax matter and it raises the questi...


Summary:
The charge created in respect of municipal property tax by section 212 of the City of Bombay Municipal Act, 1888, is an "annual charge not being a capital charge" within the mean ing of section 9 (1) (iv) of the Indian Income tax Act, 199.2, 

In [9]:
# Login to Hugging Face to access gated models (like Gemma)
from huggingface_hub import login
import os

print("=" * 60)
print("HUGGING FACE AUTHENTICATION")
print("=" * 60)
print("Gemma is a gated model - you need to authenticate")
print()

# For Kaggle: Use your token directly
HF_TOKEN = "YOUR_HUGGINGFACE_TOKEN_HERE"

try:
    # Login with token
    login(token=HF_TOKEN, add_to_git_credential=False)
    print("✓ Authentication successful!")
    print()
    print("IMPORTANT: Make sure you've accepted the Gemma license:")
    print("Visit: https://huggingface.co/google/gemma-2b")
    print("Click 'Agree and access repository'")
except Exception as e:
    print(f"❌ Authentication failed: {e}")
    print()
    print("To fix:")
    print("1. Get your token from: https://huggingface.co/settings/tokens")
    print("2. Replace HF_TOKEN value above with your token")
    print("3. Accept Gemma license: https://huggingface.co/google/gemma-2b")
    raise

print("=" * 60)

HUGGING FACE AUTHENTICATION
Gemma is a gated model - you need to authenticate

✓ Authentication successful!

IMPORTANT: Make sure you've accepted the Gemma license:
Visit: https://huggingface.co/google/gemma-2b
Click 'Agree and access repository'


In [10]:
# Analyze token lengths in the dataset
from transformers import AutoTokenizer

# Load tokenizer for analysis (using Gemma)
print("Loading Gemma tokenizer...")
analysis_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")

def get_token_stats(dataset_split, num_samples=1000):
    """Calculate token length statistics for judgments."""
    judgment_lengths = []
    summary_lengths = []

    sample_size = min(num_samples, len(dataset_split))
    indices = random.sample(range(len(dataset_split)), sample_size)

    for idx in indices:
        example = dataset_split[idx]
        judgment_tokens = len(analysis_tokenizer.encode(example['judgement']))
        summary_tokens = len(analysis_tokenizer.encode(example['summary']))
        judgment_lengths.append(judgment_tokens)
        summary_lengths.append(summary_tokens)

    return judgment_lengths, summary_lengths

print("Analyzing token lengths (sampling 1000 examples)...")
judgment_lengths, summary_lengths = get_token_stats(dataset['train'], 1000)

print("\n" + "="*60)
print("TOKEN LENGTH STATISTICS")
print("="*60)
print(f"\nJudgment tokens:")
print(f"  Mean: {np.mean(judgment_lengths):.0f}")
print(f"  Median: {np.median(judgment_lengths):.0f}")
print(f"  Min: {np.min(judgment_lengths)}")
print(f"  Max: {np.max(judgment_lengths)}")
print(f"  95th percentile: {np.percentile(judgment_lengths, 95):.0f}")

print(f"\nSummary tokens:")
print(f"  Mean: {np.mean(summary_lengths):.0f}")
print(f"  Median: {np.median(summary_lengths):.0f}")
print(f"  Min: {np.min(summary_lengths)}")
print(f"  Max: {np.max(summary_lengths)}")
print("="*60)

Loading Gemma tokenizer...


config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

Analyzing token lengths (sampling 1000 examples)...

TOKEN LENGTH STATISTICS

Judgment tokens:
  Mean: 7065
  Median: 4475
  Min: 319
  Max: 101787
  95th percentile: 21519

Summary tokens:
  Mean: 1125
  Median: 898
  Min: 57
  Max: 11965


In [None]:
# Using 4000 training samples for optimal balance between performance and training time

TRAIN_SAMPLES = 4000
VAL_SAMPLES = 500
TEST_SAMPLES = 300

print(f"Subsampling dataset for GPU efficiency...")
print(f"Train: {TRAIN_SAMPLES}, Validation: {VAL_SAMPLES}, Test: {TEST_SAMPLES}")
print()

# Set seed before sampling
set_seed(SEED)

# Check available splits
print(f"Available splits: {list(dataset.keys())}")
print()

# Subsample each split - create validation from training data if needed
train_dataset = dataset['train'].shuffle(seed=SEED).select(range(min(TRAIN_SAMPLES, len(dataset['train']))))

# Create validation split from train data (since dataset might not have validation)
if 'validation' in dataset:
    val_dataset = dataset['validation'].shuffle(seed=SEED).select(range(min(VAL_SAMPLES, len(dataset['validation']))))
else:
    print("⚠️  No validation split found - creating from training data")
    # Use samples after training samples for validation
    val_start = TRAIN_SAMPLES
    val_end = val_start + VAL_SAMPLES
    val_dataset = dataset['train'].shuffle(seed=SEED).select(range(val_start, min(val_end, len(dataset['train']))))

test_dataset = dataset['test'].shuffle(seed=SEED).select(range(min(TEST_SAMPLES, len(dataset['test']))))

print("✓ Dataset subsampled successfully")
print(f"Final split sizes:")
print(f"  Train: {len(train_dataset)}")
print(f"  Validation: {len(val_dataset)}")
print(f"  Test: {len(test_dataset)}")

Subsampling dataset for GPU efficiency...
Train: 4000, Validation: 500, Test: 300

Available splits: ['train', 'test']

⚠️  No validation split found - creating from training data
✓ Dataset subsampled successfully
Final split sizes:
  Train: 4000
  Validation: 500
  Test: 200


---
## 3. Data Preprocessing

Convert the dataset into instruction-response format suitable for fine-tuning a generative model. We'll use a structured prompt template that clearly separates instruction, input, and response.

In [12]:
# Define instruction template for legal case summarization
INSTRUCTION_TEMPLATE = """Instruction:
Summarize the following legal court judgment.

Input:
{judgement}

Response:
{summary}"""

# Maximum token limits for model context window management
MAX_JUDGMENT_TOKENS = 1024  # Truncate long judgments to fit in context
MAX_SUMMARY_TOKENS = 256    # Maximum summary length
MAX_TOTAL_LENGTH = 1536     # Total sequence length (judgment + summary + prompt)

print("Instruction Template:")
print("="*60)
print(INSTRUCTION_TEMPLATE.format(
    judgement="[JUDGMENT TEXT]",
    summary="[SUMMARY TEXT]"
))
print("="*60)
print(f"\nMax judgment tokens: {MAX_JUDGMENT_TOKENS}")
print(f"Max summary tokens: {MAX_SUMMARY_TOKENS}")
print(f"Max total length: {MAX_TOTAL_LENGTH}")

Instruction Template:
Instruction:
Summarize the following legal court judgment.

Input:
[JUDGMENT TEXT]

Response:
[SUMMARY TEXT]

Max judgment tokens: 1024
Max summary tokens: 256
Max total length: 1536


In [13]:
# Load tokenizer
MODEL_NAME = "google/gemma-2b"

print(f"Loading tokenizer: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Add padding token if not present
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

print("✓ Tokenizer loaded successfully")
print(f"Vocabulary size: {len(tokenizer)}")
print(f"Pad token: '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})")
print(f"EOS token: '{tokenizer.eos_token}' (ID: {tokenizer.eos_token_id})")

Loading tokenizer: google/gemma-2b
✓ Tokenizer loaded successfully
Vocabulary size: 256000
Pad token: '<pad>' (ID: 0)
EOS token: '<eos>' (ID: 1)


In [14]:
def truncate_text_by_tokens(text, max_tokens, tokenizer):
    """
    Truncate text to a maximum number of tokens.
    This ensures we don't exceed the model's context window.
    """
    tokens = tokenizer.encode(text, add_special_tokens=False)
    if len(tokens) > max_tokens:
        tokens = tokens[:max_tokens]
    return tokenizer.decode(tokens, skip_special_tokens=True)

def format_instruction(example):
    """
    Format a single example into instruction-response format.
    Truncates judgment to fit within token limits.
    """
    # Truncate judgment to max tokens
    truncated_judgment = truncate_text_by_tokens(
        example['judgement'],
        MAX_JUDGMENT_TOKENS,
        tokenizer
    )

    # Truncate summary to max tokens
    truncated_summary = truncate_text_by_tokens(
        example['summary'],
        MAX_SUMMARY_TOKENS,
        tokenizer
    )

    # Format into instruction template
    formatted_text = INSTRUCTION_TEMPLATE.format(
        judgement=truncated_judgment,
        summary=truncated_summary
    )

    return {"text": formatted_text}

def tokenize_function(examples):
    """
    Tokenize preprocessed examples with padding and truncation.
    """
    # Tokenize with truncation and padding
    tokenized = tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=MAX_TOTAL_LENGTH,
        return_tensors=None
    )

    # Create labels (copy of input_ids for causal language modeling)
    tokenized["labels"] = tokenized["input_ids"].copy()

    return tokenized

print("✓ Preprocessing functions defined")

✓ Preprocessing functions defined


In [15]:
# Apply formatting to datasets
print("Formatting datasets into instruction-response format...")
print()

# Format each split
train_formatted = train_dataset.map(format_instruction, remove_columns=train_dataset.column_names)
val_formatted = val_dataset.map(format_instruction, remove_columns=val_dataset.column_names)
test_formatted = test_dataset.map(format_instruction, remove_columns=test_dataset.column_names)

print("✓ Datasets formatted")
print()
print("Sample formatted example:")
print("="*60)
print(train_formatted[0]['text'][:800] + "...")
print("="*60)

Formatting datasets into instruction-response format...



Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

✓ Datasets formatted

Sample formatted example:
Instruction:
Summarize the following legal court judgment.

Input:
Appeal No. 945 of 1965.
Appeal by special leave from the judgment and order dated December 14, 1962 of the Gujarat High Court in Sales Tax Re ference No. 16 of 1961.
N. section Bindra and R. H. Dhebar, for the appellant.
M. V. Goswami, for the respondent.
The Judgment of the Court was delivered by Bhargava, J.
This appeal under special leave granted by this Court arises out of proceedings for assessment of sales tax under the Bombay Sales Tax Act III of 1953.
Messrs. Kailash Engineering Co. (hereinafter referred to as "the respondent") was an engineering concern having their workshop at Morvi on the meter gauge section of the Western Railway.
They obtained a contract from the Western Railway Administration for construction ...


In [16]:
# Tokenize datasets
print("Tokenizing datasets...")
print("This may take a few minutes...")
print()

train_tokenized = train_formatted.map(
    tokenize_function,
    batched=True,
    remove_columns=train_formatted.column_names,
    desc="Tokenizing train set"
)

val_tokenized = val_formatted.map(
    tokenize_function,
    batched=True,
    remove_columns=val_formatted.column_names,
    desc="Tokenizing validation set"
)

test_tokenized = test_formatted.map(
    tokenize_function,
    batched=True,
    remove_columns=test_formatted.column_names,
    desc="Tokenizing test set"
)

print("✓ Tokenization complete")
print(f"\nTokenized dataset features: {train_tokenized.features}")
print(f"Example input_ids shape: {len(train_tokenized[0]['input_ids'])}")

Tokenizing datasets...
This may take a few minutes...



Tokenizing train set:   0%|          | 0/4000 [00:00<?, ? examples/s]

Tokenizing validation set:   0%|          | 0/500 [00:00<?, ? examples/s]

Tokenizing test set:   0%|          | 0/200 [00:00<?, ? examples/s]

✓ Tokenization complete

Tokenized dataset features: {'input_ids': List(Value('int32')), 'attention_mask': List(Value('int8')), 'labels': List(Value('int64'))}
Example input_ids shape: 1536


---
## 4. Model Selection and Configuration

Load Gemma-2B model with 4-bit quantization for memory-efficient training on Colab's free GPU. We use QLoRA (Quantized LoRA) to enable fine-tuning on consumer hardware.

In [17]:
# Configure 4-bit quantization for memory efficiency
# This reduces memory footprint by ~75% compared to fp16

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,                      # Enable 4-bit quantization
    bnb_4bit_quant_type="nf4",             # Use NormalFloat4 quantization
    bnb_4bit_compute_dtype=torch.bfloat16,  # Compute in bfloat16 for stability
    bnb_4bit_use_double_quant=True,        # Double quantization for extra compression
)

print("4-bit Quantization Configuration:")
print("="*60)
print(f"Quantization type: NF4 (4-bit NormalFloat)")
print(f"Compute dtype: bfloat16")
print(f"Double quantization: Enabled")
print(f"Expected memory reduction: ~75% vs fp16")
print("="*60)

4-bit Quantization Configuration:
Quantization type: NF4 (4-bit NormalFloat)
Compute dtype: bfloat16
Double quantization: Enabled
Expected memory reduction: ~75% vs fp16


In [18]:
# Load base model with quantization
print(f"Loading base model: {MODEL_NAME}")
print("This will take 2-3 minutes...")
print()

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",                    # Automatically distribute across GPUs
    trust_remote_code=True,               # Required for Gemma
)

# Enable gradient checkpointing to reduce memory during training
# Trade-off: ~20% slower training for ~30% memory reduction
base_model.config.use_cache = False
base_model.gradient_checkpointing_enable()

print("✓ Model loaded successfully")
print()
print("Model Configuration:")
print("="*60)
print(f"Model name: {MODEL_NAME}")
print(f"Parameters: {base_model.num_parameters() / 1e9:.2f}B")
print(f"Gradient checkpointing: Enabled")
print(f"Device: {next(base_model.parameters()).device}")
print("="*60)

Loading base model: google/gemma-2b
This will take 2-3 minutes...



model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/164 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

✓ Model loaded successfully

Model Configuration:
Model name: google/gemma-2b
Parameters: 2.51B
Gradient checkpointing: Enabled
Device: cuda:0


---
## 5. Apply LoRA Using PEFT

Configure and apply Low-Rank Adaptation (LoRA) for parameter-efficient fine-tuning. LoRA only trains a small subset of parameters (typically <1% of total parameters) while achieving comparable performance to full fine-tuning.

In [19]:
# Prepare model for k-bit training
base_model = prepare_model_for_kbit_training(base_model)

print("✓ Model prepared for k-bit training")

✓ Model prepared for k-bit training


In [20]:
# Configure LoRA parameters
# These hyperparameters balance training efficiency with model expressiveness

lora_config = LoraConfig(
    r=16,                                  # Rank of update matrices (higher = more parameters)
    lora_alpha=16,                         # Scaling factor (typically set equal to r)
    target_modules=["q_proj", "v_proj"],  # Apply LoRA to query and value projection layers
    lora_dropout=0.05,                     # Dropout for regularization
    bias="none",                           # Don't train bias parameters
    task_type="CAUSAL_LM"                  # Task type: Causal Language Modeling
)

print("LoRA Configuration:")
print("="*60)
print(f"Rank (r): {lora_config.r}")
print(f"  → Controls the dimensionality of low-rank matrices")
print(f"  → Higher rank = more expressiveness but more parameters")
print()
print(f"Alpha: {lora_config.lora_alpha}")
print(f"  → Scaling factor for LoRA updates")
print(f"  → Typically set equal to r for balanced scaling")
print()
print(f"Target modules: {lora_config.target_modules}")
print(f"  → Attention query and value projections")
print(f"  → These are most impactful for adaptation")
print()
print(f"Dropout: {lora_config.lora_dropout}")
print(f"  → Prevents overfitting on small dataset")
print("="*60)

LoRA Configuration:
Rank (r): 16
  → Controls the dimensionality of low-rank matrices
  → Higher rank = more expressiveness but more parameters

Alpha: 16
  → Scaling factor for LoRA updates
  → Typically set equal to r for balanced scaling

Target modules: {'q_proj', 'v_proj'}
  → Attention query and value projections
  → These are most impactful for adaptation

Dropout: 0.05
  → Prevents overfitting on small dataset


In [21]:
# Apply LoRA to the model
model = get_peft_model(base_model, lora_config)

print("✓ LoRA applied to model")
print()

# Print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
trainable_percent = 100 * trainable_params / total_params

print("Parameter Analysis:")
print("="*60)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Trainable percentage: {trainable_percent:.4f}%")
print()
print(f"Memory efficiency: Training only {trainable_percent:.2f}% of parameters!")
print("="*60)

# Print model summary
model.print_trainable_parameters()

✓ LoRA applied to model

Parameter Analysis:
Total parameters: 1,517,111,296
Trainable parameters: 1,843,200
Trainable percentage: 0.1215%

Memory efficiency: Training only 0.12% of parameters!
trainable params: 1,843,200 || all params: 2,508,015,616 || trainable%: 0.0735


---
## 6. Training Setup

Configure training hyperparameters optimized for Colab's T4 GPU. We use gradient accumulation to simulate larger batch sizes without OOM errors.

In [23]:
# Define training hyperparameters
# Each parameter is carefully chosen for optimal training on Colab GPU

OUTPUT_DIR = "./legal-summarization-lora"

training_args = TrainingArguments(
    # Output and logging
    output_dir=OUTPUT_DIR,
    logging_dir="./logs",
    logging_steps=50,                      # Log every 50 steps for monitoring

    # Training schedule
    num_train_epochs=2,                    # 2 epochs sufficient for LoRA fine-tuning
    learning_rate=2e-4,                    # Higher LR works well with LoRA
    lr_scheduler_type="cosine",            # Cosine annealing for smooth convergence
    warmup_ratio=0.05,                     # 5% warmup steps

    # Batch size and accumulation
    per_device_train_batch_size=2,         # Small batch size to fit in 16GB GPU
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,         # Effective batch size = 2 * 8 = 16

    # Mixed precision and optimization
    fp16=True,                             # Use mixed precision for speed
    optim="paged_adamw_8bit",             # Memory-efficient optimizer

    # Evaluation and saving
    eval_strategy="epoch",                 # Evaluate at end of each epoch
    save_strategy="epoch",                 # Save checkpoint each epoch
    save_total_limit=2,                    # Keep only 2 best checkpoints
    load_best_model_at_end=True,          # Load best model after training
    metric_for_best_model="eval_loss",    # Use validation loss for model selection

    # Performance
    dataloader_num_workers=2,              # Parallel data loading
    dataloader_pin_memory=True,            # Faster GPU transfer

    # Miscellaneous
    report_to="none",                      # Disable wandb/tensorboard
    seed=SEED,
)

print("Training Configuration:")
print("="*60)
print(f"Learning Rate: {training_args.learning_rate}")
print(f"  → Higher LR (2e-4) suitable for LoRA small parameter space")
print()
print(f"Epochs: {training_args.num_train_epochs}")
print(f"  → 2 epochs sufficient to avoid overfitting on domain data")
print()
print(f"Batch Size per Device: {training_args.per_device_train_batch_size}")
print(f"  → Small to fit in T4 GPU memory (16GB)")
print()
print(f"Gradient Accumulation Steps: {training_args.gradient_accumulation_steps}")
print(f"  → Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  → Larger effective batch = more stable gradients")
print()
print(f"Mixed Precision (FP16): {training_args.fp16}")
print(f"  → 2x faster training, 50% memory reduction")
print()
print(f"LR Scheduler: {training_args.lr_scheduler_type}")
print(f"  → Cosine annealing for smooth convergence")
print()
print(f"Optimizer: {training_args.optim}")
print(f"  → 8-bit AdamW reduces optimizer memory by 75%")
print("="*60)

warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.
`logging_dir` is deprecated and will be removed in v5.2. Please set `TENSORBOARD_LOGGING_DIR` instead.


Training Configuration:
Learning Rate: 0.0002
  → Higher LR (2e-4) suitable for LoRA small parameter space

Epochs: 2
  → 2 epochs sufficient to avoid overfitting on domain data

Batch Size per Device: 2
  → Small to fit in T4 GPU memory (16GB)

Gradient Accumulation Steps: 8
  → Effective batch size: 16
  → Larger effective batch = more stable gradients

Mixed Precision (FP16): True
  → 2x faster training, 50% memory reduction

LR Scheduler: SchedulerType.COSINE
  → Cosine annealing for smooth convergence

Optimizer: OptimizerNames.PAGED_ADAMW_8BIT
  → 8-bit AdamW reduces optimizer memory by 75%


In [24]:
# Create data collator for causal language modeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  # We're doing causal LM, not masked LM
)

print("✓ Data collator created")

✓ Data collator created


In [25]:
# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    data_collator=data_collator,
)

print("✓ Trainer initialized")
print()
print(f"Training samples: {len(train_tokenized)}")
print(f"Validation samples: {len(val_tokenized)}")
print(f"Steps per epoch: {len(train_tokenized) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps)}")
print(f"Total training steps: {len(train_tokenized) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * training_args.num_train_epochs}")

✓ Trainer initialized

Training samples: 4000
Validation samples: 500
Steps per epoch: 250
Total training steps: 500


In [26]:
# Start training with time and memory tracking
print("="*60)
print("STARTING TRAINING")
print("="*60)
print()
print("Training will take approximately 30-45 minutes on Colab T4 GPU")
print()

# Clear GPU cache before training
torch.cuda.empty_cache()

# Record start time
start_time = time.time()
start_memory = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0

# Train the model
train_result = trainer.train()

# Record end time and memory
end_time = time.time()
peak_memory = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0
training_time = end_time - start_time

print()
print("="*60)
print("TRAINING COMPLETE")
print("="*60)
print(f"Training time: {training_time / 60:.2f} minutes")
print(f"Peak GPU memory: {peak_memory:.2f} GB")
print(f"Final training loss: {train_result.training_loss:.4f}")
print("="*60)

STARTING TRAINING

Training will take approximately 30-45 minutes on Colab T4 GPU



Epoch,Training Loss,Validation Loss
1,1.809841,1.804943
2,1.786816,1.7956



TRAINING COMPLETE
Training time: 635.90 minutes
Peak GPU memory: 14.15 GB
Final training loss: 1.8199


In [27]:
# Save the fine-tuned model
print("Saving fine-tuned model...")

# Save LoRA adapters
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print(f"✓ Model saved to {OUTPUT_DIR}")
print()
print("Saved files:")
import os
for file in os.listdir(OUTPUT_DIR):
    if not file.startswith('checkpoint'):
        print(f"  - {file}")

Saving fine-tuned model...
✓ Model saved to ./legal-summarization-lora

Saved files:
  - tokenizer.json
  - adapter_config.json
  - tokenizer_config.json
  - README.md
  - adapter_model.safetensors


---
## 7. Evaluation

Evaluate the fine-tuned model using ROUGE metrics and compare against the base pre-trained model (zero-shot). ROUGE scores measure n-gram overlap between generated and reference summaries.

In [None]:
# Install required packages (run once)
!pip install -q evaluate rouge_score bert_score huggingface_hub

In [None]:
# Configuration - UPDATE THESE PATHS
MODEL_PATH = "/kaggle/input/datasets/orpheusmanga/legal-summarization-lora"  # Path to your trained model
BASE_MODEL_NAME = "google/gemma-2b"       # Base model used for training
MAX_JUDGMENT_TOKENS = 1024
MAX_NEW_TOKENS = 256
NUM_TEST_SAMPLES = 100

# Verify model path exists
import os
if not os.path.exists(MODEL_PATH):
    print(f"⚠️  WARNING: Model path not found: {MODEL_PATH}")
    print("Please update MODEL_PATH to point to your trained model folder")
    print("It should contain: adapter_model.safetensors, adapter_config.json, tokenizer files")
else:
    print(f"✓ Model path found: {MODEL_PATH}")

print("Configuration:")
print(f"  Model path: {MODEL_PATH}")
print(f"  Base model: {BASE_MODEL_NAME}")
print(f"  Test samples: {NUM_TEST_SAMPLES}")

✓ Model path found: /kaggle/input/datasets/orpheusmanga/legal-summarization-lora1
Configuration:
  Model path: /kaggle/input/datasets/orpheusmanga/legal-summarization-lora1
  Base model: google/gemma-2b
  Test samples: 100


In [None]:
# Verify model files
print("="*60)
print("CHECKING MODEL FILES")
print("="*60)

required_files = [
    'adapter_config.json',
    'adapter_model.safetensors',  # or adapter_model.bin
    'tokenizer_config.json',
    'tokenizer.json'
]

import os
if os.path.exists(MODEL_PATH):
    files = os.listdir(MODEL_PATH)
    print(f"\nFiles found in {MODEL_PATH}:")
    for f in sorted(files):
        print(f"  ✓ {f}")
    
    # Check for required files
    print("\nRequired files check:")
    for req_file in required_files:
        if req_file in files or (req_file == 'adapter_model.safetensors' and 'adapter_model.bin' in files):
            print(f"  ✓ {req_file}")
        else:
            print(f"  ⚠️  Missing: {req_file}")
else:
    print(f"\n❌ Path does not exist: {MODEL_PATH}")
    print("Please update MODEL_PATH in the configuration cell above")

print("="*60)

CHECKING MODEL FILES

Files found in /kaggle/input/datasets/orpheusmanga/legal-summarization-lora1:
  ✓ adapter_config.json
  ✓ adapter_model.safetensors
  ✓ tokenizer.json
  ✓ tokenizer_config.json

Required files check:
  ✓ adapter_config.json
  ✓ adapter_model.safetensors
  ✓ tokenizer_config.json
  ✓ tokenizer.json


In [None]:
# Load test dataset
print("Loading test dataset...")
dataset = load_dataset("joelniklaus/legal_case_document_summarization", "default")
test_dataset = dataset['test'].shuffle(seed=42).select(range(NUM_TEST_SAMPLES))
print(f"✓ Loaded {len(test_dataset)} test samples")
print(f"  Dataset type: {type(test_dataset)}")
print(f"  First example keys: {list(test_dataset[0].keys())}")

Loading test dataset...


Repo card metadata block was not found. Setting CardData to empty.


✓ Loaded 100 test samples
  Dataset type: <class 'datasets.arrow_dataset.Dataset'>
  First example keys: ['judgement', 'dataset_name', 'summary']


In [None]:
# Load your trained LoRA adapters
print("Loading LoRA adapters from local trained model...")
print(f"  Path: {MODEL_PATH}")

# Load adapters from local directory
model = PeftModel.from_pretrained(
    base_model, 
    MODEL_PATH,
    is_trainable=False  # Set to evaluation mode
)
model.eval()  # Set to evaluation mode

print("✓ Fine-tuned model loaded successfully!")
print(f"  Device: {next(model.parameters()).device}")

Loading LoRA adapters from local trained model...
  Path: /kaggle/input/datasets/orpheusmanga/legal-summarization-lora1




✓ Fine-tuned model loaded successfully!
  Device: cuda:0


In [None]:
# Load evaluation metrics
print("Loading evaluation metrics...")
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")
bertscore = evaluate.load("bertscore")
print("✓ Metrics loaded")

Loading evaluation metrics...


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

✓ Metrics loaded


In [None]:
# Define helper functions
def truncate_text_by_tokens(text, max_tokens, tokenizer):
    """Truncate text to max tokens"""
    tokens = tokenizer.encode(text, add_special_tokens=False)
    if len(tokens) > max_tokens:
        tokens = tokens[:max_tokens]
    return tokenizer.decode(tokens, skip_special_tokens=True)

def generate_summary(model, tokenizer, judgment_text, max_new_tokens=256):
    """Generate summary for a legal judgment"""
    prompt = f"""Instruction:
Summarize the following legal court judgment.

Input:
{judgment_text}

Response:
"""
    
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1280).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    if "Response:" in generated_text:
        summary = generated_text.split("Response:")[-1].strip()
    else:
        summary = generated_text[len(prompt):].strip()
    
    return summary

print("✓ Helper functions defined")

✓ Helper functions defined


In [None]:
# Run evaluation
print("="*60)
print("STARTING EVALUATION")
print("="*60)
print(f"\nEvaluating {NUM_TEST_SAMPLES} samples...")
print("This will take 10-15 minutes\n")

predictions = []
references = []
perplexities = []

for i in range(NUM_TEST_SAMPLES):
    if i % 20 == 0:
        print(f"Progress: {i}/{NUM_TEST_SAMPLES}")
    
    # Get example and check structure
    example = test_dataset[i]
    
    # Handle both dict and dataset row formats
    if isinstance(example, dict):
        judgment_text = example['judgement']
        reference_summary = example['summary']
    else:
        # If it's a dataset row, access columns properly
        judgment_text = example['judgement'] if 'judgement' in example else str(example)
        reference_summary = example['summary'] if 'summary' in example else ""
    
    judgment = truncate_text_by_tokens(judgment_text, MAX_JUDGMENT_TOKENS, tokenizer)
    
    # Generate summary
    pred_summary = generate_summary(model, tokenizer, judgment, MAX_NEW_TOKENS)
    
    predictions.append(pred_summary)
    references.append(reference_summary)
    
    # Calculate perplexity
    try:
        inputs = tokenizer(pred_summary, return_tensors="pt", truncation=True, max_length=512).to(model.device)
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss.item()
            perplexity = torch.exp(torch.tensor(loss)).item()
            perplexities.append(perplexity)
    except Exception as e:
        if i == 0:  # Only print error for first sample
            print(f"Warning: Could not compute perplexity: {e}")
        pass

print(f"\n✓ Generated {len(predictions)} summaries")

STARTING EVALUATION

Evaluating 100 samples...
This will take 10-15 minutes

Progress: 0/100
Progress: 20/100
Progress: 40/100
Progress: 60/100
Progress: 80/100

✓ Generated 100 summaries


In [None]:
# Compute metrics
print("\nComputing metrics...")
print(f"Predictions: {len(predictions)}, References: {len(references)}")

# Ensure all strings are valid
predictions_clean = [str(p) if p else "" for p in predictions]
references_clean = [str(r) if r else "" for r in references]

# ROUGE
print("Computing ROUGE...")
rouge_results = rouge.compute(
    predictions=predictions_clean,
    references=references_clean,
    use_stemmer=True
)

# BLEU
print("Computing BLEU...")
bleu_references = [[ref] for ref in references_clean]
bleu_results = bleu.compute(
    predictions=predictions_clean,
    references=bleu_references
)

# BERTScore
print("Computing BERTScore (this takes a few minutes)...")
try:
    bertscore_results = bertscore.compute(
        predictions=predictions_clean,
        references=references_clean,
        model_type="distilbert-base-uncased",
        device="cuda" if torch.cuda.is_available() else "cpu"
    )
    print(f"BERTScore type: {type(bertscore_results)}")
    if isinstance(bertscore_results, dict):
        print(f"BERTScore keys: {bertscore_results.keys()}")
except Exception as e:
    print(f"BERTScore error: {e}")
    # Fallback values if BERTScore fails
    bertscore_results = {
        'precision': [0.0] * len(predictions_clean),
        'recall': [0.0] * len(predictions_clean),
        'f1': [0.0] * len(predictions_clean)
    }

# Perplexity
avg_perplexity = np.mean(perplexities) if perplexities else None

print("✓ All metrics computed")


Computing metrics...
Predictions: 100, References: 100
Computing ROUGE...
Computing BLEU...
Computing BERTScore (this takes a few minutes)...


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

BERTScore type: <class 'dict'>
BERTScore keys: dict_keys(['precision', 'recall', 'f1', 'hashcode'])
✓ All metrics computed


In [None]:
# Display results
print("\n" + "="*60)
print("EVALUATION RESULTS")
print("="*60)

results_df = pd.DataFrame({
    'Metric': ['ROUGE-1', 'ROUGE-2', 'ROUGE-L', 'BLEU', 'BERTScore-P', 'BERTScore-R', 'BERTScore-F1', 'Perplexity'],
    'Score': [
        f"{rouge_results['rouge1']:.4f}",
        f"{rouge_results['rouge2']:.4f}",
        f"{rouge_results['rougeL']:.4f}",
        f"{bleu_results['bleu']:.4f}",
        f"{np.mean(bertscore_results['precision']):.4f}",
        f"{np.mean(bertscore_results['recall']):.4f}",
        f"{np.mean(bertscore_results['f1']):.4f}",
        f"{avg_perplexity:.2f}" if avg_perplexity else "N/A"
    ],
    'Interpretation': [
        'Vocabulary coverage',
        'Phrase fluency',
        'Structural coherence',
        'Precision-based overlap',
        'Semantic precision',
        'Semantic recall',
        'Semantic F1 balance',
        'Model confidence (lower=better)'
    ]
}
)

print(results_df.to_string(index=False))
print("\n" + "="*60)

# Save results
results_df.to_csv('evaluation_results.csv', index=False)
print("\n✓ Results saved to evaluation_results.csv")


EVALUATION RESULTS
      Metric  Score                  Interpretation
     ROUGE-1 0.2908             Vocabulary coverage
     ROUGE-2 0.1341                  Phrase fluency
     ROUGE-L 0.1760            Structural coherence
        BLEU 0.0057         Precision-based overlap
 BERTScore-P 0.8218              Semantic precision
 BERTScore-R 0.7614                 Semantic recall
BERTScore-F1 0.7902             Semantic F1 balance
  Perplexity   6.80 Model confidence (lower=better)


✓ Results saved to evaluation_results.csv


In [None]:
# Show sample predictions
print("\n" + "="*60)
print("SAMPLE PREDICTIONS")
print("="*60)

for i in range(min(3, len(predictions))):
    print(f"\n{'='*60}")
    print(f"EXAMPLE {i+1}")
    print("="*60)
    print(f"\nInput (first 200 chars):\n{test_dataset[i]['judgement'][:200]}...")
    print(f"\n{'-'*60}")
    print(f"\nReference Summary:\n{references[i][:400]}...")
    print(f"\n{'-'*60}")
    print(f"\nGenerated Summary:\n{predictions[i][:400]}...")
    print(f"\n{'='*60}")


SAMPLE PREDICTIONS

EXAMPLE 1

Input (first 200 chars):
The appellant brought a claim for judicial review of a decision of the respondent, on 21 February 2012, to approve a Revenue Budget for 2012/13 in relation to the provision of youth services.
In his c...

------------------------------------------------------------

Reference Summary:
Mr Aaron Hunt, born on 17 April 1991, suffers from ADHD, learning difficulties and behavioural problems.
As a result, North Somerset Council (the Council) are statutorily required, so far as reasonably practicable, to secure access for him to sufficient educational and recreational leisure time activities for the improvement of his well being.
On 21 February 2012, the Council made a decision to ap...

------------------------------------------------------------

Generated Summary:
The appellant was a young person with a disability who used to attend a weekly youth club.
He was concerned about the impact which the reduction in the youth services budg

In [33]:
# Create comparison table
print("\n" + "="*60)
print("EVALUATION RESULTS COMPARISON")
print("="*60)
print()

comparison_df = pd.DataFrame({
    'Model': ['Base (Zero-shot)', 'Fine-tuned (LoRA)'],
    'ROUGE-1': [
        f"{base_results['rouge1']:.4f}",
        f"{finetuned_results['rouge1']:.4f}"
    ],
    'ROUGE-2': [
        f"{base_results['rouge2']:.4f}",
        f"{finetuned_results['rouge2']:.4f}"
    ],
    'ROUGE-L': [
        f"{base_results['rougeL']:.4f}",
        f"{finetuned_results['rougeL']:.4f}"
    ],
    'Improvement': [
        '-',
        f"+{(finetuned_results['rougeL'] - base_results['rougeL']) * 100:.2f}%"
    ]
})

print(comparison_df.to_string(index=False))
print()
print("="*60)
print(f"Fine-tuning improved ROUGE-L by {(finetuned_results['rougeL'] - base_results['rougeL']) * 100:.2f}%")
print("="*60)


EVALUATION RESULTS COMPARISON

            Model ROUGE-1 ROUGE-2 ROUGE-L Improvement
 Base (Zero-shot)  0.2213  0.0787  0.1412           -
Fine-tuned (LoRA)  0.2872  0.1323  0.1710      +2.98%

Fine-tuning improved ROUGE-L by 2.98%


In [34]:
# Display example predictions
print("\n" + "="*60)
print("SAMPLE PREDICTIONS")
print("="*60)

for i in range(3):
    print(f"\n{'='*60}")
    print(f"EXAMPLE {i+1}")
    print("="*60)
    print(f"\nJudgment (first 300 chars):\n{test_dataset[i]['judgement'][:300]}...")
    print(f"\n{'-'*60}")
    print(f"\nReference Summary:\n{ft_references[i][:400]}...")
    print(f"\n{'-'*60}")
    print(f"\nBase Model (Zero-shot):\n{base_predictions[i][:400]}...")
    print(f"\n{'-'*60}")
    print(f"\nFine-tuned Model:\n{ft_predictions[i][:400]}...")
    print(f"\n{'='*60}")


SAMPLE PREDICTIONS

EXAMPLE 1

Judgment (first 300 chars):
The appellant brought a claim for judicial review of a decision of the respondent, on 21 February 2012, to approve a Revenue Budget for 2012/13 in relation to the provision of youth services.
In his claim form he applied for declarations that the respondent had failed to comply with section 149 of t...

------------------------------------------------------------

Reference Summary:
Mr Aaron Hunt, born on 17 April 1991, suffers from ADHD, learning difficulties and behavioural problems.
As a result, North Somerset Council (the Council) are statutorily required, so far as reasonably practicable, to secure access for him to sufficient educational and recreational leisure time activities for the improvement of his well being.
On 21 February 2012, the Council made a decision to ap...

------------------------------------------------------------

Base Model (Zero-shot):
This case is about the Court of Appeal’s decision in the case o

---
## 9. Inference Function

Create a clean, reusable function for generating case summaries in production.

In [36]:
def summarize_case(text, model=model, tokenizer=tokenizer, max_length=256):
    """
    Generate a legal case summary from input judgment text.

    Args:
        text (str): Legal court judgment text
        model: Fine-tuned summarization model
        tokenizer: Model tokenizer
        max_length (int): Maximum summary length in tokens

    Returns:
        str: Generated summary
    """
    # Format prompt
    prompt = f"""Instruction:
Summarize the following legal court judgment.

Input:
{text[:3000]}

Response:
"""

    # Tokenize
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=1280
    ).to(model.device)

    # Generate with sampling for diversity
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_length,
            temperature=0.7,       # Controls randomness
            top_p=0.9,            # Nucleus sampling
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            num_beams=1,          # Greedy sampling for speed
        )

    # Decode and extract summary
    full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract only the generated response
    if "Response:" in full_text:
        summary = full_text.split("Response:")[-1].strip()
    else:
        summary = full_text[len(prompt):].strip()

    return summary

print("✓ Inference function 'summarize_case()' defined")
print()
print("Usage:")
print("  summary = summarize_case(judgment_text)")

✓ Inference function 'summarize_case()' defined

Usage:
  summary = summarize_case(judgment_text)


In [37]:
# Test inference function
print("Testing inference function...\n")
print("="*60)

test_case = test_dataset[5]['judgement']
test_summary = summarize_case(test_case)

print(f"Input (first 400 chars):\n{test_case[:400]}...\n")
print("="*60)
print(f"\nGenerated Summary:\n{test_summary}\n")
print("="*60)
print(f"\nReference Summary:\n{test_dataset[5]['summary']}\n")
print("="*60)

Testing inference function...

Input (first 400 chars):
Appeals, Nos. 275 276 of 1963.
Appeals by special leave from the judgment and order dated May 2, 1960 of the Kerala High Court in Income tax Referred case No. 98 of 1955(M).
section T. Desai, C. V. Mahalingam, B. Parthasarathi and J. B. Dadachanji, for the appellant (in both the appeals).
K. N. Rajagopal Sastri and R. N. Sachthey, for the res pondent (in both the appeals).
95 December 20, 1963.
Th...


Generated Summary:
In the year 1946, the appellant, a partnership firm, entered into a partnership agreement with another firm, to transfer the business of the latter to the former.
The appellant and the other firm executed a deed of transfer on November 13, 1947.
The appellant then made a declaration that it was discontinuing the business of the other firm.
The appellant claimed relief under section 25(4) of the Income tax Act, 1922, on the ground that the business was discontinued on February 7, 1948, by the execution of the deed 

---
## 11. Save and Reload Model

Demonstrate how to save the LoRA adapters and reload for inference.

In [None]:
# Save LoRA adapters (already done during training)
print("LoRA adapters have been saved to:", OUTPUT_DIR)
print()
print("Files saved:")
for file in os.listdir(OUTPUT_DIR):
    if not file.startswith('checkpoint'):
        file_size = os.path.getsize(os.path.join(OUTPUT_DIR, file)) / (1024 * 1024)
        print(f"  - {file} ({file_size:.2f} MB)")

print()
print("Total LoRA adapter size: ~10-50 MB (vs ~5GB for full model)")