# üö® EMERGENCY FIX - MT5 Training Loss = 0

## Fixes Applied:
1. ‚úÖ Disable FP16 (numerical instability)
2. ‚úÖ Explicit model unfreezing
3. ‚úÖ Force labels to not be all -100
4. ‚úÖ Test EVERY step before training
5. ‚úÖ Higher learning rate

---

## ‚ö†Ô∏è RUN THIS FIRST: Diagnostic Cell

In [None]:
# CH·∫†Y CELL N√ÄY TR∆Ø·ªöC ƒë·ªÉ verify m·ªçi th·ª© OK
!python diagnostic_script.py

: 

## Step 1: Install & Import

In [None]:
!pip install -q transformers datasets evaluate rouge-score sentencepiece

In [None]:
import torch
import gc
import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
import evaluate

# Clear GPU
torch.cuda.empty_cache()
gc.collect()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Step 2: Load Data

In [None]:
train_df = pd.read_csv("data/train.csv")
val_df = pd.read_csv("data/validation.csv")
test_df = pd.read_csv("data/test.csv")

train_df = train_df[['document', 'summary']].dropna()
val_df = val_df[['document', 'summary']].dropna()
test_df = test_df[['document', 'summary']].dropna()

dataset = DatasetDict({
    'train': Dataset.from_pandas(train_df, preserve_index=False),
    'validation': Dataset.from_pandas(val_df, preserve_index=False),
    'test': Dataset.from_pandas(test_df, preserve_index=False)
})

print(f"‚úÖ Data loaded: {len(train_df)} train, {len(val_df)} val, {len(test_df)} test")

## Step 3: Load Model - WITH EXPLICIT CHECKS

In [None]:
MODEL_NAME = "google/mt5-small"
print(f"\nLoading {MODEL_NAME}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

print(f"‚úÖ Model loaded: {model.__class__.__name__}")
print(f"   Parameters: {model.num_parameters():,}")
print(f"   Vocab size: {tokenizer.vocab_size:,}")

# üî• CRITICAL FIX 1: Explicitly unfreeze ALL parameters
for param in model.parameters():
    param.requires_grad = True
print("‚úÖ All parameters unfrozen")

# Move to device
model = model.to(device)
print(f"‚úÖ Model on {device}")

# üî• CRITICAL CHECK: Verify model can compute loss
print("\nüîç Testing forward pass...")
test_input = tokenizer("t√≥m t·∫Øt: Test sentence", return_tensors="pt").to(device)
test_label = tokenizer("Test output", return_tensors="pt").to(device)

with torch.no_grad():
    test_output = model(
        input_ids=test_input['input_ids'],
        labels=test_label['input_ids']
    )
    test_loss = test_output.loss.item()

print(f"Test loss: {test_loss:.4f}")

if test_loss == 0.0:
    print("‚ùå‚ùå‚ùå CRITICAL ERROR: Test loss is 0!")
    print("Model is broken. DO NOT CONTINUE.")
    raise RuntimeError("Model test loss is 0")
elif torch.isnan(torch.tensor(test_loss)):
    print("‚ùå‚ùå‚ùå CRITICAL ERROR: Test loss is NaN!")
    raise RuntimeError("Model test loss is NaN")
else:
    print(f"‚úÖ Test loss is normal: {test_loss:.4f}")

# Test generation
print("\nüîç Testing generation...")
with torch.no_grad():
    test_gen = model.generate(**test_input, max_length=20)
    test_gen_text = tokenizer.decode(test_gen[0], skip_special_tokens=True)

print(f"Generated: '{test_gen_text}'")
if len(test_gen_text.strip()) == 0:
    print("‚ùå WARNING: Generated empty text")
elif '<' in test_gen_text and '>' in test_gen_text:
    print("‚ùå WARNING: Generated sentinel tokens")
else:
    print("‚úÖ Generation works")

## Step 4: Tokenize - WITH VERIFICATION

In [None]:
def preprocess_function(examples):
    """Tokenize with explicit checks"""
    inputs = ["t√≥m t·∫Øt: " + doc for doc in examples["document"]]
    
    model_inputs = tokenizer(
        inputs,
        max_length=512,
        truncation=True,
        padding=False
    )
    
    # üî• CRITICAL: Use text_target for labels
    labels = tokenizer(
        text_target=examples["summary"],
        max_length=128,
        truncation=True,
        padding=False
    )
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

print("Tokenizing dataset...")
tokenized_datasets = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset["train"].column_names,
)

# üî• VERIFICATION: Check labels are NOT all -100
sample = tokenized_datasets["train"][0]
print(f"\n‚úÖ Sample tokenized:")
print(f"   Input length: {len(sample['input_ids'])}")
print(f"   Label length: {len(sample['labels'])}")
print(f"   Labels (first 20): {sample['labels'][:20]}")

if all(l == -100 for l in sample['labels']):
    print("‚ùå‚ùå‚ùå CRITICAL: ALL LABELS ARE -100!")
    raise RuntimeError("All labels are -100")
else:
    valid_count = sum(1 for l in sample['labels'] if l != -100)
    print(f"‚úÖ Labels OK: {valid_count}/{len(sample['labels'])} valid tokens")
    print(f"   Decoded: {tokenizer.decode([l for l in sample['labels'][:30] if l != -100])}")

## Step 5: Metrics

In [None]:
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    
    # Handle 3D predictions (logits)
    if len(predictions.shape) == 3:
        predictions = np.argmax(predictions, axis=-1)
    
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    # Replace -100 in labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Debug print
    print(f"\n[EVAL] Sample prediction: {decoded_preds[0][:100]}")
    print(f"[EVAL] Sample reference: {decoded_labels[0][:100]}")
    
    # Clean
    decoded_preds = [" ".join(pred.strip().split()) for pred in decoded_preds]
    decoded_labels = [" ".join(label.strip().split()) for label in decoded_labels]
    
    # Check empty
    if all(len(p.strip()) == 0 for p in decoded_preds):
        print("‚ö†Ô∏è  All predictions empty!")
        return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0, "rougeLsum": 0.0}
    
    try:
        result = rouge.compute(
            predictions=decoded_preds,
            references=decoded_labels,
            use_stemmer=False
        )
        return {
            "rouge1": result["rouge1"],
            "rouge2": result["rouge2"],
            "rougeL": result["rougeL"],
            "rougeLsum": result["rougeLsum"],
        }
    except Exception as e:
        print(f"‚ö†Ô∏è  ROUGE error: {e}")
        return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0, "rougeLsum": 0.0}

print("‚úÖ Metrics defined")

## Step 6: Training Setup - EMERGENCY FIXES

In [None]:
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    label_pad_token_id=-100,
    padding=True,  # Explicit padding
)

# üî• TEST DATA COLLATOR
print("\nüîç Testing data collator...")
test_batch = [tokenized_datasets["train"][i] for i in range(2)]
collated = data_collator(test_batch)

print(f"Collated batch:")
print(f"  Input IDs shape: {collated['input_ids'].shape}")
print(f"  Labels shape: {collated['labels'].shape}")

# Check labels
labels_check = collated['labels'][0]
valid_labels = (labels_check != -100).sum().item()
total_labels = len(labels_check)
print(f"  Valid labels: {valid_labels}/{total_labels} ({valid_labels/total_labels*100:.1f}%)")

if valid_labels == 0:
    print("‚ùå‚ùå‚ùå CRITICAL: Data collator produces all -100 labels!")
    raise RuntimeError("Data collator broken")
else:
    print("‚úÖ Data collator OK")

# Training arguments - EMERGENCY MODE
training_args = Seq2SeqTrainingArguments(
    output_dir="./mt5-emergency-fix",
    
    # Batch size
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    
    # üî• FIX 2: Higher learning rate
    learning_rate=1e-4,  # 10x higher!
    num_train_epochs=3,
    warmup_steps=200,    # Shorter warmup
    weight_decay=0.01,
    
    # Eval
    eval_strategy="steps",
    eval_steps=500,
    
    # Generation
    predict_with_generate=True,
    generation_max_length=128,
    generation_num_beams=4,
    
    # üî• FIX 3: DISABLE FP16 - can cause loss=0 issues
    fp16=False,  # Disabled!
    gradient_checkpointing=True,
    
    # Logging
    logging_steps=10,     # Log every 10 steps
    logging_first_step=True,
    save_steps=500,
    save_total_limit=2,
    
    load_best_model_at_end=True,
    metric_for_best_model="rouge1",
    
    report_to="none",
)

# Create trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print("\n‚úÖ Trainer created")
print(f"\n‚ö†Ô∏è  EMERGENCY MODE ENABLED:")
print(f"  - FP16 disabled (avoid numerical issues)")
print(f"  - Learning rate: {training_args.learning_rate} (10x normal)")
print(f"  - Short warmup: {training_args.warmup_steps} steps")
print(f"  - Frequent logging: every {training_args.logging_steps} steps")

## Step 7: FINAL CHECK Before Training

In [None]:
print("\n" + "="*60)
print("üîç FINAL PRE-TRAINING CHECK")
print("="*60)

# Get a real batch from the dataloader
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    tokenized_datasets["train"],
    batch_size=4,
    collate_fn=data_collator,
)

# Get first batch
first_batch = next(iter(train_dataloader))
first_batch = {k: v.to(device) for k, v in first_batch.items()}

print("\n1. Batch shape check:")
print(f"   Input IDs: {first_batch['input_ids'].shape}")
print(f"   Labels: {first_batch['labels'].shape}")
print(f"   Attention mask: {first_batch['attention_mask'].shape}")

print("\n2. Labels validity check:")
for i in range(min(2, first_batch['labels'].shape[0])):
    labels = first_batch['labels'][i]
    valid = (labels != -100).sum().item()
    print(f"   Sample {i}: {valid}/{len(labels)} valid tokens ({valid/len(labels)*100:.1f}%)")
    if valid == 0:
        print(f"      ‚ùå ALL -100!")

print("\n3. Forward pass with real batch:")
model.train()  # Ensure training mode
outputs = model(**first_batch)
batch_loss = outputs.loss.item()

print(f"   Loss: {batch_loss:.4f}")
print(f"   Loss requires_grad: {outputs.loss.requires_grad}")

if batch_loss == 0.0:
    print("\n‚ùå‚ùå‚ùå CRITICAL ERROR: Loss is 0!")
    print("DO NOT START TRAINING!")
    print("\nPossible issues:")
    print("- All labels are -100")
    print("- Model parameters are frozen")
    print("- Incorrect loss computation")
    raise RuntimeError("Training loss is 0")
elif torch.isnan(outputs.loss):
    print("\n‚ùå‚ùå‚ùå CRITICAL ERROR: Loss is NaN!")
    raise RuntimeError("Training loss is NaN")
else:
    print(f"   ‚úÖ Loss is normal!")

print("\n4. Backward pass test:")
outputs.loss.backward()
print("   ‚úÖ Backward pass successful")

# Check gradients
grad_norm = 0
for p in model.parameters():
    if p.grad is not None:
        grad_norm += p.grad.norm().item() ** 2
grad_norm = grad_norm ** 0.5

print(f"   Gradient norm: {grad_norm:.4f}")
if grad_norm == 0:
    print("   ‚ùå No gradients!")
else:
    print("   ‚úÖ Gradients OK")

# Clear gradients
model.zero_grad()

print("\n" + "="*60)
print("‚úÖ ALL CHECKS PASSED - READY TO TRAIN")
print("="*60)
print("\n‚ö†Ô∏è  WATCH FOR:")
print("  - First step loss should be 2-8")
print("  - Loss should NOT be 0 or NaN")
print("  - Loss should decrease over time")
print("  - ROUGE should be > 0 after first eval")
print("\n" + "="*60)

## Step 8: TRAIN üöÄ

In [None]:
print("\nüöÄ Starting training...")
print("Expected time: ~1-1.5 hours")
print("="*60)

trainer.train()

print("\n" + "="*60)
print("‚úÖ Training complete!")

## Step 9: Evaluate

In [None]:
results = trainer.evaluate(eval_dataset=tokenized_datasets["test"])

print("\n" + "="*50)
print("TEST RESULTS")
print("="*50)
for key, value in results.items():
    if 'rouge' in key:
        print(f"{key.upper()}: {value:.4f}")

## Step 10: Test Generation

In [None]:
def generate_summary(text, max_length=128, num_beams=4):
    inputs = tokenizer("t√≥m t·∫Øt: " + text, max_length=512, truncation=True, return_tensors="pt").to(device)
    outputs = model.generate(**inputs, max_length=max_length, num_beams=num_beams)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test
for i in range(3):
    test_text = dataset['test'][i]['document']
    ground_truth = dataset['test'][i]['summary']
    
    print(f"\n--- Example {i+1} ---")
    print(f"Document: {test_text[:200]}...")
    print(f"\nGenerated: {generate_summary(test_text)}")
    print(f"\nGround truth: {ground_truth}")
    print("="*60)