In [1]:
import random
import os
os.environ["HF_HOME"] = r"./.cache"

from transformers import EncoderDecoderModel, AutoTokenizer, GenerationConfig, Seq2SeqTrainer, Seq2SeqTrainingArguments
from tokenizers import processors
import evaluate

- Encoders
    - BERT_JA : `cl-tohoku/bert-base-japanese-v3`
    - BERT_EN : `bert-base-uncased`, `prajjwal1/bert-tiny`
- Decorders
    - GPT_JA : `rinna/japanese-gpt2-xsmall`
    - GPT_EN : `gpt2`

In [2]:
source_lng = "ja"

if source_lng == "en":
    target_lng = "ja"
    encoder = "bert-base-uncased"
    decoder = "rinna/japanese-gpt2-small"
else: 
    target_lng = "en"
    encoder = "cl-tohoku/bert-base-japanese-v3"
    decoder = "gpt2"

model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder, decoder, encoder_add_pooling_layer=False
)
model.cuda();

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.7.crossattention.c_attn.bias', 'h.7.crossattention.c_proj.weight', 'h.6.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.bias', 'h.7.crossattention.q_attn.weight', 'h.0.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.8.ln_cross_attn.weight', 'h.11.crossattention.c_attn.weight', 'h.1.ln_cross_attn.weight', 'h.5.crossattention.c_proj.weight', 'h.4.crossattention.c_attn.weight', 'h.3.crossattention.q_attn.weight', 'h.5.crossattention.c_attn.bias', 'h.6.crossattention.q_attn.weight', 'h.1.crossattention.q_attn.bias', 'h.7.crossattention.c_proj.bias', 'h.4.ln_cross_attn.bias', 'h.6.crossattention.c_attn.bias', 'h.9.crossattention.c_attn.weight', 'h.2.ln_cross_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.8.crossattention.c_proj.bias', 'h.9.crossattention.c_attn.bias', 'h.8.crossattention.q_attn.weight', 'h.1.

In [3]:
def print_model_parameters():
    t_pars, t_bytes = 0, 0
    for p in model.parameters():
        t_pars += p.nelement()
        t_bytes += p.nelement() * p.element_size()

    c_attn_pars, c_attn_bytes = 0, 0
    for layer in model.decoder.transformer.h:
        for p in layer.crossattention.parameters():
            c_attn_pars += p.nelement()
            c_attn_bytes += p.nelement() * p.element_size()
        for p in layer.ln_cross_attn.parameters():
            c_attn_pars += p.nelement()
            c_attn_bytes += p.nelement() * p.element_size()

    print(f"Total number of parameters: {t_pars:12,} ({(t_bytes / 1024**2):7,.1f}MB)")
    print(f"Cross-attention parameters: {c_attn_pars:12,} ({(c_attn_bytes / 1024**2):7,.1f}MB)")

print_model_parameters()

Total number of parameters:  263,423,232 (1,004.9MB)
Cross-attention parameters:   28,366,848 (  108.2MB)


In [4]:
def set_cross_attention_only(model):
    for p in model.parameters():
        p.requires_grad = False
    for layer in model.decoder.transformer.h:
        for p in layer.crossattention.parameters():
            p.requires_grad = True
        for p in layer.ln_cross_attn.parameters():
            p.requires_grad = True
set_cross_attention_only(model)

In [5]:
encoder_tokenizer = AutoTokenizer.from_pretrained(encoder, use_fast=True)
decoder_tokenizer = AutoTokenizer.from_pretrained(decoder, use_fast=True)
if decoder_tokenizer.pad_token_id is None:
    decoder_tokenizer.pad_token_id = decoder_tokenizer.eos_token_id

model.config.decoder_start_token_id = decoder_tokenizer.bos_token_id
model.config.eos_token_id = decoder_tokenizer.eos_token_id
model.config.pad_token_id = decoder_tokenizer.eos_token_id

# add EOS token at the end of each sentence
decoder_tokenizer._tokenizer.post_processor = processors.TemplateProcessing(
    single="$A " + decoder_tokenizer.eos_token,
    special_tokens=[(decoder_tokenizer.eos_token, decoder_tokenizer.eos_token_id)],
)

In [6]:
from utils.dataset import EnJaDatasetMaker
from transformers import DataCollatorForSeq2Seq

dataset = EnJaDatasetMaker.load_dataset("ja-en-test-1")
train_data = dataset.select(range(100))
valid_data = dataset.select(range(100, 150))

data_collator = DataCollatorForSeq2Seq(encoder_tokenizer, model=model)

In [7]:
metric = evaluate.load("sacrebleu")

def compute_metrics(preds):
    preds_ids, labels_ids = preds

    labels_ids[labels_ids == -100] = decoder_tokenizer.eos_token_id
    references = decoder_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    references = [[reference] for reference in references]

    predictions = decoder_tokenizer.batch_decode(preds_ids, skip_special_tokens=True)

    if target_lng == "ja":
        bleu_output = metric.compute(
            references=references, 
            predictions=predictions, 
            tokenize="ja-mecab"
        )
    else:
        bleu_output = metric.compute(
            references=references, 
            predictions=predictions
        )
    return bleu_output


In [8]:
MAX_LENGHT = 128
def set_decoder_configuration(gc: GenerationConfig):
    gc.no_repeat_ngram_size = 3
    gc.length_penalty = 2.0
    gc.num_beams = 3
    #gen_config.max_new_tokens = MAX_LENGHT
    gc.max_length = MAX_LENGHT * 2
    gc.min_length = 0
    gc.early_stopping = True
    gc.pad_token_id = decoder_tokenizer.eos_token_id
    gc.bos_token_id = decoder_tokenizer.bos_token_id
    gc.eos_token_id = decoder_tokenizer.eos_token_id
    return gc

gen_config = GenerationConfig()
gen_config = set_decoder_configuration(gen_config)

In [9]:
train_args = Seq2SeqTrainingArguments(
    report_to="wandb",
    run_name="testing-data-maker-2",
    num_train_epochs=50,

    logging_strategy="steps",
    logging_steps=10,

    evaluation_strategy="epoch",

    output_dir="./.ckp/",
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=4,

    optim="adamw_torch",
    bf16=True,

    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    
    group_by_length=True,
    length_column_name="length",

    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_config=gen_config,
    # torch_compile=True,
    # label_smoothing_factor=0,
    # auto_find_batch_size=False,
)

In [10]:
trainer = Seq2SeqTrainer(
    model, 
    args=train_args,
    data_collator=data_collator,
    train_dataset=train_data, 
    eval_dataset=valid_data, 
    compute_metrics=compute_metrics
)

In [11]:
model.train()
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdavidboening[0m ([33mdandd[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/650 [00:00<?, ?it/s]



{'loss': 5.0748, 'learning_rate': 4.923076923076924e-05, 'epoch': 0.77}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.712499618530273, 'eval_score': 0.26543226091295274, 'eval_counts': [57, 3, 0, 0], 'eval_totals': [887, 837, 787, 737], 'eval_precisions': [6.4261555806087935, 0.35842293906810035, 0.06353240152477764, 0.033921302578018994], 'eval_bp': 1.0, 'eval_sys_len': 887, 'eval_ref_len': 419, 'eval_runtime': 4.626, 'eval_samples_per_second': 10.808, 'eval_steps_per_second': 1.513, 'epoch': 1.0}
{'loss': 4.2984, 'learning_rate': 4.846153846153846e-05, 'epoch': 1.54}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.359025955200195, 'eval_score': 0.25666566501770993, 'eval_counts': [57, 1, 0, 0], 'eval_totals': [714, 664, 614, 564], 'eval_precisions': [7.983193277310924, 0.15060240963855423, 0.08143322475570032, 0.044326241134751775], 'eval_bp': 1.0, 'eval_sys_len': 714, 'eval_ref_len': 419, 'eval_runtime': 4.677, 'eval_samples_per_second': 10.691, 'eval_steps_per_second': 1.497, 'epoch': 2.0}
{'loss': 3.6416, 'learning_rate': 4.76923076923077e-05, 'epoch': 2.31}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.369422435760498, 'eval_score': 0.3682501826619369, 'eval_counts': [59, 1, 0, 0], 'eval_totals': [526, 476, 426, 376], 'eval_precisions': [11.216730038022813, 0.21008403361344538, 0.11737089201877934, 0.06648936170212766], 'eval_bp': 1.0, 'eval_sys_len': 526, 'eval_ref_len': 419, 'eval_runtime': 3.7755, 'eval_samples_per_second': 13.243, 'eval_steps_per_second': 1.854, 'epoch': 3.0}
{'loss': 3.277, 'learning_rate': 4.692307692307693e-05, 'epoch': 3.08}
{'loss': 3.2231, 'learning_rate': 4.615384615384616e-05, 'epoch': 3.85}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.359713554382324, 'eval_score': 0.43002868620010143, 'eval_counts': [54, 1, 0, 0], 'eval_totals': [454, 404, 354, 304], 'eval_precisions': [11.894273127753303, 0.24752475247524752, 0.14124293785310735, 0.08223684210526316], 'eval_bp': 1.0, 'eval_sys_len': 454, 'eval_ref_len': 419, 'eval_runtime': 3.943, 'eval_samples_per_second': 12.681, 'eval_steps_per_second': 1.775, 'epoch': 4.0}
{'loss': 2.9748, 'learning_rate': 4.538461538461539e-05, 'epoch': 4.62}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.358962535858154, 'eval_score': 0.3592707422701168, 'eval_counts': [73, 1, 0, 0], 'eval_totals': [562, 512, 462, 412], 'eval_precisions': [12.98932384341637, 0.1953125, 0.10822510822510822, 0.06067961165048544], 'eval_bp': 1.0, 'eval_sys_len': 562, 'eval_ref_len': 419, 'eval_runtime': 3.684, 'eval_samples_per_second': 13.572, 'eval_steps_per_second': 1.9, 'epoch': 5.0}
{'loss': 2.9757, 'learning_rate': 4.461538461538462e-05, 'epoch': 5.38}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.426562786102295, 'eval_score': 0.39200336299958993, 'eval_counts': [63, 1, 0, 0], 'eval_totals': [506, 456, 406, 356], 'eval_precisions': [12.450592885375494, 0.21929824561403508, 0.12315270935960591, 0.0702247191011236], 'eval_bp': 1.0, 'eval_sys_len': 506, 'eval_ref_len': 419, 'eval_runtime': 3.788, 'eval_samples_per_second': 13.2, 'eval_steps_per_second': 1.848, 'epoch': 6.0}
{'loss': 2.8348, 'learning_rate': 4.384615384615385e-05, 'epoch': 6.15}
{'loss': 2.737, 'learning_rate': 4.3076923076923084e-05, 'epoch': 6.92}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.478742599487305, 'eval_score': 0.3577700826577314, 'eval_counts': [54, 1, 0, 0], 'eval_totals': [529, 479, 429, 379], 'eval_precisions': [10.207939508506616, 0.20876826722338204, 0.11655011655011654, 0.06596306068601583], 'eval_bp': 1.0, 'eval_sys_len': 529, 'eval_ref_len': 419, 'eval_runtime': 3.695, 'eval_samples_per_second': 13.532, 'eval_steps_per_second': 1.894, 'epoch': 7.0}
{'loss': 2.6941, 'learning_rate': 4.230769230769231e-05, 'epoch': 7.69}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.561590194702148, 'eval_score': 0.3867658327098837, 'eval_counts': [61, 2, 0, 0], 'eval_totals': [589, 539, 489, 439], 'eval_precisions': [10.356536502546689, 0.37105751391465674, 0.10224948875255624, 0.05694760820045558], 'eval_bp': 1.0, 'eval_sys_len': 589, 'eval_ref_len': 419, 'eval_runtime': 4.096, 'eval_samples_per_second': 12.207, 'eval_steps_per_second': 1.709, 'epoch': 8.0}
{'loss': 2.5544, 'learning_rate': 4.1538461538461544e-05, 'epoch': 8.46}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.527477264404297, 'eval_score': 0.33367002668537243, 'eval_counts': [66, 1, 0, 0], 'eval_totals': [586, 536, 486, 436], 'eval_precisions': [11.262798634812286, 0.1865671641791045, 0.102880658436214, 0.05733944954128441], 'eval_bp': 1.0, 'eval_sys_len': 586, 'eval_ref_len': 419, 'eval_runtime': 3.829, 'eval_samples_per_second': 13.058, 'eval_steps_per_second': 1.828, 'epoch': 9.0}
{'loss': 2.6248, 'learning_rate': 4.0769230769230773e-05, 'epoch': 9.23}
{'loss': 2.6224, 'learning_rate': 4e-05, 'epoch': 10.0}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.5946855545043945, 'eval_score': 0.35976669822507057, 'eval_counts': [66, 2, 0, 0], 'eval_totals': [638, 588, 538, 488], 'eval_precisions': [10.344827586206897, 0.3401360544217687, 0.09293680297397769, 0.05122950819672131], 'eval_bp': 1.0, 'eval_sys_len': 638, 'eval_ref_len': 419, 'eval_runtime': 4.016, 'eval_samples_per_second': 12.45, 'eval_steps_per_second': 1.743, 'epoch': 10.0}
{'loss': 2.5091, 'learning_rate': 3.923076923076923e-05, 'epoch': 10.77}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.6237030029296875, 'eval_score': 0.3644231454585382, 'eval_counts': [68, 2, 0, 0], 'eval_totals': [635, 585, 535, 485], 'eval_precisions': [10.708661417322835, 0.3418803418803419, 0.09345794392523364, 0.05154639175257732], 'eval_bp': 1.0, 'eval_sys_len': 635, 'eval_ref_len': 419, 'eval_runtime': 3.905, 'eval_samples_per_second': 12.804, 'eval_steps_per_second': 1.793, 'epoch': 11.0}
{'loss': 2.4271, 'learning_rate': 3.846153846153846e-05, 'epoch': 11.54}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.719987869262695, 'eval_score': 0.5090402786824376, 'eval_counts': [62, 3, 0, 0], 'eval_totals': [510, 460, 410, 360], 'eval_precisions': [12.156862745098039, 0.6521739130434783, 0.12195121951219512, 0.06944444444444445], 'eval_bp': 1.0, 'eval_sys_len': 510, 'eval_ref_len': 419, 'eval_runtime': 3.971, 'eval_samples_per_second': 12.591, 'eval_steps_per_second': 1.763, 'epoch': 12.0}
{'loss': 2.4165, 'learning_rate': 3.769230769230769e-05, 'epoch': 12.31}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.7869768142700195, 'eval_score': 0.41593060369359686, 'eval_counts': [71, 3, 0, 0], 'eval_totals': [624, 574, 524, 474], 'eval_precisions': [11.378205128205128, 0.5226480836236934, 0.09541984732824428, 0.052742616033755275], 'eval_bp': 1.0, 'eval_sys_len': 624, 'eval_ref_len': 419, 'eval_runtime': 3.774, 'eval_samples_per_second': 13.248, 'eval_steps_per_second': 1.855, 'epoch': 13.0}
{'loss': 2.3771, 'learning_rate': 3.692307692307693e-05, 'epoch': 13.08}
{'loss': 2.2185, 'learning_rate': 3.615384615384615e-05, 'epoch': 13.85}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.7367682456970215, 'eval_score': 0.33633505655147883, 'eval_counts': [66, 1, 0, 0], 'eval_totals': [582, 532, 482, 432], 'eval_precisions': [11.34020618556701, 0.18796992481203006, 0.1037344398340249, 0.05787037037037037], 'eval_bp': 1.0, 'eval_sys_len': 582, 'eval_ref_len': 419, 'eval_runtime': 3.874, 'eval_samples_per_second': 12.907, 'eval_steps_per_second': 1.807, 'epoch': 14.0}
{'loss': 2.2566, 'learning_rate': 3.538461538461539e-05, 'epoch': 14.62}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.873305797576904, 'eval_score': 0.40574556898209124, 'eval_counts': [67, 1, 0, 0], 'eval_totals': [498, 448, 398, 348], 'eval_precisions': [13.453815261044177, 0.22321428571428573, 0.12562814070351758, 0.07183908045977011], 'eval_bp': 1.0, 'eval_sys_len': 498, 'eval_ref_len': 419, 'eval_runtime': 3.878, 'eval_samples_per_second': 12.893, 'eval_steps_per_second': 1.805, 'epoch': 15.0}
{'loss': 2.1995, 'learning_rate': 3.461538461538462e-05, 'epoch': 15.38}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.812005519866943, 'eval_score': 0.4864327321272828, 'eval_counts': [59, 1, 0, 0], 'eval_totals': [416, 366, 316, 266], 'eval_precisions': [14.182692307692308, 0.273224043715847, 0.15822784810126583, 0.09398496240601503], 'eval_bp': 0.9928144022869276, 'eval_sys_len': 416, 'eval_ref_len': 419, 'eval_runtime': 3.712, 'eval_samples_per_second': 13.47, 'eval_steps_per_second': 1.886, 'epoch': 16.0}
{'loss': 2.1973, 'learning_rate': 3.384615384615385e-05, 'epoch': 16.15}
{'loss': 2.1581, 'learning_rate': 3.307692307692308e-05, 'epoch': 16.92}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.855074405670166, 'eval_score': 0.6219279069640946, 'eval_counts': [66, 3, 0, 0], 'eval_totals': [438, 388, 338, 288], 'eval_precisions': [15.068493150684931, 0.7731958762886598, 0.14792899408284024, 0.08680555555555555], 'eval_bp': 1.0, 'eval_sys_len': 438, 'eval_ref_len': 419, 'eval_runtime': 3.668, 'eval_samples_per_second': 13.631, 'eval_steps_per_second': 1.908, 'epoch': 17.0}
{'loss': 2.0602, 'learning_rate': 3.230769230769231e-05, 'epoch': 17.69}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.951928615570068, 'eval_score': 0.5921035499766193, 'eval_counts': [64, 3, 0, 0], 'eval_totals': [453, 403, 353, 303], 'eval_precisions': [14.1280353200883, 0.7444168734491315, 0.141643059490085, 0.08250825082508251], 'eval_bp': 1.0, 'eval_sys_len': 453, 'eval_ref_len': 419, 'eval_runtime': 4.055, 'eval_samples_per_second': 12.33, 'eval_steps_per_second': 1.726, 'epoch': 18.0}
{'loss': 2.0585, 'learning_rate': 3.153846153846154e-05, 'epoch': 18.46}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.954809665679932, 'eval_score': 0.6003368642515899, 'eval_counts': [75, 4, 0, 0], 'eval_totals': [491, 441, 391, 341], 'eval_precisions': [15.274949083503055, 0.9070294784580499, 0.1278772378516624, 0.07331378299120235], 'eval_bp': 1.0, 'eval_sys_len': 491, 'eval_ref_len': 419, 'eval_runtime': 3.692, 'eval_samples_per_second': 13.543, 'eval_steps_per_second': 1.896, 'epoch': 19.0}
{'loss': 2.0856, 'learning_rate': 3.0769230769230774e-05, 'epoch': 19.23}
{'loss': 1.9323, 'learning_rate': 3e-05, 'epoch': 20.0}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.999368190765381, 'eval_score': 0.6429096749676323, 'eval_counts': [66, 4, 0, 0], 'eval_totals': [452, 402, 352, 302], 'eval_precisions': [14.601769911504425, 0.9950248756218906, 0.14204545454545456, 0.08278145695364239], 'eval_bp': 1.0, 'eval_sys_len': 452, 'eval_ref_len': 419, 'eval_runtime': 3.664, 'eval_samples_per_second': 13.646, 'eval_steps_per_second': 1.91, 'epoch': 20.0}
{'loss': 1.9175, 'learning_rate': 2.9230769230769234e-05, 'epoch': 20.77}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.029788970947266, 'eval_score': 0.624825182680384, 'eval_counts': [75, 4, 0, 0], 'eval_totals': [475, 425, 375, 325], 'eval_precisions': [15.789473684210526, 0.9411764705882353, 0.13333333333333333, 0.07692307692307693], 'eval_bp': 1.0, 'eval_sys_len': 475, 'eval_ref_len': 419, 'eval_runtime': 3.878, 'eval_samples_per_second': 12.893, 'eval_steps_per_second': 1.805, 'epoch': 21.0}
{'loss': 1.8358, 'learning_rate': 2.846153846153846e-05, 'epoch': 21.54}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 4.993334770202637, 'eval_score': 0.47677910954727254, 'eval_counts': [67, 2, 0, 0], 'eval_totals': [503, 453, 403, 353], 'eval_precisions': [13.320079522862823, 0.44150110375275936, 0.12406947890818859, 0.0708215297450425], 'eval_bp': 1.0, 'eval_sys_len': 503, 'eval_ref_len': 419, 'eval_runtime': 3.97, 'eval_samples_per_second': 12.594, 'eval_steps_per_second': 1.763, 'epoch': 22.0}
{'loss': 1.9432, 'learning_rate': 2.7692307692307694e-05, 'epoch': 22.31}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.044647693634033, 'eval_score': 0.4962343838342137, 'eval_counts': [66, 2, 0, 0], 'eval_totals': [485, 435, 385, 335], 'eval_precisions': [13.608247422680412, 0.45977011494252873, 0.12987012987012986, 0.07462686567164178], 'eval_bp': 1.0, 'eval_sys_len': 485, 'eval_ref_len': 419, 'eval_runtime': 3.471, 'eval_samples_per_second': 14.405, 'eval_steps_per_second': 2.017, 'epoch': 23.0}
{'loss': 1.8309, 'learning_rate': 2.6923076923076923e-05, 'epoch': 23.08}
{'loss': 1.8795, 'learning_rate': 2.6153846153846157e-05, 'epoch': 23.85}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.080243110656738, 'eval_score': 0.5710686924930134, 'eval_counts': [55, 2, 0, 0], 'eval_totals': [408, 358, 308, 258], 'eval_precisions': [13.480392156862745, 0.5586592178770949, 0.16233766233766234, 0.09689922480620156], 'eval_bp': 0.9733994133018776, 'eval_sys_len': 408, 'eval_ref_len': 419, 'eval_runtime': 3.518, 'eval_samples_per_second': 14.213, 'eval_steps_per_second': 1.99, 'epoch': 24.0}
{'loss': 1.7446, 'learning_rate': 2.5384615384615383e-05, 'epoch': 24.62}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.142169952392578, 'eval_score': 0.776775845608334, 'eval_counts': [75, 3, 1, 0], 'eval_totals': [498, 448, 398, 348], 'eval_precisions': [15.060240963855422, 0.6696428571428571, 0.25125628140703515, 0.14367816091954022], 'eval_bp': 1.0, 'eval_sys_len': 498, 'eval_ref_len': 419, 'eval_runtime': 3.853, 'eval_samples_per_second': 12.977, 'eval_steps_per_second': 1.817, 'epoch': 25.0}
{'loss': 1.7977, 'learning_rate': 2.461538461538462e-05, 'epoch': 25.38}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.103691577911377, 'eval_score': 0.3801444626106224, 'eval_counts': [67, 1, 0, 0], 'eval_totals': [526, 476, 426, 376], 'eval_precisions': [12.737642585551331, 0.21008403361344538, 0.11737089201877934, 0.06648936170212766], 'eval_bp': 1.0, 'eval_sys_len': 526, 'eval_ref_len': 419, 'eval_runtime': 4.099, 'eval_samples_per_second': 12.198, 'eval_steps_per_second': 1.708, 'epoch': 26.0}
{'loss': 1.658, 'learning_rate': 2.384615384615385e-05, 'epoch': 26.15}
{'loss': 1.6706, 'learning_rate': 2.307692307692308e-05, 'epoch': 26.92}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.151328086853027, 'eval_score': 0.588453926760384, 'eval_counts': [77, 4, 0, 0], 'eval_totals': [502, 452, 402, 352], 'eval_precisions': [15.338645418326694, 0.8849557522123894, 0.12437810945273632, 0.07102272727272728], 'eval_bp': 1.0, 'eval_sys_len': 502, 'eval_ref_len': 419, 'eval_runtime': 4.22, 'eval_samples_per_second': 11.848, 'eval_steps_per_second': 1.659, 'epoch': 27.0}
{'loss': 1.7087, 'learning_rate': 2.230769230769231e-05, 'epoch': 27.69}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.172948837280273, 'eval_score': 1.0116814559184961, 'eval_counts': [67, 5, 1, 0], 'eval_totals': [435, 385, 335, 285], 'eval_precisions': [15.402298850574713, 1.2987012987012987, 0.29850746268656714, 0.17543859649122806], 'eval_bp': 1.0, 'eval_sys_len': 435, 'eval_ref_len': 419, 'eval_runtime': 3.689, 'eval_samples_per_second': 13.554, 'eval_steps_per_second': 1.898, 'epoch': 28.0}
{'loss': 1.5445, 'learning_rate': 2.1538461538461542e-05, 'epoch': 28.46}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.169295787811279, 'eval_score': 0.8354052492120418, 'eval_counts': [71, 4, 1, 0], 'eval_totals': [492, 442, 392, 342], 'eval_precisions': [14.43089430894309, 0.9049773755656109, 0.25510204081632654, 0.14619883040935672], 'eval_bp': 1.0, 'eval_sys_len': 492, 'eval_ref_len': 419, 'eval_runtime': 4.309, 'eval_samples_per_second': 11.604, 'eval_steps_per_second': 1.625, 'epoch': 29.0}
{'loss': 1.7124, 'learning_rate': 2.0769230769230772e-05, 'epoch': 29.23}
{'loss': 1.5766, 'learning_rate': 2e-05, 'epoch': 30.0}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.208347320556641, 'eval_score': 0.6136973681565577, 'eval_counts': [64, 3, 0, 0], 'eval_totals': [440, 390, 340, 290], 'eval_precisions': [14.545454545454545, 0.7692307692307693, 0.14705882352941177, 0.08620689655172414], 'eval_bp': 1.0, 'eval_sys_len': 440, 'eval_ref_len': 419, 'eval_runtime': 3.5, 'eval_samples_per_second': 14.286, 'eval_steps_per_second': 2.0, 'epoch': 30.0}
{'loss': 1.5815, 'learning_rate': 1.923076923076923e-05, 'epoch': 30.77}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.329334735870361, 'eval_score': 0.9317341023831286, 'eval_counts': [72, 4, 1, 0], 'eval_totals': [451, 401, 351, 301], 'eval_precisions': [15.964523281596453, 0.9975062344139651, 0.2849002849002849, 0.16611295681063123], 'eval_bp': 1.0, 'eval_sys_len': 451, 'eval_ref_len': 419, 'eval_runtime': 3.789, 'eval_samples_per_second': 13.196, 'eval_steps_per_second': 1.847, 'epoch': 31.0}
{'loss': 1.5245, 'learning_rate': 1.8461538461538465e-05, 'epoch': 31.54}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.242983818054199, 'eval_score': 0.8664325886803378, 'eval_counts': [72, 6, 1, 0], 'eval_totals': [521, 471, 421, 371], 'eval_precisions': [13.81957773512476, 1.2738853503184713, 0.2375296912114014, 0.1347708894878706], 'eval_bp': 1.0, 'eval_sys_len': 521, 'eval_ref_len': 419, 'eval_runtime': 3.75, 'eval_samples_per_second': 13.333, 'eval_steps_per_second': 1.867, 'epoch': 32.0}
{'loss': 1.5225, 'learning_rate': 1.7692307692307694e-05, 'epoch': 32.31}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.303744316101074, 'eval_score': 0.4887721448238277, 'eval_counts': [74, 2, 0, 0], 'eval_totals': [503, 453, 403, 353], 'eval_precisions': [14.711729622266402, 0.44150110375275936, 0.12406947890818859, 0.0708215297450425], 'eval_bp': 1.0, 'eval_sys_len': 503, 'eval_ref_len': 419, 'eval_runtime': 4.129, 'eval_samples_per_second': 12.109, 'eval_steps_per_second': 1.695, 'epoch': 33.0}
{'loss': 1.4705, 'learning_rate': 1.6923076923076924e-05, 'epoch': 33.08}
{'loss': 1.5087, 'learning_rate': 1.6153846153846154e-05, 'epoch': 33.85}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.3946614265441895, 'eval_score': 0.563288059575354, 'eval_counts': [76, 3, 0, 0], 'eval_totals': [489, 439, 389, 339], 'eval_precisions': [15.541922290388548, 0.683371298405467, 0.12853470437017994, 0.07374631268436578], 'eval_bp': 1.0, 'eval_sys_len': 489, 'eval_ref_len': 419, 'eval_runtime': 3.944, 'eval_samples_per_second': 12.677, 'eval_steps_per_second': 1.775, 'epoch': 34.0}
{'loss': 1.3685, 'learning_rate': 1.5384615384615387e-05, 'epoch': 34.62}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.413300037384033, 'eval_score': 0.5715221081507122, 'eval_counts': [70, 3, 0, 0], 'eval_totals': [475, 425, 375, 325], 'eval_precisions': [14.736842105263158, 0.7058823529411765, 0.13333333333333333, 0.07692307692307693], 'eval_bp': 1.0, 'eval_sys_len': 475, 'eval_ref_len': 419, 'eval_runtime': 3.831, 'eval_samples_per_second': 13.051, 'eval_steps_per_second': 1.827, 'epoch': 35.0}
{'loss': 1.6223, 'learning_rate': 1.4615384615384617e-05, 'epoch': 35.38}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.30155086517334, 'eval_score': 0.8541862524100857, 'eval_counts': [71, 4, 1, 0], 'eval_totals': [483, 433, 383, 333], 'eval_precisions': [14.699792960662526, 0.9237875288683602, 0.26109660574412535, 0.15015015015015015], 'eval_bp': 1.0, 'eval_sys_len': 483, 'eval_ref_len': 419, 'eval_runtime': 3.838, 'eval_samples_per_second': 13.028, 'eval_steps_per_second': 1.824, 'epoch': 36.0}
{'loss': 1.3625, 'learning_rate': 1.3846153846153847e-05, 'epoch': 36.15}
{'loss': 1.4775, 'learning_rate': 1.3076923076923078e-05, 'epoch': 36.92}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.376684188842773, 'eval_score': 0.5634840094424891, 'eval_counts': [69, 2, 0, 0], 'eval_totals': [441, 391, 341, 291], 'eval_precisions': [15.646258503401361, 0.5115089514066496, 0.1466275659824047, 0.0859106529209622], 'eval_bp': 1.0, 'eval_sys_len': 441, 'eval_ref_len': 419, 'eval_runtime': 3.427, 'eval_samples_per_second': 14.59, 'eval_steps_per_second': 2.043, 'epoch': 37.0}
{'loss': 1.449, 'learning_rate': 1.230769230769231e-05, 'epoch': 37.69}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.3497538566589355, 'eval_score': 0.5612505867198834, 'eval_counts': [72, 3, 0, 0], 'eval_totals': [485, 435, 385, 335], 'eval_precisions': [14.845360824742269, 0.6896551724137931, 0.12987012987012986, 0.07462686567164178], 'eval_bp': 1.0, 'eval_sys_len': 485, 'eval_ref_len': 419, 'eval_runtime': 3.798, 'eval_samples_per_second': 13.165, 'eval_steps_per_second': 1.843, 'epoch': 38.0}
{'loss': 1.459, 'learning_rate': 1.153846153846154e-05, 'epoch': 38.46}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.309025764465332, 'eval_score': 0.6173296199143199, 'eval_counts': [74, 3, 0, 0], 'eval_totals': [451, 401, 351, 301], 'eval_precisions': [16.407982261640797, 0.7481296758104738, 0.14245014245014245, 0.08305647840531562], 'eval_bp': 1.0, 'eval_sys_len': 451, 'eval_ref_len': 419, 'eval_runtime': 3.643, 'eval_samples_per_second': 13.725, 'eval_steps_per_second': 1.921, 'epoch': 39.0}
{'loss': 1.4905, 'learning_rate': 1.0769230769230771e-05, 'epoch': 39.23}
{'loss': 1.3818, 'learning_rate': 1e-05, 'epoch': 40.0}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.292601108551025, 'eval_score': 1.0042037941403148, 'eval_counts': [76, 6, 1, 0], 'eval_totals': [466, 416, 366, 316], 'eval_precisions': [16.30901287553648, 1.4423076923076923, 0.273224043715847, 0.15822784810126583], 'eval_bp': 1.0, 'eval_sys_len': 466, 'eval_ref_len': 419, 'eval_runtime': 3.893, 'eval_samples_per_second': 12.844, 'eval_steps_per_second': 1.798, 'epoch': 40.0}
{'loss': 1.3986, 'learning_rate': 9.230769230769232e-06, 'epoch': 40.77}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.379519462585449, 'eval_score': 1.059000280434136, 'eval_counts': [75, 6, 1, 0], 'eval_totals': [445, 395, 345, 295], 'eval_precisions': [16.853932584269664, 1.518987341772152, 0.2898550724637681, 0.1694915254237288], 'eval_bp': 1.0, 'eval_sys_len': 445, 'eval_ref_len': 419, 'eval_runtime': 3.665, 'eval_samples_per_second': 13.642, 'eval_steps_per_second': 1.91, 'epoch': 41.0}
{'loss': 1.3688, 'learning_rate': 8.461538461538462e-06, 'epoch': 41.54}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.3860297203063965, 'eval_score': 0.6629148763833748, 'eval_counts': [73, 4, 0, 0], 'eval_totals': [450, 400, 350, 300], 'eval_precisions': [16.22222222222222, 1.0, 0.14285714285714285, 0.08333333333333333], 'eval_bp': 1.0, 'eval_sys_len': 450, 'eval_ref_len': 419, 'eval_runtime': 3.6515, 'eval_samples_per_second': 13.693, 'eval_steps_per_second': 1.917, 'epoch': 42.0}
{'loss': 1.3938, 'learning_rate': 7.692307692307694e-06, 'epoch': 42.31}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.411128044128418, 'eval_score': 0.9812830089300275, 'eval_counts': [74, 5, 1, 0], 'eval_totals': [455, 405, 355, 305], 'eval_precisions': [16.263736263736263, 1.2345679012345678, 0.28169014084507044, 0.16393442622950818], 'eval_bp': 1.0, 'eval_sys_len': 455, 'eval_ref_len': 419, 'eval_runtime': 3.769, 'eval_samples_per_second': 13.266, 'eval_steps_per_second': 1.857, 'epoch': 43.0}
{'loss': 1.3794, 'learning_rate': 6.923076923076923e-06, 'epoch': 43.08}
{'loss': 1.2948, 'learning_rate': 6.153846153846155e-06, 'epoch': 43.85}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.445441246032715, 'eval_score': 0.965148048755409, 'eval_counts': [70, 5, 1, 0], 'eval_totals': [456, 406, 356, 306], 'eval_precisions': [15.350877192982455, 1.2315270935960592, 0.2808988764044944, 0.16339869281045752], 'eval_bp': 1.0, 'eval_sys_len': 456, 'eval_ref_len': 419, 'eval_runtime': 3.543, 'eval_samples_per_second': 14.112, 'eval_steps_per_second': 1.976, 'epoch': 44.0}
{'loss': 1.3925, 'learning_rate': 5.3846153846153855e-06, 'epoch': 44.62}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.460047721862793, 'eval_score': 0.6291620557007614, 'eval_counts': [71, 4, 0, 0], 'eval_totals': [467, 417, 367, 317], 'eval_precisions': [15.203426124197001, 0.9592326139088729, 0.1362397820163488, 0.07886435331230283], 'eval_bp': 1.0, 'eval_sys_len': 467, 'eval_ref_len': 419, 'eval_runtime': 3.88, 'eval_samples_per_second': 12.886, 'eval_steps_per_second': 1.804, 'epoch': 45.0}
{'loss': 1.2882, 'learning_rate': 4.615384615384616e-06, 'epoch': 45.38}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.439296722412109, 'eval_score': 0.9813310545667665, 'eval_counts': [73, 6, 1, 0], 'eval_totals': [471, 421, 371, 321], 'eval_precisions': [15.498938428874734, 1.4251781472684086, 0.2695417789757412, 0.1557632398753894], 'eval_bp': 1.0, 'eval_sys_len': 471, 'eval_ref_len': 419, 'eval_runtime': 3.902, 'eval_samples_per_second': 12.814, 'eval_steps_per_second': 1.794, 'epoch': 46.0}
{'loss': 1.3109, 'learning_rate': 3.846153846153847e-06, 'epoch': 46.15}
{'loss': 1.3771, 'learning_rate': 3.0769230769230774e-06, 'epoch': 46.92}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.44369649887085, 'eval_score': 0.6357869170559859, 'eval_counts': [71, 4, 0, 0], 'eval_totals': [463, 413, 363, 313], 'eval_precisions': [15.334773218142548, 0.9685230024213075, 0.13774104683195593, 0.07987220447284345], 'eval_bp': 1.0, 'eval_sys_len': 463, 'eval_ref_len': 419, 'eval_runtime': 3.793, 'eval_samples_per_second': 13.182, 'eval_steps_per_second': 1.846, 'epoch': 47.0}
{'loss': 1.2329, 'learning_rate': 2.307692307692308e-06, 'epoch': 47.69}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.443950653076172, 'eval_score': 0.6301119029236261, 'eval_counts': [76, 4, 0, 0], 'eval_totals': [473, 423, 373, 323], 'eval_precisions': [16.067653276955603, 0.9456264775413712, 0.13404825737265416, 0.07739938080495357], 'eval_bp': 1.0, 'eval_sys_len': 473, 'eval_ref_len': 419, 'eval_runtime': 3.858, 'eval_samples_per_second': 12.96, 'eval_steps_per_second': 1.814, 'epoch': 48.0}
{'loss': 1.3539, 'learning_rate': 1.5384615384615387e-06, 'epoch': 48.46}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.447695732116699, 'eval_score': 1.0001238489314928, 'eval_counts': [74, 5, 1, 0], 'eval_totals': [448, 398, 348, 298], 'eval_precisions': [16.517857142857142, 1.256281407035176, 0.28735632183908044, 0.16778523489932887], 'eval_bp': 1.0, 'eval_sys_len': 448, 'eval_ref_len': 419, 'eval_runtime': 3.656, 'eval_samples_per_second': 13.676, 'eval_steps_per_second': 1.915, 'epoch': 49.0}
{'loss': 1.3264, 'learning_rate': 7.692307692307694e-07, 'epoch': 49.23}
{'loss': 1.2709, 'learning_rate': 0.0, 'epoch': 50.0}


  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 5.450464248657227, 'eval_score': 0.5844467086477805, 'eval_counts': [75, 3, 0, 0], 'eval_totals': [473, 423, 373, 323], 'eval_precisions': [15.856236786469344, 0.7092198581560284, 0.13404825737265416, 0.07739938080495357], 'eval_bp': 1.0, 'eval_sys_len': 473, 'eval_ref_len': 419, 'eval_runtime': 3.808, 'eval_samples_per_second': 13.13, 'eval_steps_per_second': 1.838, 'epoch': 50.0}
{'train_runtime': 261.0817, 'train_samples_per_second': 19.151, 'train_steps_per_second': 2.49, 'train_loss': 1.9977916130652795, 'epoch': 50.0}


TrainOutput(global_step=650, training_loss=1.9977916130652795, metrics={'train_runtime': 261.0817, 'train_samples_per_second': 19.151, 'train_steps_per_second': 2.49, 'train_loss': 1.9977916130652795, 'epoch': 50.0})

In [12]:
model.cuda()
model.eval()
train_out = trainer.predict(train_data)
valid_out = trainer.predict(valid_data)

print("Train:", compute_metrics((train_out.predictions, train_data["labels"])))
print("Valid:", compute_metrics((valid_out.predictions, valid_data["labels"])))

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train: {'score': 9.366630959325185, 'counts': [215, 67, 36, 24], 'totals': [794, 694, 594, 494], 'precisions': [27.07808564231738, 9.654178674351584, 6.0606060606060606, 4.8582995951417], 'bp': 1.0, 'sys_len': 794, 'ref_len': 740}
Valid: {'score': 0.5805108167112687, 'counts': [73, 3, 0, 0], 'totals': [473, 423, 373, 323], 'precisions': [15.433403805496829, 0.7092198581560284, 0.13404825737265416, 0.07739938080495357], 'bp': 1.0, 'sys_len': 473, 'ref_len': 409}


In [13]:
train_decode = decoder_tokenizer.batch_decode(train_out.predictions, skip_special_tokens=True)
valid_decode = decoder_tokenizer.batch_decode(valid_out.predictions, skip_special_tokens=True)

In [14]:
def print_pairs(dataset, generation, sample=5):
    assert len(dataset) == len(generation), "Invalid combination!"

    sample_ids = random.sample(range(len(dataset)), sample)
    for i, sid in enumerate(sample_ids):
        print(f"Sentence #{i} [id={sid}]")
        print(
            f"\tOriginal:  {dataset['source'][sid]}\n"
            f"\tTarget:    {dataset['target'][sid]}\n"
            f"\tGenerated: {generation[sid]}\n"
        )
    return

print_pairs(train_data, train_decode, sample=3)

Sentence #0 [id=81]
	Original:  花を入れるものには何本の花が入っていますか。
	Target:    how many flowers are there in the vase ?
	Generated: the book is full of historical facts.

Sentence #1 [id=14]
	Original:  私たちはローマで楽しく過ごしてます。
	Target:    we are having a nice time in rome .
	Generated: i like coffee.

Sentence #2 [id=3]
	Original:  彼女はとても美しい。その上、とても賢い。
	Target:    she is very beautiful , and what is more , very wise .
	Generated: i am very young.

