In [1]:
# train_en_to_es_sft.py

import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    set_seed,
)
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
import random
import numpy as np

# Set seed for reproducibility
set_seed(42)

# ==========================
# Config
# ==========================
MODEL_NAME = "HuggingFaceTB/SmolLM2-135M"
DATASET_PATH = "exp-data/en-es-train-val.parquet"
MAX_SEQ_LENGTH = 512
BATCH_SIZE = 32
GRADIENT_ACCUMULATION_STEPS = 4
LEARNING_RATE = 2e-4
NUM_EPOCHS = 3
OUTPUT_DIR = "./smollm2-135m-en-es-lora"
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05

# ==========================
# Load tokenizer and model (we'll use LoRA so base model weights stay untouched)
# ==========================
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

peft_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

# ==========================
# Load and prepare dataset
# ==========================
dataset = load_dataset("parquet", data_files=DATASET_PATH)["train"]
# use 50%
dataset = dataset.train_test_split(test_size=0.5, seed=42)["train"]
# Split into train/val based on 'split' column
train_dataset = dataset.filter(lambda x: x["split"] == "train")
val_dataset   = dataset.filter(lambda x: x["split"] == "val")

# Instruction template for translation
INSTRUCTION = "English: {en} Spanish:"

def formatting_prompts_func(example):
    text = INSTRUCTION.format(en=example["EN"]) + " " + example["ES"] + tokenizer.eos_token
    return {"text": text}

train_dataset = train_dataset.map(formatting_prompts_func)
val_dataset   = val_dataset.map(formatting_prompts_func)

# Tokenize
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=MAX_SEQ_LENGTH,
        padding=False,  # Will be handled by DataCollatorForLanguageModeling
    )

example_samples = val_dataset.select(range(3))
print(example_samples)

train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=train_dataset.column_names)
val_dataset   = val_dataset.map(tokenize_function, batched=True, remove_columns=val_dataset.column_names)

# ==========================
# Custom callback to print 3 samples every epoch
# ==========================
from transformers import TrainerCallback
from peft import PeftModel

class TranslationEvalCallback(TrainerCallback):
    def __init__(self, tokenizer, val_dataset, num_samples=3):
        self.tokenizer = tokenizer
        self.val_dataset = val_dataset.select(range(min(num_samples * 10, len(val_dataset))))  # small pool
        self.num_samples = num_samples
        self.base_model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
        )
        self.base_model.eval()

    def generate_translation(self, model, prompt):
        inputs = self.tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=128,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=self.tokenizer.eos_token_id,

            )
        full_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
        # Extract only the Spanish part after "Spanish:"
        try:
            spanish = full_text.split("Spanish:")[1].strip().split(tokenizer.eos_token)[0].strip()
        except:
            spanish = full_text.split("Spanish:")[1].strip() if "Spanish:" in full_text else "ERROR"
        return spanish

    def on_epoch_end(self, args, state, control, **kwargs):
        model = kwargs["model"]
        model.eval()

        print("\n" + "="*80)
        print(f"END OF EPOCH {state.epoch:.1f} - SAMPLE TRANSLATIONS")
        print("="*80)


        for i, sample in enumerate(example_samples, 1):
            prompt = INSTRUCTION.format(en=sample["EN"])

            base_pred = self.generate_translation(self.base_model, prompt)
            current_pred = self.generate_translation(model, prompt)
            reference = sample["ES"]

            print(f"\nSample {i}:")
            print(f"EN → {sample['EN']}")
            print(f"REF → {reference}")
            # print(f"BASE (SmolLM2-135M) → {base_pred}")
            print(f"CURRENT (Fine-tuned) → {current_pred}")
            print("-"*80)

        model.train()

# ==========================
# Training arguments
# ==========================
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    logging_steps=10,
    save_strategy="epoch",
    eval_strategy="no",
    fp16=False,
    bf16=True,
    report_to="none",  # Change to "wandb" if you use it
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",
    remove_unused_columns=False,
)

# ==========================
# SFTTrainer
# ==========================
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    # tokenizer=tokenizer,
    peft_config=peft_config,
    # max_seq_length=MAX_SEQ_LENGTH,
    # dataset_text_field="text",  # We already have tokenized input_ids, but SFTTrainer can handle it
    # packing=False,
    callbacks=[TranslationEvalCallback(tokenizer, val_dataset, num_samples=3)],
)


  from .autonotebook import tqdm as notebook_tqdm
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
Map: 100%|██████████| 50061/50061 [00:04<00:00, 10425.96 examples/s]
Map: 100%|██████████| 50061/50061 [00:04<00:00, 10425.96 examples/s]
Map: 100%|██████████| 4939/4939 [00:00<00:00, 10461.34 examples/s]



Dataset({
    features: ['dataset', 'split', 'EN', 'ES', 'length', 'text'],
    num_rows: 3
})


Map: 100%|██████████| 50061/50061 [00:03<00:00, 15938.90 examples/s]
Map: 100%|██████████| 50061/50061 [00:03<00:00, 15938.90 examples/s]
Map: 100%|██████████| 4939/4939 [00:00<00:00, 16213.68 examples/s]

Truncating train dataset: 100%|██████████| 50061/50061 [00:00<00:00, 1442386.26 examples/s]
Truncating train dataset: 100%|██████████| 50061/50061 [00:00<00:00, 1442386.26 examples/s]
Truncating eval dataset: 100%|██████████| 4939/4939 [00:00<00:00, 976601.33 examples/s]
The model is already on multiple devices. Skipping the move to device specified in `args`.

The model is already on multiple devices. Skipping the move to device specified in `args`.


In [2]:
# Reload base model
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="cpu",
    trust_remote_code=True,
)

# Load LoRA adapter on top
lora_model = PeftModel.from_pretrained(base_model, OUTPUT_DIR)

# Merge and unload adapters
merged_model = lora_model.merge_and_unload()

# Save merged full model
MERGED_OUTPUT_DIR = OUTPUT_DIR + "-merged"
merged_model.save_pretrained(MERGED_OUTPUT_DIR)
tokenizer.save_pretrained(MERGED_OUTPUT_DIR)

print(f"Merged full model saved to {MERGED_OUTPUT_DIR}")

Merged full model saved to ./smollm2-135m-en-es-lora-merged
