# 04 - Training / Fine-Tuning

This notebook handles fine-tuning a model for text simplification.

## Overview
1. Load preprocessed data from `data/processed/`
2. Configure training hyperparameters
3. Fine-tune with HuggingFace Trainer
4. Evaluate on test set
5. Export model for inference


In [None]:
# Setup
import json
from pathlib import Path

# Uncomment for training
# import torch
# from transformers import (
#     AutoTokenizer,
#     AutoModelForSeq2SeqLM,
#     Seq2SeqTrainer,
#     Seq2SeqTrainingArguments,
#     DataCollatorForSeq2Seq,
# )
# from datasets import Dataset

PROJECT_ROOT = Path("..").resolve()
DATA_PROCESSED = PROJECT_ROOT / "data" / "processed"
MODELS_DIR = PROJECT_ROOT / "models"

print(f"Data: {DATA_PROCESSED}")
print(f"Models will be saved to: {MODELS_DIR}")


## 1. Configuration


In [None]:
# Training configuration
CONFIG = {
    # Model
    "base_model": "google/mt5-small",  # or mt5-base, flan-t5-base
    "max_source_length": 512,
    "max_target_length": 256,
    
    # Training
    "batch_size": 8,
    "learning_rate": 5e-5,
    "num_epochs": 3,
    "warmup_steps": 500,
    "weight_decay": 0.01,
    
    # Outputs
    "output_dir": str(MODELS_DIR / "klartext-mt5-small"),
    "logging_steps": 100,
    "save_steps": 500,
    "eval_steps": 500,
}

print("Training config:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")


## 2. Load Data


In [None]:
def load_jsonl(path: Path) -> list:
    """Load JSONL file."""
    data = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return data

# Load data (uncomment when you have processed data)
# train_data = load_jsonl(DATA_PROCESSED / "train.jsonl")
# val_data = load_jsonl(DATA_PROCESSED / "val.jsonl")
# test_data = load_jsonl(DATA_PROCESSED / "test.jsonl")

# print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

print("Data loading ready (uncomment when processed data exists)")


## 3. Preprocessing for Training


In [None]:
# Uncomment for actual training

# tokenizer = AutoTokenizer.from_pretrained(CONFIG["base_model"])

# def preprocess_function(examples):
#     """Tokenize source and target texts."""
#     # Add task prefix
#     inputs = [f"simplify: {text}" for text in examples["source"]]
#     targets = examples["target"]
#     
#     model_inputs = tokenizer(
#         inputs, 
#         max_length=CONFIG["max_source_length"], 
#         truncation=True,
#         padding="max_length",
#     )
#     
#     labels = tokenizer(
#         targets,
#         max_length=CONFIG["max_target_length"],
#         truncation=True,
#         padding="max_length",
#     )
#     
#     model_inputs["labels"] = labels["input_ids"]
#     return model_inputs

# # Create datasets
# train_dataset = Dataset.from_list(train_data).map(preprocess_function, batched=True)
# val_dataset = Dataset.from_list(val_data).map(preprocess_function, batched=True)

print("Preprocessing functions defined")


## 4. Training


In [None]:
# Uncomment for actual training

# model = AutoModelForSeq2SeqLM.from_pretrained(CONFIG["base_model"])

# training_args = Seq2SeqTrainingArguments(
#     output_dir=CONFIG["output_dir"],
#     evaluation_strategy="steps",
#     eval_steps=CONFIG["eval_steps"],
#     learning_rate=CONFIG["learning_rate"],
#     per_device_train_batch_size=CONFIG["batch_size"],
#     per_device_eval_batch_size=CONFIG["batch_size"],
#     weight_decay=CONFIG["weight_decay"],
#     save_total_limit=3,
#     num_train_epochs=CONFIG["num_epochs"],
#     predict_with_generate=True,
#     logging_steps=CONFIG["logging_steps"],
#     save_steps=CONFIG["save_steps"],
#     warmup_steps=CONFIG["warmup_steps"],
#     fp16=torch.cuda.is_available(),  # Use mixed precision if GPU available
# )

# data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# trainer = Seq2SeqTrainer(
#     model=model,
#     args=training_args,
#     train_dataset=train_dataset,
#     eval_dataset=val_dataset,
#     tokenizer=tokenizer,
#     data_collator=data_collator,
# )

# # Train!
# trainer.train()

# # Save final model
# trainer.save_model(CONFIG["output_dir"])
# tokenizer.save_pretrained(CONFIG["output_dir"])

print("Training code ready (uncomment to run)")


## 5. Evaluation


In [None]:
# Evaluate on test set (uncomment after training)

# def generate_simplification(text: str, model, tokenizer, max_length=256):
#     inputs = tokenizer(f"simplify: {text}", return_tensors="pt", truncation=True)
#     outputs = model.generate(**inputs, max_length=max_length, num_beams=4)
#     return tokenizer.decode(outputs[0], skip_special_tokens=True)

# # Load trained model
# model = AutoModelForSeq2SeqLM.from_pretrained(CONFIG["output_dir"])
# tokenizer = AutoTokenizer.from_pretrained(CONFIG["output_dir"])

# # Test on a few examples
# for example in test_data[:5]:
#     output = generate_simplification(example["source"], model, tokenizer)
#     print(f"Source: {example['source'][:100]}...")
#     print(f"Target: {example['target'][:100]}...")
#     print(f"Output: {output[:100]}...")
#     print("-" * 50)

print("Evaluation code ready")


## 6. Export for Production

After training, export the model for use in the KlarText API.


In [None]:
# Export options:

# 1. HuggingFace Hub (recommended for sharing)
# model.push_to_hub("klartext/klartext-mt5-small")
# tokenizer.push_to_hub("klartext/klartext-mt5-small")

# 2. ONNX export (for faster inference)
# from optimum.onnxruntime import ORTModelForSeq2SeqLM
# ort_model = ORTModelForSeq2SeqLM.from_pretrained(CONFIG["output_dir"], export=True)
# ort_model.save_pretrained(CONFIG["output_dir"] + "-onnx")

# 3. Quantization (for smaller model size)
# from transformers import AutoModelForSeq2SeqLM
# model = AutoModelForSeq2SeqLM.from_pretrained(CONFIG["output_dir"])
# model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

print("Export options documented. Uncomment preferred method after training.")
