In [13]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from trl import DPOTrainer, DPOConfig
from datasets import Dataset

# 1. Load base instruction model and tokenizer
model_name = "google/flan-t5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 2. Larger dummy preference dataset
samples = [
    {"prompt": "Translate 'Hello' to French:", "chosen": "Bonjour", "rejected": "Hola"},
    {"prompt": "What is 2+2?", "chosen": "4", "rejected": "5"},
    {"prompt": "Translate 'Thank you' to Spanish:", "chosen": "Gracias", "rejected": "Merci"},
    {"prompt": "What is the capital of France?", "chosen": "Paris", "rejected": "Florida"},
    {"prompt": "Translate 'Good morning' to German:", "chosen": "Guten Morgen", "rejected": "Buenos Dias"},
    {"prompt": "What is 3x3?", "chosen": "9", "rejected": "6"},
    {"prompt": "Translate 'I love you' to Italian:", "chosen": "Ti amo", "rejected": "Te amo"},
    {"prompt": "What is the largest planet?", "chosen": "Jupiter", "rejected": "Mars"},
    {"prompt": "Translate 'Goodbye' to Japanese:", "chosen": "Sayonara", "rejected": "Adios"},
    {"prompt": "What is 5-2?", "chosen": "3", "rejected": "4"},
]
train_dataset = Dataset.from_list(samples[:8])
eval_dataset = Dataset.from_list(samples[8:])

# 3. Verify dataset format
def verify_dataset(example):
    assert all(k in example for k in ["prompt", "chosen", "rejected"]), "Missing keys"
    return example

train_dataset = train_dataset.map(verify_dataset)
eval_dataset = eval_dataset.map(verify_dataset)

# 4. Training args
args = DPOConfig(
    output_dir="./dpo_flan_t5",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=5,
    logging_steps=1,
    save_strategy="no",
    learning_rate=3e-5,
    remove_unused_columns=False,
    report_to=[],
    beta=0.1,
    max_length=128,
    max_prompt_length=64,
    eval_strategy="steps",
    eval_steps=2,
)

# 5. DPO Trainer
trainer = DPOTrainer(
    model=model,
    ref_model=None,   # reference model auto-created
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    processing_class=tokenizer,
)

# 6. Train
trainer.train()

# 7. Save model
trainer.model.save_pretrained("./dpo_flan_t5_model")
tokenizer.save_pretrained("./dpo_flan_t5_model")

# 8. Inference with trained model
def generate(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(trainer.model.device)
    with torch.no_grad():
        outputs = trainer.model.generate(
            **inputs,
            max_new_tokens=30,
            num_beams=4,
            early_stopping=True,
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

# Test
test_prompt = "Translate 'I am going to school' to French:"
print("\nGenerated Response:\n", generate(test_prompt))


Map:   0%|          | 0/8 [00:00<?, ? examples/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Extracting prompt in train dataset:   0%|          | 0/8 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/8 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/8 [00:00<?, ? examples/s]

Extracting prompt in eval dataset:   0%|          | 0/2 [00:00<?, ? examples/s]

Applying chat template to eval dataset:   0%|          | 0/2 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/2 [00:00<?, ? examples/s]

Step,Training Loss,Validation Loss,Rewards/chosen,Rewards/rejected,Rewards/accuracies,Rewards/margins,Logps/chosen,Logps/rejected,Logits/chosen,Logits/rejected
2,0.7467,0.698322,-0.014705,-0.004424,0.5,-0.010281,-16.59931,-13.265632,-14.172015,-12.658319
4,0.7209,0.705713,-0.029208,-0.004395,0.5,-0.024812,-16.744335,-13.265342,-14.281544,-12.759686
6,0.629,0.709102,-0.046065,-0.014675,0.5,-0.03139,-16.912905,-13.368136,-14.37464,-12.838968
8,0.6446,0.713161,-0.064511,-0.025329,0.5,-0.039182,-17.097366,-13.474674,-14.432551,-12.881265
10,0.4966,0.711789,-0.069058,-0.032512,0.5,-0.036546,-17.142838,-13.546513,-14.501986,-12.926527
12,0.6764,0.714571,-0.084044,-0.04226,0.5,-0.041784,-17.292694,-13.643986,-14.56069,-12.992306
14,0.523,0.717714,-0.085411,-0.037593,0.5,-0.047818,-17.306368,-13.597322,-14.562527,-12.985661
16,0.6575,0.718935,-0.09073,-0.040626,0.5,-0.050104,-17.35956,-13.627646,-14.597715,-13.019164
18,0.6513,0.717553,-0.091259,-0.043726,0.5,-0.047533,-17.364851,-13.658651,-14.610106,-13.02462
20,0.7553,0.720287,-0.096043,-0.043275,0.5,-0.052767,-17.412683,-13.65414,-14.59933,-13.018983



Generated Response:
 Je vais à l'école
