## Loading the Data

In [1]:
from datasets import load_from_disk
dataset = load_from_disk('ape-ds')

In [2]:
dataset['train'][0]

{'bo': 'ཞི་ཁྲོ་སྤྲུལ་པའི་སྐུ་ལ་ཕྱག་འཚལ་ལོ༔',
 'en': 'Nirmāṇakāya peaceful and wrathful: to you I pay homage!',
 'topic': 'Confession, Termas, Tibetan Masters, Nyala Pema Dündul',
 'for-post-edit': 'Homage to the tathāgatas and the Victorious Ones of the Victorious Ones: homage to you!'}

## Load Unfinetuned Tokenizer, Model, and Data Collator

In [5]:
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM

checkpoint = "google-t5/t5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, device_map="cuda:0")
tokenizer = AutoTokenizer.from_pretrained('my-tokenizer')
model.resize_token_embeddings(len(tokenizer))
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

## Preprocess Data

The dataset can now be tokenized for training.

In [6]:
def preprocess_function(examples):

    # Prepare translation inputs and targets
    inputs = ['Post-Edit Translation: ' + example for example in examples['for-post-edit']]
    targets = [example for example in examples['en']]
    
    # Tokenize translation inputs and targets
    model_inputs = tokenizer(inputs, text_target=targets, 
                                         max_length=256, truncation=False, padding="max_length")
    
    
    return model_inputs


In [7]:
tokenized_dataset = dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/193544 [00:00<?, ? examples/s]

Map:   0%|          | 0/21506 [00:00<?, ? examples/s]

## Train the Model

Finally, we can train the model. Note that the optimizer used is Adafactor. This is the optimizer that is preferred for translation tasks and for the T5 model in general. The transformers api includes a built in version of Adafactor, but I define it separately here so that we can optimize it with the 'accelerate' library.

In [8]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, Adafactor, EarlyStoppingCallback
from accelerate import Accelerator

accelerator = Accelerator()

optimizer = Adafactor(
    model.parameters(), 
    scale_parameter=True, 
    relative_step=False, 
    warmup_init=False, 
    lr=3e-4
)

model, optimizer = accelerator.prepare(model, optimizer)

In [9]:
import numpy as np
import evaluate

# Load BLEU and CHRF metrics
bleu_metric = evaluate.load("sacrebleu")
chrf_metric = evaluate.load("chrf")
ter_metric = evaluate.load("ter")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    
    # Decode predictions and labels
    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Postprocess text
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    # Compute BLEU score
    bleu_result = bleu_metric.compute(predictions=decoded_preds, references=decoded_labels)
    bleu_score = bleu_result["score"]

    # Compute CHRF score
    chrf_result = chrf_metric.compute(predictions=decoded_preds, references=decoded_labels)
    chrf_score = chrf_result["score"]

    # Compute TER score
    ter_result = ter_metric.compute(predictions=decoded_preds, references=decoded_labels)
    ter_score = ter_result["score"]

    # Return rounded results
    metrics = {
        "bleu": round(bleu_score, 4),
        "chrf": round(chrf_score, 4),
        "ter": round(ter_score, 4)
    }

    #print("Computed Metrics:", metrics)

    return metrics

Using the latest cached version of the module from /home/j/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--ter/9c9af3214842a93c26c9ac10ddc5d07559df8e38c0ab3b599e8121b9aae196bd (last modified on Tue Jan 28 22:53:15 2025) since it couldn't be found locally at evaluate-metric--ter, or remotely on the Hugging Face Hub.


In [10]:
%env WANDB_PROJECT=ape-experiment

env: WANDB_PROJECT=ape-experiment


In [11]:
training_args = Seq2SeqTrainingArguments(
    output_dir=f"post-edit-model",
    auto_find_batch_size=True,
    predict_with_generate=True,
    fp16=False,
    push_to_hub=False,
    eval_strategy='epoch',
    save_strategy='epoch',
    num_train_epochs=10
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['test'],
    tokenizer=tokenizer,
    optimizers=(optimizer, None),
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()

  trainer = Seq2SeqTrainer(
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.0111124526889196, max=1.0))…

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss,Bleu,Chrf,Ter
1,0.1625,0.165366,21.5452,35.3344,82.9788
2,0.1358,0.138631,22.2272,36.0762,82.0887
3,0.116,0.120754,24.2156,37.3817,79.3348
4,0.1037,0.108298,26.3561,39.0103,77.1086
5,0.0919,0.099292,28.6844,40.6832,74.6957
6,0.0826,0.092779,30.2903,41.8956,73.0387
7,0.078,0.088224,31.8423,43.1804,71.63
8,0.073,0.084906,32.5567,44.1424,71.3292
9,0.0718,0.082991,33.5175,44.6203,70.1556
10,0.0688,0.082397,33.6962,44.8243,70.0194


TrainOutput(global_step=241930, training_loss=0.10358538406284401, metrics={'train_runtime': 48732.8771, 'train_samples_per_second': 39.715, 'train_steps_per_second': 4.964, 'total_flos': 1.3097296812048384e+17, 'train_loss': 0.10358538406284401, 'epoch': 10.0})