### Dataset

In [1]:
!pip install -q transformers sentencepiece datasets accelerate evaluate sacrebleu

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from datasets import load_dataset

ds = load_dataset('thainq107/iwslt2015-en-vi')

README.md:   0%|          | 0.00/522 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/17.8M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/181k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/133317 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1268 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1268 [00:00<?, ? examples/s]

In [3]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

### Tokenizer

In [4]:
from transformers import AutoTokenizer

model_name = 'facebook/mbart-large-50-many-to-many-mmt'
tokenizer = AutoTokenizer.from_pretrained(model_name)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

tokenizer_config.json:   0%|          | 0.00/529 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.43k [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/649 [00:00<?, ?B/s]

### Encoding

In [5]:
import torch

MAX_LEN = 75

def preprocess_function(examples):
    input_ids = tokenizer(
        examples['en'], padding='max_length',
        truncation=True,
        max_length=MAX_LEN,
    )['input_ids']

    labels = tokenizer(
        examples['vi'], padding='max_length',
        truncation=True,
        max_length=MAX_LEN
    )['input_ids']
    labels = [
        [-100 if item == tokenizer.pad_token_id else item for item in label]
        for label in labels
    ]

    return {
        'input_ids': torch.tensor(input_ids),
        'labels': torch.tensor(labels)
    }

preprocessed_ds = ds.map(preprocess_function, batched=True)

Map:   0%|          | 0/133317 [00:00<?, ? examples/s]

Map:   0%|          | 0/1268 [00:00<?, ? examples/s]

Map:   0%|          | 0/1268 [00:00<?, ? examples/s]

### Model

In [6]:
from transformers import AutoModelForSeq2SeqLM

model_name = 'facebook/mbart-large-50-many-to-many-mmt'
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

model.safetensors:   0%|          | 0.00/2.44G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/261 [00:00<?, ?B/s]

In [7]:
import numpy as np
import evaluate
metric = evaluate.load('sacrebleu')

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred  in preds]
    labesl = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(
        preds, skip_special_tokens=True,
        clean_up_tokenization_spaces=True
    )

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(
        labels, skip_special_tokens=True,
        clean_up_tokenization_spaces=True
    )

    decoded_preds, decoded_labels = postprocess_text(
        decoded_preds, decoded_labels
    )

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {'bleu': result['score']}

    return result

Downloading builder script:   0%|          | 0.00/8.15k [00:00<?, ?B/s]

### Trainer

In [8]:
import os
os.environ['WANDB_DISABLED'] = 'true'

from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, Seq2SeqTrainer

training_args = Seq2SeqTrainingArguments(
    output_dir='./en-vi-mbart50',
    logging_dir='logs',
    logging_steps=1000,
    predict_with_generate=True,
    eval_strategy='steps',
    eval_steps=1000,
    save_strategy='steps',
    save_steps=1000,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    save_total_limit=1,
    num_train_epochs=1,
    load_best_model_at_end=True,
)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model,
    training_args,
    train_dataset=preprocessed_ds['train'],
    eval_dataset=preprocessed_ds['validation'],
    data_collator=data_collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics
)
trainer.train()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Step,Training Loss,Validation Loss,Bleu
1000,1.2211,1.329548,33.584446
2000,1.1835,1.285117,34.296266
3000,1.1467,1.255791,34.678876


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Tr

RuntimeError: [enforce fail at inline_container.cc:603] . unexpected pos 3753363712 vs 3753363600

In [None]:
trainer.push_to_hub()

In [None]:
trainer.state.log_history