In [None]:
import random
import os
os.environ["HF_HOME"] = r"./.cache"

from transformers import EncoderDecoderModel, AutoTokenizer, GenerationConfig, Seq2SeqTrainer, \
    Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from tokenizers import processors
import evaluate

- Encoders
    - BERT_JA : `cl-tohoku/bert-base-japanese-v3`
    - BERT_EN : `bert-base-uncased`, `prajjwal1/bert-tiny`
- Decorders
    - GPT_JA : `rinna/japanese-gpt2-xsmall`
    - GPT_EN : `gpt2`

In [None]:
SOURCE_LANG = "ja"

if SOURCE_LANG == "en":
    TARGET_LANG = "ja"
    encoder = "bert-base-uncased"
    decoder = "rinna/japanese-gpt2-small"
else: 
    TARGET_LANG = "en"
    encoder = "cl-tohoku/bert-base-japanese-v3"
    decoder = "gpt2"

model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder, decoder, encoder_add_pooling_layer=False
)
model.cuda();

In [None]:
def print_model_parameters():
    t_pars, t_bytes = 0, 0
    for p in model.parameters():
        t_pars += p.nelement()
        t_bytes += p.nelement() * p.element_size()

    c_attn_pars, c_attn_bytes = 0, 0
    for layer in model.decoder.transformer.h:
        for p in layer.crossattention.parameters():
            c_attn_pars += p.nelement()
            c_attn_bytes += p.nelement() * p.element_size()
        for p in layer.ln_cross_attn.parameters():
            c_attn_pars += p.nelement()
            c_attn_bytes += p.nelement() * p.element_size()

    print(f"Total number of parameters: {t_pars:12,} ({(t_bytes / 1024**2):7,.1f}MB)")
    print(f"Cross-attention parameters: {c_attn_pars:12,} ({(c_attn_bytes / 1024**2):7,.1f}MB)")

print_model_parameters()

In [None]:
def set_cross_attention_only(model):
    for p in model.parameters():
        p.requires_grad = False
    for layer in model.decoder.transformer.h:
        for p in layer.crossattention.parameters():
            p.requires_grad = True
        for p in layer.ln_cross_attn.parameters():
            p.requires_grad = True
set_cross_attention_only(model)

In [None]:
encoder_tokenizer = AutoTokenizer.from_pretrained(encoder, use_fast=True)
decoder_tokenizer = AutoTokenizer.from_pretrained(decoder, use_fast=True)
if decoder_tokenizer.pad_token_id is None:
    decoder_tokenizer.pad_token_id = decoder_tokenizer.eos_token_id

model.config.decoder_start_token_id = decoder_tokenizer.bos_token_id
model.config.eos_token_id = decoder_tokenizer.eos_token_id
model.config.pad_token_id = decoder_tokenizer.eos_token_id

# add EOS token at the end of each sentence
decoder_tokenizer._tokenizer.post_processor = processors.TemplateProcessing(
    single="$A " + decoder_tokenizer.eos_token,
    special_tokens=[(decoder_tokenizer.eos_token, decoder_tokenizer.eos_token_id)],
)

In [None]:
from utils.dataset import EnJaDatasetMaker

dataset = EnJaDatasetMaker.load_dataset("ja-en-test-1")
train_data = dataset.select(range(100))
valid_data = dataset.select(range(100, 150))

data_collator = DataCollatorForSeq2Seq(encoder_tokenizer, model=model)

In [None]:
dataset[0]

In [None]:
metric = evaluate.load("sacrebleu")

if TARGET_LANG == "ja":
    def compute_metrics(preds):
        preds_ids, labels_ids = preds

        labels_ids[labels_ids == -100] = decoder_tokenizer.eos_token_id
        references = decoder_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
        references = [[reference] for reference in references]

        predictions = decoder_tokenizer.batch_decode(preds_ids, skip_special_tokens=True)
        
        bleu_output = metric.compute(
            references=references, 
            predictions=predictions, 
            tokenize="ja-mecab"
        )
        return bleu_output
else:
    def compute_metrics(preds):
        preds_ids, labels_ids = preds

        labels_ids[labels_ids == -100] = decoder_tokenizer.eos_token_id
        references = decoder_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
        references = [[reference] for reference in references]

        predictions = decoder_tokenizer.batch_decode(preds_ids, skip_special_tokens=True)
        
        bleu_output = metric.compute(
            references=references, 
            predictions=predictions
        )
        return bleu_output

In [None]:
MAX_LENGHT = 128
def set_decoder_configuration(gc: GenerationConfig):
    gc.no_repeat_ngram_size = 3
    gc.length_penalty = 2.0
    gc.num_beams = 3
    #gen_config.max_new_tokens = MAX_LENGHT
    gc.max_length = MAX_LENGHT * 2
    gc.min_length = 0
    gc.early_stopping = True
    gc.pad_token_id = decoder_tokenizer.eos_token_id
    gc.bos_token_id = decoder_tokenizer.bos_token_id
    gc.eos_token_id = decoder_tokenizer.eos_token_id
    return gc

gen_config = GenerationConfig()
gen_config = set_decoder_configuration(gen_config)

In [None]:
train_args = Seq2SeqTrainingArguments(
    report_to="wandb",
    run_name="testing-data-maker-2",
    num_train_epochs=50,

    logging_strategy="steps",
    logging_steps=10,

    evaluation_strategy="epoch",

    output_dir="./.ckp/",
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=4,

    optim="adamw_torch",
    bf16=True,

    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    
    group_by_length=True,
    length_column_name="length",

    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_config=gen_config,
    # torch_compile=True,
    # label_smoothing_factor=0,
    # auto_find_batch_size=False,
)

In [None]:
trainer = Seq2SeqTrainer(
    model, 
    args=train_args,
    data_collator=data_collator,
    train_dataset=train_data, 
    eval_dataset=valid_data, 
    compute_metrics=compute_metrics
)

In [None]:
model.train()
trainer.train()

In [None]:
model.cuda()
model.eval()
train_out = trainer.predict(train_data)
valid_out = trainer.predict(valid_data)

print("Train:", compute_metrics((train_out.predictions, train_data["labels"])))
print("Valid:", compute_metrics((valid_out.predictions, valid_data["labels"])))

In [None]:
train_decode = decoder_tokenizer.batch_decode(train_out.predictions, skip_special_tokens=True)
valid_decode = decoder_tokenizer.batch_decode(valid_out.predictions, skip_special_tokens=True)

In [None]:
def print_pairs(dataset, generation, sample=5):
    assert len(dataset) == len(generation), "Invalid combination!"

    sample_ids = random.sample(range(len(dataset)), sample)
    for i, sid in enumerate(sample_ids):
        print(f"Sentence #{i} [id={sid}]")
        print(
            f"\tOriginal:  {dataset['source'][sid]}\n"
            f"\tTarget:    {dataset['target'][sid]}\n"
            f"\tGenerated: {generation[sid]}\n"
        )
    return

print_pairs(train_data, train_decode, sample=3)