In [1]:
import torch
import numpy as np
from transformers import set_seed

# random seed
seed = 42

# set random seed
if seed is not None:
    print(f'random seed: {seed}')
    set_seed(seed)

random seed: 42


In [2]:
transformer_name = 't5-small'
dataset_name = 'wmt16'
dataset_config_name = 'ro-en'
source_lang = 'ro'
target_lang = 'en'
max_source_length = 1024
max_target_length = 128
task_prefix = 'translate Romanian to English: '
batch_size = 4
label_pad_token_id = -100
save_steps = 25_000
num_beams = 1
learning_rate = 1e-3
num_train_epochs = 3
output_dir = '/media/data2/t5-translation-example'

In [3]:
from datasets import load_dataset

wmt16 = load_dataset(dataset_name, dataset_config_name)

Reusing dataset wmt16 (/home/marco/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/9e0038fe4cc117bd474d2774032cc133e355146ed0a47021b2040ca9db4645c0)


  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM

config = AutoConfig.from_pretrained(transformer_name)
tokenizer = AutoTokenizer.from_pretrained(transformer_name)
model = AutoModelForSeq2SeqLM.from_pretrained(transformer_name, config=config)

In [5]:
def tokenize(batch):
    # get source sentences and prepend task prefix
    sources = [x[source_lang] for x in batch["translation"]]
    sources = [task_prefix + x for x in sources]
    # tokenize source sentences
    output = tokenizer(
        sources,
        max_length=max_source_length,
        truncation=True,
    )

    # get target sentences
    targets = [x[target_lang] for x in batch["translation"]]
    # tokenize target sentences
    labels = tokenizer(
        targets,
        max_length=max_target_length,
        truncation=True,
    )
    # add targets to output
    output["labels"] = labels["input_ids"]

    return output

In [6]:
train_dataset = wmt16['train']
eval_dataset = wmt16['validation']
column_names = train_dataset.column_names

train_dataset = train_dataset.map(
    tokenize,
    batched=True,
    remove_columns=column_names,
)

eval_dataset = eval_dataset.map(
    tokenize,
    batched=True,
    remove_columns=column_names,
)

Loading cached processed dataset at /home/marco/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/9e0038fe4cc117bd474d2774032cc133e355146ed0a47021b2040ca9db4645c0/cache-23cc4847e3a6788f.arrow
Loading cached processed dataset at /home/marco/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/9e0038fe4cc117bd474d2774032cc133e355146ed0a47021b2040ca9db4645c0/cache-2cd8d8c9a5ee0dc9.arrow


In [7]:
train_dataset.to_pandas()

Unnamed: 0,input_ids,attention_mask,labels
0,"[13959, 3871, 29, 12, 1566, 10, 4961, 106, 204...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[19428, 13, 12876, 10, 217, 13687, 7, 1]"
1,"[13959, 3871, 29, 12, 1566, 10, 5085, 5840, 49...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[2276, 8843, 138, 13, 13687, 7, 13, 1767, 3823..."
2,"[13959, 3871, 29, 12, 1566, 10, 4961, 106, 204...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[19428, 13, 12876, 10, 217, 13687, 7, 1]"
3,"[13959, 3871, 29, 12, 1566, 10, 781, 8750, 9, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[781, 2420, 13, 17500, 10, 217, 13687, 7, 1]"
4,"[13959, 3871, 29, 12, 1566, 10, 374, 6225, 49,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[11167, 7, 1204, 10, 217, 13687, 7, 1]"
...,...,...,...
610315,"[13959, 3871, 29, 12, 1566, 10, 4540, 4031, 9,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[4540, 4031, 9, 7, 1672, 7, 2262, 900, 17, 38,..."
610316,"[13959, 3871, 29, 12, 1566, 10, 2364, 4540, 40...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[242, 4540, 4031, 9, 7, 6, 8, 516, 65, 66, 8, ..."
610317,"[13959, 3871, 29, 12, 1566, 10, 2262, 900, 17,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[2262, 900, 17, 641, 65, 46, 3761, 6, 1069, 31..."
610318,"[13959, 3871, 29, 12, 1566, 10, 3, 25882, 759,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[9810, 157, 31, 7, 516, 92, 3088, 21, 46, 3839..."


In [8]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
)

In [9]:
from datasets import load_metric

metric = load_metric('sacrebleu')

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # get text for predictions
    predictions = tokenizer.batch_decode(
        preds,
        skip_special_tokens=True,
    )
    # replace -100 in labels with pad token
    labels = np.where(
        labels != -100,
        labels,
        tokenizer.pad_token_id,
    )
    # get text for gold labels
    references = tokenizer.batch_decode(
        labels,
        skip_special_tokens=True,
    )
    # metric expects list of references for each prediction
    references = [[ref] for ref in references]
    
    # compute bleu score
    results = metric.compute(
        predictions=predictions,
        references=references,
    )
    results = {'bleu': results['score']}
    
    return results

In [10]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_steps=save_steps,
    predict_with_generate=True,
    evaluation_strategy='steps',
    eval_steps=save_steps,
    learning_rate=learning_rate,
    num_train_epochs=num_train_epochs,
)

In [11]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [12]:
import os
from transformers.trainer_utils import get_last_checkpoint

last_checkpoint = None
if os.path.isdir(output_dir):
    last_checkpoint = get_last_checkpoint(output_dir)

if last_checkpoint is not None:
    print(f'Checkpoint detected, resuming training at {last_checkpoint}.')

In [13]:
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
trainer.save_model()

***** Running training *****
  Num examples = 610320
  Num Epochs = 3
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 457740


Step,Training Loss,Validation Loss,Bleu
25000,1.2485,1.831403,13.156476
50000,1.1744,1.743581,13.979321
75000,1.118,1.697292,14.718017
100000,1.0688,1.650289,14.942166
125000,1.0382,1.628099,14.984296
150000,0.9867,1.621915,15.411908
175000,0.9497,1.593482,15.777268
200000,0.9471,1.566495,15.372236
225000,0.9168,1.545556,15.787117
250000,0.9081,1.537921,15.864213


***** Running Evaluation *****
  Num examples = 1999
  Batch size = 4
Saving model checkpoint to /media/data2/t5-translation-example/checkpoint-25000
Configuration saved in /media/data2/t5-translation-example/checkpoint-25000/config.json
Model weights saved in /media/data2/t5-translation-example/checkpoint-25000/pytorch_model.bin
tokenizer config file saved in /media/data2/t5-translation-example/checkpoint-25000/tokenizer_config.json
Special tokens file saved in /media/data2/t5-translation-example/checkpoint-25000/special_tokens_map.json
Copy vocab file to /media/data2/t5-translation-example/checkpoint-25000/spiece.model
***** Running Evaluation *****
  Num examples = 1999
  Batch size = 4
Saving model checkpoint to /media/data2/t5-translation-example/checkpoint-50000
Configuration saved in /media/data2/t5-translation-example/checkpoint-50000/config.json
Model weights saved in /media/data2/t5-translation-example/checkpoint-50000/pytorch_model.bin
tokenizer config file saved in /media/d

***** Running Evaluation *****
  Num examples = 1999
  Batch size = 4
Saving model checkpoint to /media/data2/t5-translation-example/checkpoint-350000
Configuration saved in /media/data2/t5-translation-example/checkpoint-350000/config.json
Model weights saved in /media/data2/t5-translation-example/checkpoint-350000/pytorch_model.bin
tokenizer config file saved in /media/data2/t5-translation-example/checkpoint-350000/tokenizer_config.json
Special tokens file saved in /media/data2/t5-translation-example/checkpoint-350000/special_tokens_map.json
Copy vocab file to /media/data2/t5-translation-example/checkpoint-350000/spiece.model
***** Running Evaluation *****
  Num examples = 1999
  Batch size = 4
Saving model checkpoint to /media/data2/t5-translation-example/checkpoint-375000
Configuration saved in /media/data2/t5-translation-example/checkpoint-375000/config.json
Model weights saved in /media/data2/t5-translation-example/checkpoint-375000/pytorch_model.bin
tokenizer config file saved in

In [14]:
metrics = train_result.metrics
metrics['train_samples'] = len(train_dataset)

trainer.log_metrics('train', metrics)
trainer.save_metrics('train', metrics)
trainer.save_state()

***** train metrics *****
  epoch                    =        3.0
  total_flos               = 33926705GF
  train_loss               =     0.9658
  train_runtime            = 5:14:15.83
  train_samples            =     610320
  train_samples_per_second =     97.103
  train_steps_per_second   =     24.276


In [15]:
# https://discuss.huggingface.co/t/evaluation-results-metric-during-training-is-different-from-the-evaluation-results-at-the-end/15401

metrics = trainer.evaluate(
    max_length=max_target_length,
    num_beams=num_beams,
    metric_key_prefix='eval',
)

metrics['eval_samples'] = len(eval_dataset)

trainer.log_metrics('eval', metrics)
trainer.save_metrics('eval', metrics)

***** Running Evaluation *****
  Num examples = 1999
  Batch size = 4


***** eval metrics *****
  epoch                   =        3.0
  eval_bleu               =    35.1923
  eval_loss               =     1.4452
  eval_runtime            = 0:01:50.71
  eval_samples            =       1999
  eval_samples_per_second =     18.055
  eval_steps_per_second   =      4.516


In [16]:
kwargs = {
    'finetuned_from': transformer_name,
    'tasks': 'translation',
    'dataset_tags': dataset_name,
    'dataset_args': dataset_config_name,
    'dataset': f'{dataset_name} {dataset_config_name}',
    'language': [source_lang, target_lang],
}
trainer.create_model_card(**kwargs)