In [2]:
import os
os.environ['TRANSFORMERS_CACHE'] = '/media/dcl/7D03F77D0BDB12B9/HTR-Firat/Transformers/cache'

from transformers import MBartForConditionalGeneration, MBart50TokenizerFast,\
                        Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq

from datasets import load_dataset, load_metric
import numpy as np

In [15]:
TRAIN = True

## Model

In [4]:
model = MBartForConditionalGeneration.from_pretrained('facebook/mbart-large-50')

In [5]:
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", 
                                                src_lang="en_XX", tgt_lang="tr_TR")

In [6]:
metric = load_metric("sacrebleu")

## Data

In [7]:
prefix = ""
max_input_length = 90
max_target_length = 90
source_lang = "en"
target_lang = "tr"

def preprocess_function(examples):
    inputs = [prefix + ex[source_lang] for ex in examples["translation"]]
    targets = [ex[target_lang] for ex in examples["translation"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, padding = True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, truncation=True, padding = True)
    model_inputs["labels"] = labels["input_ids"]

    return model_inputs

In [8]:
dataset = load_dataset("wmt16", "tr-en")

Reusing dataset wmt16 (/home/dcl/.cache/huggingface/datasets/wmt16/tr-en/1.0.0/af3c5d746b307726d0de73ebe7f10545361b9cb6f75c83a1734c000e48b6264f)


  0%|          | 0/3 [00:00<?, ?it/s]

In [9]:
preprocessed_dataset = dataset.map(preprocess_function, batched=True)

Loading cached processed dataset at /home/dcl/.cache/huggingface/datasets/wmt16/tr-en/1.0.0/af3c5d746b307726d0de73ebe7f10545361b9cb6f75c83a1734c000e48b6264f/cache-e95d1086952cf61a.arrow
Loading cached processed dataset at /home/dcl/.cache/huggingface/datasets/wmt16/tr-en/1.0.0/af3c5d746b307726d0de73ebe7f10545361b9cb6f75c83a1734c000e48b6264f/cache-2d86a52f49d9f875.arrow


  0%|          | 0/3 [00:00<?, ?ba/s]

In [10]:
checkpoint_path = '/media/dcl/7D03F77D0BDB12B9/HTR-Firat/Transformers/'
batch_size = 4

args = Seq2SeqTrainingArguments(
    output_dir = checkpoint_path,
    evaluation_strategy = 'epoch',
    learning_rate = 1e-4,
    per_device_train_batch_size = batch_size,
    per_device_eval_batch_size = batch_size,
    weight_decay = 0.01,
    save_total_limit = 5,
    num_train_epochs = 30,
    predict_with_generate = True,
)

In [11]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [12]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}

    return result

In [13]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=preprocessed_dataset["train"],
    eval_dataset=preprocessed_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)



In [16]:
if TRAIN:
    trainer.train()

The following columns in the training set  don't have a corresponding argument in `MBartForConditionalGeneration.forward` and have been ignored: translation.
***** Running training *****
  Num examples = 205756
  Num Epochs = 20
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 1028780


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

## Test

In [30]:
finetuned_model_path = '/media/dcl/7D03F77D0BDB12B9/HTR-Firat/Transformers/models/en-tr/checkpoint-160000'
finetuned_model = MBartForConditionalGeneration.from_pretrained(finetuned_model_path)
tokenizer = MBart50TokenizerFast.from_pretrained(finetuned_model_path, 
                                                src_lang="en_XX", tgt_lang="tr_TR");

loading configuration file /media/dcl/7D03F77D0BDB12B9/HTR-Firat/Transformers/models/en-tr/checkpoint-160000/config.json
Model config MBartConfig {
  "_name_or_path": "facebook/mbart-large-50",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": true,
  "architectures": [
    "MBartForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.

In [31]:
test_trainer = Seq2SeqTrainer(
    model=finetuned_model,
    args=args,
    train_dataset=preprocessed_dataset["train"],
    eval_dataset=preprocessed_dataset["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)
test_trainer.predict(preprocessed_dataset['test'])

The following columns in the test set  don't have a corresponding argument in `MBartForConditionalGeneration.forward` and have been ignored: translation.
***** Running Prediction *****
  Num examples = 3000
  Batch size = 4


PredictionOutput(predictions=array([[     2, 250023,    990, ...,      1,      1,      1],
       [     2, 250023, 139948, ...,      1,      1,      1],
       [     2, 250023, 122930, ...,      1,      1,      1],
       ...,
       [     2, 250023,   2902, ...,      1,      1,      1],
       [     2, 250023,  61806, ...,      1,      1,      1],
       [     2, 250023,   6565, ...,      1,      1,      1]]), label_ids=array([[250023,      6, 166964, ...,      1,      1,      1],
       [250023, 120138,     25, ...,      1,      1,      1],
       [250023, 122930,      4, ...,      1,      1,      1],
       ...,
       [250023,  24627,  71799, ...,      1,      1,      1],
       [250023,  46850,  14588, ...,      1,      1,      1],
       [250023,   6565,   2544, ...,      1,      1,      1]]), metrics={'eval_loss': 0.8444002270698547, 'eval_bleu': 13.5572, 'eval_gen_len': 28.66, 'eval_runtime': 711.7295, 'eval_samples_per_second': 4.215, 'eval_steps_per_second': 1.054})