In [1]:
import random
import os
os.environ["HF_HOME"] = r"./.cache"

from transformers import EncoderDecoderModel, AutoTokenizer, GenerationConfig, Seq2SeqTrainer, Seq2SeqTrainingArguments
from tokenizers import processors
import evaluate

- Encoders
    - BERT_JA : `cl-tohoku/bert-base-japanese-v3`
    - BERT_EN : `bert-base-uncased`, `prajjwal1/bert-tiny`
- Decorders
    - GPT_JA : `rinna/japanese-gpt2-xsmall`
    - GPT_EN : `gpt2`

In [2]:
source_lng = "ja"

if source_lng == "en":
    target_lng = "ja"
    encoder = "bert-base-uncased"
    decoder = "rinna/japanese-gpt2-small"
else: 
    target_lng = "en"
    encoder = "cl-tohoku/bert-base-japanese-v3"
    decoder = "gpt2"

model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder, decoder, encoder_add_pooling_layer=False
)
model.cuda();

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.11.ln_cross_attn.bias', 'h.8.crossattention.c_proj.weight', 'h.5.crossattention.q_attn.weight', 'h.2.crossattention.c_attn.weight', 'h.3.crossattention.q_attn.bias', 'h.3.crossattention.c_proj.weight', 'h.4.ln_cross_attn.weight', 'h.3.crossattention.c_proj.bias', 'h.2.ln_cross_attn.weight', 'h.11.crossattention.c_proj.weight', 'h.10.crossattention.c_proj.weight', 'h.2.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.crossattention.c_attn.weight', 'h.10.ln_cross_attn.weight', 'h.3.ln_cross_attn.bias', 'h.5.ln_cross_attn.weight', 'h.8.crossattention.c_attn.bias', 'h.4.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.bias', 'h.10.crossattention.c_proj.bias', 'h.1.crossattention.c_attn.weight', 'h.10.crossattention.q_attn.bias', 'h.7.crossattention.c_proj.weight', 'h.11.ln_cross_attn.weight', 'h.6.crossattention.c_proj.bias', 'h.0.crossattention.

In [3]:
def print_model_parameters():
    t_pars, t_bytes = 0, 0
    for p in model.parameters():
        t_pars += p.nelement()
        t_bytes += p.nelement() * p.element_size()

    c_attn_pars, c_attn_bytes = 0, 0
    for layer in model.decoder.transformer.h:
        for p in layer.crossattention.parameters():
            c_attn_pars += p.nelement()
            c_attn_bytes += p.nelement() * p.element_size()
        for p in layer.ln_cross_attn.parameters():
            c_attn_pars += p.nelement()
            c_attn_bytes += p.nelement() * p.element_size()

    print(f"Total number of parameters: {t_pars:12,} ({(t_bytes / 1024**2):7,.1f}MB)")
    print(f"Cross-attention parameters: {c_attn_pars:12,} ({(c_attn_bytes / 1024**2):7,.1f}MB)")

print_model_parameters()

Total number of parameters:  263,423,232 (1,004.9MB)
Cross-attention parameters:   28,366,848 (  108.2MB)


In [4]:
def set_cross_attention_only(model):
    for p in model.parameters():
        p.requires_grad = False
    for layer in model.decoder.transformer.h:
        for p in layer.crossattention.parameters():
            p.requires_grad = True
        for p in layer.ln_cross_attn.parameters():
            p.requires_grad = True
# set_cross_attention_only(model)

In [5]:
encoder_tokenizer = AutoTokenizer.from_pretrained(encoder, use_fast=True)
decoder_tokenizer = AutoTokenizer.from_pretrained(decoder, use_fast=True)
if decoder_tokenizer.pad_token_id is None:
    decoder_tokenizer.pad_token_id = decoder_tokenizer.eos_token_id

model.config.decoder_start_token_id = decoder_tokenizer.bos_token_id
model.config.eos_token_id = decoder_tokenizer.eos_token_id
model.config.pad_token_id = decoder_tokenizer.eos_token_id

# add EOS token at the end of each sentence
decoder_tokenizer._tokenizer.post_processor = processors.TemplateProcessing(
    single="$A " + decoder_tokenizer.eos_token,
    special_tokens=[(decoder_tokenizer.eos_token, decoder_tokenizer.eos_token_id)],
)

In [6]:
from utils.dataset import EnJaDatasetMaker
from transformers import DataCollatorForSeq2Seq

dataset = EnJaDatasetMaker.load_dataset("ja-en-test-1")
train_data = dataset.select(range(100))
valid_data = dataset.select(range(100, 150))

data_collator = DataCollatorForSeq2Seq(encoder_tokenizer, model=model)

In [7]:
metric = evaluate.load("sacrebleu")

def compute_metrics(preds):
    preds_ids, labels_ids = preds

    labels_ids[labels_ids == -100] = decoder_tokenizer.eos_token_id
    references = decoder_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    references = [[reference] for reference in references]

    predictions = decoder_tokenizer.batch_decode(preds_ids, skip_special_tokens=True)

    if target_lng == "ja":
        bleu_output = metric.compute(
            references=references, 
            predictions=predictions, 
            tokenize="ja-mecab"
        )
    else:
        bleu_output = metric.compute(
            references=references, 
            predictions=predictions
        )
    return bleu_output


In [8]:
MAX_LENGHT = 128
def set_decoder_configuration(gc: GenerationConfig):
    gc.no_repeat_ngram_size = 3
    gc.length_penalty = 2.0
    gc.num_beams = 3
    #gen_config.max_new_tokens = MAX_LENGHT
    gc.max_length = MAX_LENGHT * 2
    gc.min_length = 0
    gc.early_stopping = True
    gc.pad_token_id = decoder_tokenizer.eos_token_id
    gc.bos_token_id = decoder_tokenizer.bos_token_id
    gc.eos_token_id = decoder_tokenizer.eos_token_id
    return gc

gen_config = GenerationConfig()
gen_config = set_decoder_configuration(gen_config)

In [9]:
train_args = Seq2SeqTrainingArguments(
    report_to="wandb",
    run_name="testing-data-maker-1",
    num_train_epochs=10,

    logging_strategy="steps",
    logging_steps=10,

    evaluation_strategy="epoch",

    output_dir="./.ckp/",
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=4,

    optim="adamw_torch",
    bf16=True,

    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    
    group_by_length=True,
    length_column_name="length",

    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_config=gen_config,
    # torch_compile=True,
    # label_smoothing_factor=0,
    # auto_find_batch_size=False,
)

In [10]:
trainer = Seq2SeqTrainer(
    model, 
    args=train_args,
    data_collator=data_collator,
    train_dataset=train_data, 
    eval_dataset=valid_data, 
    compute_metrics=compute_metrics
)

In [11]:
model.train()
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdavidboening[0m ([33mdandd[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/650 [00:00<?, ?it/s]



{'loss': 4.1641, 'learning_rate': 4.923076923076924e-05, 'epoch': 0.77}
{'loss': 3.0033, 'learning_rate': 4.846153846153846e-05, 'epoch': 1.54}
{'loss': 2.3679, 'learning_rate': 4.76923076923077e-05, 'epoch': 2.31}
{'loss': 2.0273, 'learning_rate': 4.692307692307693e-05, 'epoch': 3.08}
{'loss': 1.626, 'learning_rate': 4.615384615384616e-05, 'epoch': 3.85}
{'loss': 1.3038, 'learning_rate': 4.538461538461539e-05, 'epoch': 4.62}
{'loss': 1.183, 'learning_rate': 4.461538461538462e-05, 'epoch': 5.38}
{'loss': 0.9856, 'learning_rate': 4.384615384615385e-05, 'epoch': 6.15}
{'loss': 0.8227, 'learning_rate': 4.3076923076923084e-05, 'epoch': 6.92}
{'loss': 0.7118, 'learning_rate': 4.230769230769231e-05, 'epoch': 7.69}
{'loss': 0.6145, 'learning_rate': 4.1538461538461544e-05, 'epoch': 8.46}
{'loss': 0.6536, 'learning_rate': 4.0769230769230773e-05, 'epoch': 9.23}
{'loss': 0.531, 'learning_rate': 4e-05, 'epoch': 10.0}
{'loss': 0.415, 'learning_rate': 3.923076923076923e-05, 'epoch': 10.77}
{'loss': 

TrainOutput(global_step=650, training_loss=0.44654285747271316, metrics={'train_runtime': 242.7068, 'train_samples_per_second': 20.601, 'train_steps_per_second': 2.678, 'train_loss': 0.44654285747271316, 'epoch': 50.0})

In [12]:
model.cuda()
model.eval()
train_out = trainer.predict(train_data)
valid_out = trainer.predict(valid_data)

print("Train:", compute_metrics((train_out.predictions, train_data["labels"])))
print("Valid:", compute_metrics((valid_out.predictions, valid_data["labels"])))

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train: {'score': 98.90981656809547, 'counts': [740, 641, 542, 443], 'totals': [748, 648, 548, 448], 'precisions': [98.93048128342247, 98.91975308641975, 98.9051094890511, 98.88392857142857], 'bp': 1.0, 'sys_len': 748, 'ref_len': 740}
Valid: {'score': 1.6850151200126198, 'counts': [62, 7, 2, 1], 'totals': [372, 322, 272, 222], 'precisions': [16.666666666666668, 2.1739130434782608, 0.7352941176470589, 0.45045045045045046], 'bp': 0.905324020561496, 'sys_len': 372, 'ref_len': 409}


In [13]:
train_decode = decoder_tokenizer.batch_decode(train_out.predictions, skip_special_tokens=True)
valid_decode = decoder_tokenizer.batch_decode(valid_out.predictions, skip_special_tokens=True)

In [14]:
def print_pairs(dataset, generation, sample=5):
    assert len(dataset) == len(generation), "Invalid combination!"

    sample_ids = random.sample(range(len(dataset)), sample)
    for i, sid in enumerate(sample_ids):
        print(f"Sentence #{i} [id={sid}]")
        print(
            f"\tOriginal:  {dataset['source'][sid]}\n"
            f"\tTarget:    {dataset['target'][sid]}\n"
            f"\tGenerated: {generation[sid]}\n"
        )
    return

print_pairs(train_data, train_decode, sample=3)

Sentence #0 [id=81]
	Original:  花を入れるものには何本の花が入っていますか。
	Target:    how many flowers are there in the vase ?
	Generated: how many flowers are there in the vase?

Sentence #1 [id=14]
	Original:  私たちはローマで楽しく過ごしてます。
	Target:    we are having a nice time in rome .
	Generated: we are having a nice time in rome.

Sentence #2 [id=3]
	Original:  彼女はとても美しい。その上、とても賢い。
	Target:    she is very beautiful , and what is more , very wise .
	Generated: she is very beautiful, and what is more, very wise.

