In [None]:
!pip install transformers datasets sentencepiece --quiet

from transformers import MarianMTModel, MarianTokenizer
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import Dataset
import torch

# Model and tokenizer
model_name = "Helsinki-NLP/opus-mt-en-de"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Small data set (only 3 examples)
data = [
    {"en": "The patient has a fever.", "de": "Der Patient hat Fieber."},
    {"en": "Take this medicine twice a day.", "de": "Nehmen Sie dieses Medikament zweimal täglich."},
    {"en": "You need to rest.", "de": "Sie müssen sich ausruhen."}
]
dataset = Dataset.from_list(data)

# Tokenizer
def tokenize(batch):
    inputs = tokenizer(batch['en'], padding="max_length", truncation=True, max_length=64)
    targets = tokenizer(batch['de'], padding="max_length", truncation=True, max_length=64)
    inputs["labels"] = targets["input_ids"]
    return inputs

tokenized_dataset = dataset.map(tokenize, batched=True)

# Training options
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    num_train_epochs=1,
    weight_decay=0.01,
    save_total_limit=1,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),
    logging_dir='./logs',
    logging_steps=1,
    report_to="none"
)

# Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer
)

# Sample before training
def translate(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    translated = model.generate(**inputs)
    return tokenizer.decode(translated[0], skip_special_tokens=True)

print("Before fine-tuning:")
print("English: I have a math exam tomorrow.")
print("German:", translate("I have a math exam tomorrow."))

# Training
trainer.train()

print("\nAfter fine-tuning:")
print("English: I have a math exam tomorrow.")
print("German:", translate("I have a math exam tomorrow."))
