In [1]:
import transformers
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, EarlyStoppingCallback
from datasets import load_dataset

dataset = load_dataset("csv", data_files={"train": "train1.csv", "validation": "validation1.csv"}, encoding='latin-1')

model_name = "Helsinki-NLP/opus-mt-en-es"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

def preprocess_function(examples):
    inputs = [ex + f" </s>" for ex in examples["source"]]
    targets = [ex + f" </s>" for ex in examples["target"]]

    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length")

    model_inputs["labels"] = labels["input_ids"]
    
    return model_inputs

# Preprocess dataset
tokenized_datasets = dataset.map(preprocess_function, batched=True)


In [None]:
#Training arguments
training_args = TrainingArguments(
    output_dir="./helsinki_finetuned_2e-4_16",
    logging_dir="./logs",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4,
    num_train_epochs=30,  
    learning_rate=2e-4,  
    warmup_steps=500,
    weight_decay=0.01,
    adam_epsilon=1e-8,
    max_grad_norm=1.0,
    save_strategy="epoch",
    save_total_limit=2,
    logging_strategy="epoch",
    logging_steps=100,
    fp16=True if torch.cuda.is_available() else False,
    eval_strategy="epoch",
    metric_for_best_model="eval_loss",
    load_best_model_at_end=True,
    resume_from_checkpoint = True,
)

# Create the Trainer and train the model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],  
    eval_dataset=tokenized_datasets["validation"],
    callbacks=[EarlyStoppingCallback(early_stopping_patience=1)],
    processing_class=tokenizer,
)

trainer.train()


In [None]:
# Example of using the model for inference:
from transformers import pipeline

translator = pipeline('translation_en_to_es', model='./helsinki_finetuned_5e-5_16/checkpoint-902', tokenizer=tokenizer)

text = "What is the name of your store?"

translated_text = translator(text)
translated_text = translated_text[0].get('translation_text')
print(translated_text)
