In [2]:
# ============================================================================
# MEMORY-OPTIMIZED PHI-2 TEXT SUMMARIZATION
# Fixed for OOM errors on 15GB GPU
# ============================================================================

import os
import gc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig
)
import evaluate
import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# CRITICAL: MEMORY OPTIMIZATION SETTINGS
# ============================================================================
# Set before importing torch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def aggressive_memory_cleanup():
    """Aggressively clear GPU memory"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

print("="*70)
print("MEMORY-OPTIMIZED PHI-2 TEXT SUMMARIZATION")
print("="*70)
print(f"‚úÖ PyTorch version: {torch.__version__}")
print(f"‚úÖ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úÖ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print("="*70)

# ============================================================================
# OPTIMIZED CONFIGURATION (REDUCED FOR 15GB GPU)
# ============================================================================
class Config:
    # Model
    MODEL_NAME = "microsoft/phi-2"

    # CRITICAL: Reduced dataset sizes for memory
    TRAIN_SIZE = 1000    # Reduced from 1000
    VAL_SIZE = 50       # Reduced from 100
    TEST_SIZE = 50      # Reduced from 100

    # CRITICAL: Reduced sequence length
    MAX_LENGTH = 256    # Reduced from 512 (saves 50% memory!)

    # Training parameters
    NUM_EPOCHS = 2      # Reduced from 3
    BATCH_SIZE = 1      # Reduced from 2 (CRITICAL!)
    GRAD_ACCUMULATION = 16  # Increased from 8 to maintain effective batch
    LEARNING_RATE = 2e-4
    WARMUP_STEPS = 30

    # LoRA parameters
    LORA_R = 8
    LORA_ALPHA = 16     # Reduced from 32
    LORA_DROPOUT = 0.05

    # Generation parameters
    MAX_NEW_TOKENS = 40  # Reduced from 50
    NUM_BEAMS = 2        # Reduced from 4 (saves memory)
    TEMPERATURE = 0.7
    TOP_P = 0.9
    REPETITION_PENALTY = 1.2

    # Paths
    OUTPUT_DIR = "./fine-tuned-phi2-summarization"
    RESULTS_DIR = "./results_phi2_summarization"

    SEED = 42

config = Config()

print("\n‚ö†Ô∏è  MEMORY-OPTIMIZED CONFIGURATION:")
print(f"  Max sequence length: {config.MAX_LENGTH} (50% reduction)")
print(f"  Batch size: {config.BATCH_SIZE}")
print(f"  Effective batch: {config.BATCH_SIZE * config.GRAD_ACCUMULATION}")
print(f"  Training samples: {config.TRAIN_SIZE} (smaller dataset)")

# ============================================================================
# STEP 1: LOAD DATASET
# ============================================================================
print("\n" + "="*70)
print("STEP 1: LOADING XSUM DATASET")
print("="*70)

full_dataset = load_dataset("EdinburghNLP/xsum")

dataset = DatasetDict({
    'train': full_dataset['train'].shuffle(seed=config.SEED).select(range(config.TRAIN_SIZE)),
    'validation': full_dataset['validation'].shuffle(seed=config.SEED).select(range(config.VAL_SIZE)),
    'test': full_dataset['test'].shuffle(seed=config.SEED).select(range(config.TEST_SIZE))
})

print(f"‚úÖ Training: {len(dataset['train'])} samples")
print(f"‚úÖ Validation: {len(dataset['validation'])} samples")
print(f"‚úÖ Test: {len(dataset['test'])} samples")

# ============================================================================
# STEP 2: QUICK EDA (MINIMAL MEMORY)
# ============================================================================
print("\n" + "="*70)
print("STEP 2: DATA ANALYSIS")
print("="*70)

train_df = pd.DataFrame(dataset['train'])
train_df['doc_len'] = train_df['document'].str.split().str.len()
train_df['sum_len'] = train_df['summary'].str.split().str.len()

print(f"Document length: {train_df['doc_len'].mean():.0f} words (avg)")
print(f"Summary length: {train_df['sum_len'].mean():.0f} words (avg)")

# Simple visualization
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.hist(train_df['doc_len'], bins=20, color='skyblue', edgecolor='black')
plt.title('Document Lengths')
plt.xlabel('Words')

plt.subplot(1, 2, 2)
plt.hist(train_df['sum_len'], bins=20, color='lightcoral', edgecolor='black')
plt.title('Summary Lengths')
plt.xlabel('Words')

plt.tight_layout()
plt.savefig('xsum_lengths.png', dpi=200)
plt.close()
print("‚úÖ Saved: xsum_lengths.png")

del train_df
aggressive_memory_cleanup()

# ============================================================================
# STEP 3: LOAD MODEL WITH AGGRESSIVE QUANTIZATION
# ============================================================================
print("\n" + "="*70)
print("STEP 3: LOADING PHI-2 (4-BIT QUANTIZATION)")
print("="*70)

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print("Loading model with 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    config.MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    max_memory={0: "10GB"}  # Limit memory usage
)

model.config.use_cache = False
model.config.pretraining_tp = 1

print("‚úÖ Model loaded with 4-bit quantization")
print(f"‚úÖ Memory footprint: ~{torch.cuda.memory_allocated(0) / 1e9:.2f} GB")

aggressive_memory_cleanup()

# ============================================================================
# STEP 4: CONFIGURE LORA (MINIMAL PARAMETERS)
# ============================================================================
print("\n" + "="*70)
print("STEP 4: CONFIGURING LORA")
print("="*70)

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=config.LORA_R,
    lora_alpha=config.LORA_ALPHA,
    target_modules=["q_proj", "v_proj"],  # Minimal targets
    lora_dropout=config.LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.enable_input_require_grads()

# Only enable gradient checkpointing if needed
if config.BATCH_SIZE == 1:
    model.gradient_checkpointing_enable()

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"‚úÖ Trainable params: {trainable:,} ({trainable/total*100:.2f}%)")

aggressive_memory_cleanup()

# ============================================================================
# STEP 5: TOKENIZATION (MEMORY EFFICIENT)
# ============================================================================
print("\n" + "="*70)
print("STEP 5: TOKENIZATION")
print("="*70)

def create_prompt(document, summary=None):
    # Truncate document to fit in memory
    doc_words = document.split()
    if len(doc_words) > 150:  # Further truncate
        document = ' '.join(doc_words[:150]) + "..."

    prompt = f"""Summarize this article in one sentence.

Article: {document}

Summary:"""

    if summary:
        prompt += " " + summary + tokenizer.eos_token
    return prompt

def tokenize_function(examples):
    prompts = [
        create_prompt(doc, summ)
        for doc, summ in zip(examples['document'], examples['summary'])
    ]

    model_inputs = tokenizer(
        prompts,
        max_length=config.MAX_LENGTH,
        truncation=True,
        padding="max_length"
    )

    model_inputs["labels"] = model_inputs["input_ids"].copy()
    return model_inputs

print("Tokenizing (this may take a moment)...")
tokenized_datasets = dataset.map(
    tokenize_function,
    batched=True,
    batch_size=50,
    remove_columns=dataset['train'].column_names,
    desc="Tokenizing"
)

print("‚úÖ Tokenization complete")

aggressive_memory_cleanup()

# ============================================================================
# STEP 6: SETUP METRICS
# ============================================================================
print("\n" + "="*70)
print("STEP 6: LOADING METRICS")
print("="*70)

rouge = evaluate.load("rouge")

def compute_metrics(eval_preds):
    predictions, labels = eval_preds

    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Extract summaries
    decoded_preds = [p.split("Summary:")[-1].strip() for p in decoded_preds]
    decoded_labels = [l.split("Summary:")[-1].strip() for l in decoded_labels]

    result = rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )

    return {
        'rouge1': result['rouge1'],
        'rouge2': result['rouge2'],
        'rougeL': result['rougeL']
    }

print("‚úÖ ROUGE metric loaded")

# ============================================================================
# STEP 7: TRAINING (MEMORY OPTIMIZED)
# ============================================================================
print("\n" + "="*70)
print("STEP 7: TRAINING CONFIGURATION")
print("="*70)

training_args = TrainingArguments(
    output_dir=config.RESULTS_DIR,
    num_train_epochs=config.NUM_EPOCHS,
    per_device_train_batch_size=config.BATCH_SIZE,
    per_device_eval_batch_size=config.BATCH_SIZE,
    gradient_accumulation_steps=config.GRAD_ACCUMULATION,
    learning_rate=config.LEARNING_RATE,
    warmup_steps=config.WARMUP_STEPS,
    max_grad_norm=0.3,
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,  # Keep only 1 checkpoint
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    fp16=True,
    gradient_checkpointing=True,
    optim="adamw_8bit",  # 8-bit optimizer
    dataloader_num_workers=0,
    dataloader_pin_memory=False,  # Disable pinning
    report_to="none",
    seed=config.SEED,
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
)

print(f"‚úÖ Trainer ready")
print(f"   Effective batch size: {config.BATCH_SIZE * config.GRAD_ACCUMULATION}")

print("\n" + "="*70)
print("üöÄ STARTING TRAINING (Estimated: 8-12 minutes)")
print("="*70 + "\n")

aggressive_memory_cleanup()

# Train
try:
    train_result = trainer.train()
    print("\n‚úÖ Training completed successfully!")
except RuntimeError as e:
    if "out of memory" in str(e):
        print("\n‚ùå Still out of memory. Try these fixes:")
        print("1. Restart runtime completely")
        print("2. Further reduce TRAIN_SIZE to 250")
        print("3. Set MAX_LENGTH to 128")
        print("4. Use CPU for evaluation only")
        raise
    else:
        raise

# ============================================================================
# STEP 8: EVALUATION
# ============================================================================
print("\n" + "="*70)
print("STEP 8: EVALUATION")
print("="*70)

eval_results = trainer.evaluate()

print("\nüéØ Validation Results:")
print(f"  Validation Loss: {eval_results['eval_loss']:.4f}")

# Simple visualization
log_history = trainer.state.log_history
train_logs = [log for log in log_history if 'loss' in log and 'eval_loss' not in log]
eval_logs = [log for log in log_history if 'eval_loss' in log]

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
if train_logs and eval_logs:
    plt.plot([log['step'] for log in train_logs], [log['loss'] for log in train_logs], label='Train')
    plt.plot([log['step'] for log in eval_logs], [log['eval_loss'] for log in eval_logs], 'o-', label='Val')
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.title('Training Progress')
    plt.legend()
    plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
if eval_logs:
    epochs = [log['epoch'] for log in eval_logs]
    losses = [log['eval_loss'] for log in eval_logs]
    plt.plot(epochs, losses, 'o-', color='red', linewidth=2, markersize=8)
    plt.xlabel('Epoch')
    plt.ylabel('Validation Loss')
    plt.title('Validation Loss Progress')
    plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=200)
plt.close()
print("‚úÖ Saved: training_curves.png")

# ============================================================================
# STEP 9: SAVE MODEL
# ============================================================================
print("\n" + "="*70)
print("STEP 9: SAVING MODEL")
print("="*70)

model.save_pretrained(config.OUTPUT_DIR)
tokenizer.save_pretrained(config.OUTPUT_DIR)
print(f"‚úÖ Model saved to: {config.OUTPUT_DIR}")

aggressive_memory_cleanup()

# ============================================================================
# STEP 10: INFERENCE (MEMORY EFFICIENT)
# ============================================================================
print("\n" + "="*70)
print("STEP 10: GENERATING SAMPLE SUMMARIES")
print("="*70)

def generate_summary(document):
    prompt = create_prompt(document)
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=config.MAX_LENGTH)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=config.MAX_NEW_TOKENS,
            num_beams=config.NUM_BEAMS,
            temperature=config.TEMPERATURE,
            do_sample=True,
            top_p=config.TOP_P,
            repetition_penalty=config.REPETITION_PENALTY,
            pad_token_id=tokenizer.pad_token_id
        )

    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    summary = text.split("Summary:")[-1].strip()
    return summary

print("\nGenerating 3 example summaries:\n")

for i in range(3):
    example = dataset['test'][i]
    doc_short = example['document'][:200] + "..."

    print(f"Example {i+1}:")
    print(f"Document: {doc_short}")
    print(f"Reference: {example['summary']}")
    print(f"Generated: {generate_summary(example['document'])}")
    print()

    aggressive_memory_cleanup()

# ============================================================================
# STEP 11: TEST EVALUATION (BATCH PROCESSING)
# ============================================================================
print("\n" + "="*70)
print("STEP 11: TEST SET EVALUATION")
print("="*70)

references = []
predictions = []

print(f"Generating {len(dataset['test'])} summaries...")

for i, example in enumerate(dataset['test']):
    if (i + 1) % 10 == 0:
        print(f"Progress: {i+1}/{len(dataset['test'])}")
        aggressive_memory_cleanup()

    generated = generate_summary(example['document'])
    references.append(example['summary'])
    predictions.append(generated)

test_rouge = rouge.compute(predictions=predictions, references=references, use_stemmer=True)

print("\nüéØ FINAL TEST RESULTS:")
print("="*70)
print(f"  ROUGE-1: {test_rouge['rouge1']:.4f}")
print(f"  ROUGE-2: {test_rouge['rouge2']:.4f}")
print(f"  ROUGE-L: {test_rouge['rougeL']:.4f}")
print("="*70)

# Save results
results = pd.DataFrame({
    'Metric': ['ROUGE-1', 'ROUGE-2', 'ROUGE-L'],
    'Score': [test_rouge['rouge1'], test_rouge['rouge2'], test_rouge['rougeL']]
})
results.to_csv('test_results.csv', index=False)
print("\n‚úÖ Results saved to: test_results.csv")

# ============================================================================
# FINAL SUMMARY
# ============================================================================
print("\n" + "="*70)
print("‚úÖ TRAINING COMPLETE!")
print("="*70)
print(f"\nConfiguration:")
print(f"  Samples: {config.TRAIN_SIZE} train, {config.TEST_SIZE} test")
print(f"  Max length: {config.MAX_LENGTH} tokens")
print(f"  Batch size: {config.BATCH_SIZE} (effective: {config.BATCH_SIZE * config.GRAD_ACCUMULATION})")
print(f"\nResults:")
print(f"  ROUGE-L: {test_rouge['rougeL']:.4f}")
print(f"\nFiles saved:")
print(f"  ‚úì {config.OUTPUT_DIR}/")
print(f"  ‚úì training_curves.png")
print(f"  ‚úì xsum_lengths.png")
print(f"  ‚úì test_results.csv")
print("="*70)

aggressive_memory_cleanup()

MEMORY-OPTIMIZED PHI-2 TEXT SUMMARIZATION
‚úÖ PyTorch version: 2.9.0+cu126
‚úÖ CUDA available: True
‚úÖ GPU: Tesla T4
‚úÖ GPU Memory: 15.83 GB

‚ö†Ô∏è  MEMORY-OPTIMIZED CONFIGURATION:
  Max sequence length: 256 (50% reduction)
  Batch size: 1
  Effective batch: 16
  Training samples: 1000 (smaller dataset)

STEP 1: LOADING XSUM DATASET
‚úÖ Training: 1000 samples
‚úÖ Validation: 50 samples
‚úÖ Test: 50 samples

STEP 2: DATA ANALYSIS
Document length: 372 words (avg)
Summary length: 21 words (avg)
‚úÖ Saved: xsum_lengths.png

STEP 3: LOADING PHI-2 (4-BIT QUANTIZATION)
Loading tokenizer...
Loading model with 4-bit quantization...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

‚úÖ Model loaded with 4-bit quantization
‚úÖ Memory footprint: ~4.21 GB

STEP 4: CONFIGURING LORA
‚úÖ Trainable params: 2,621,440 (0.17%)

STEP 5: TOKENIZATION
Tokenizing (this may take a moment)...


Tokenizing:   0%|          | 0/1000 [00:00<?, ? examples/s]

Tokenizing:   0%|          | 0/50 [00:00<?, ? examples/s]

Tokenizing:   0%|          | 0/50 [00:00<?, ? examples/s]

‚úÖ Tokenization complete

STEP 6: LOADING METRICS
‚úÖ ROUGE metric loaded

STEP 7: TRAINING CONFIGURATION
‚úÖ Trainer ready
   Effective batch size: 16

üöÄ STARTING TRAINING (Estimated: 8-12 minutes)



Epoch,Training Loss,Validation Loss


Epoch,Training Loss,Validation Loss
1,2.372,2.280224
2,2.3692,2.261337



‚úÖ Training completed successfully!

STEP 8: EVALUATION



üéØ Validation Results:
  Validation Loss: 2.2613
‚úÖ Saved: training_curves.png

STEP 9: SAVING MODEL
‚úÖ Model saved to: ./fine-tuned-phi2-summarization

STEP 10: GENERATING SAMPLE SUMMARIES

Generating 3 example summaries:

Example 1:
Document: Sarah Johnson was one of 21 women heading to Liverpool when their minibus was hit by a lorry on the M62.
Her friend Bethany Jones, 18, was killed while Ms Johnson and several others were badly hurt.
M...
Reference: A woman who was seriously hurt in a fatal hen party motorway crash is now helping other major trauma victims rebuild their lives.
Generated: A woman who survived a bus crash that killed her friend has set up a charity to help other victims of road accidents. Sarah Johnson, 20, was one of 21 women heading to Liverpool when their min

Example 2:
Document: A total of 1,400 tickets have sold out for the opening weekend at Bramall Hall in Stockport, Greater Manchester after renovation work began in 2014.
Stained glass windows and ceil

In [1]:
!pip install datasets==3.6.0

Collecting datasets==3.6.0
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Downloading datasets-3.6.0-py3-none-any.whl (491 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m491.5/491.5 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: datasets
  Attempting uninstall: datasets
    Found existing installation: datasets 4.0.0
    Uninstalling datasets-4.0.0:
      Successfully uninstalled datasets-4.0.0
Successfully installed datasets-3.6.0


In [2]:
# 1. Upgrade bitsandbytes first
!pip install -U bitsandbytes

# 2. Then install other packages
!pip install -q transformers datasets torch accelerate peft rouge-score sentencepiece evaluate

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m84.1/84.1 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25h