In [None]:
import random
import os
os.environ["HF_HOME"] = r"./.cache"

from transformers import EncoderDecoderModel, AutoTokenizer, GenerationConfig, Seq2SeqTrainer, \
    Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from tokenizers import processors
import evaluate

- Encoders
    - BERT_JA : `cl-tohoku/bert-base-japanese-v3`
    - BERT_EN : `bert-base-uncased`, `prajjwal1/bert-tiny`
- Decorders
    - GPT_JA : `rinna/japanese-gpt2-xsmall`
    - GPT_EN : `gpt2`

In [None]:
SOURCE_LANG = "en"
RESUME = False

if SOURCE_LANG == "en":
    TARGET_LANG = "ja"
    encoder = "bert-base-uncased"
    decoder = "rinna/japanese-gpt2-small"
else: 
    TARGET_LANG = "en"
    encoder = "cl-tohoku/bert-base-japanese-v3"
    decoder = "gpt2"

model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder, decoder
)
model.cuda();

In [None]:
def print_model_parameters():
    t_pars, t_bytes = 0, 0
    for p in model.parameters():
        t_pars += p.nelement()
        t_bytes += p.nelement() * p.element_size()

    c_attn_pars, c_attn_bytes = 0, 0
    for layer in model.decoder.transformer.h:
        for p in layer.crossattention.parameters():
            c_attn_pars += p.nelement()
            c_attn_bytes += p.nelement() * p.element_size()
        for p in layer.ln_cross_attn.parameters():
            c_attn_pars += p.nelement()
            c_attn_bytes += p.nelement() * p.element_size()

    print(f"Total number of parameters: {t_pars:12,} ({(t_bytes / 1024**2):7,.1f}MB)")
    print(f"Cross-attention parameters: {c_attn_pars:12,} ({(c_attn_bytes / 1024**2):7,.1f}MB)")

print_model_parameters()

In [None]:
def set_cross_attention_only(model):
    for p in model.parameters():
        p.requires_grad = False
    for layer in model.decoder.transformer.h:
        for p in layer.crossattention.parameters():
            p.requires_grad = True
        for p in layer.ln_cross_attn.parameters():
            p.requires_grad = True
set_cross_attention_only(model)

In [None]:
encoder_tokenizer = AutoTokenizer.from_pretrained(encoder, use_fast=True)
decoder_tokenizer = AutoTokenizer.from_pretrained(decoder, use_fast=True)
if decoder_tokenizer.pad_token_id is None:
    decoder_tokenizer.pad_token_id = decoder_tokenizer.eos_token_id

model.config.decoder_start_token_id = decoder_tokenizer.bos_token_id
model.config.eos_token_id = decoder_tokenizer.eos_token_id
model.config.pad_token_id = decoder_tokenizer.eos_token_id

# add EOS token at the end of each sentence
decoder_tokenizer._tokenizer.post_processor = processors.TemplateProcessing(
    single="$A " + decoder_tokenizer.eos_token,
    special_tokens=[(decoder_tokenizer.eos_token, decoder_tokenizer.eos_token_id)],
)

In [None]:
encoder_tokenizer

In [None]:
from utils.dataset import EnJaDatasetMaker, EnJaDatasetSample, \
    OPUS100, JESC, MassiveTranslation, SnowSimplified, Tatoeba, IWSLT2017
from utils.dataset.dataset_base import EnJaDataset

def get_csv_path(cls):
    assert issubclass(cls, EnJaDataset), "Invalid class passed!"
    return f"{EnJaDataset.DATASET_PROCESSED_DIR}/{cls.OUT_NAME}"

dataset = EnJaDatasetMaker.prepare_dataset(
    f"BERT-GPT2-{SOURCE_LANG}-{TARGET_LANG}",
    [
        EnJaDatasetSample(dataset=get_csv_path(OPUS100), nsample=50_000, ntokens=(0, 128)),
        EnJaDatasetSample(dataset=get_csv_path(JESC), nsample=150_000, ntokens=(0, 128)),
        EnJaDatasetSample(dataset=get_csv_path(MassiveTranslation), nsample=20_000, ntokens=(0, 128)),
        EnJaDatasetSample(dataset=get_csv_path(SnowSimplified), nsample=30_000, ntokens=(0, 128)),
        EnJaDatasetSample(dataset=get_csv_path(Tatoeba), nsample=125_000, ntokens=(0, 128)),
        EnJaDatasetSample(dataset=get_csv_path(IWSLT2017), nsample=175_000, ntokens=(0, 128)),
    ],
    source_language = SOURCE_LANG,
    model_type= "BERT-GPT2",
    encoder_tokenizer = encoder_tokenizer,
    decoder_tokenizer= decoder_tokenizer,
    num_proc  = 8,
    seed      = 123,
    splits    = (1, 0.002, 1) # rescaled to 1
)

train_data = dataset["train"].remove_columns(["source", "target"])
valid_data = dataset["valid"].remove_columns(["source", "target"])

data_collator = DataCollatorForSeq2Seq(encoder_tokenizer, model=model)

In [None]:
metric = evaluate.load("sacrebleu")

if TARGET_LANG == "ja":
    def compute_metrics(preds):
        preds_ids, labels_ids = preds

        labels_ids[labels_ids == -100] = decoder_tokenizer.eos_token_id
        references = decoder_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
        references = [[reference] for reference in references]

        predictions = decoder_tokenizer.batch_decode(preds_ids, skip_special_tokens=True)
        
        bleu_output = metric.compute(
            references=references, 
            predictions=predictions, 
            tokenize="ja-mecab"
        )
        return bleu_output
else:
    def compute_metrics(preds):
        preds_ids, labels_ids = preds

        labels_ids[labels_ids == -100] = decoder_tokenizer.eos_token_id
        references = decoder_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
        references = [[reference] for reference in references]

        predictions = decoder_tokenizer.batch_decode(preds_ids, skip_special_tokens=True)
        
        bleu_output = metric.compute(
            references=references, 
            predictions=predictions
        )
        return bleu_output

In [None]:
def set_decoder_configuration(gc: GenerationConfig):
    gc.no_repeat_ngram_size = 4
    gc.length_penalty = 2.0
    gc.num_beams = 3
    #gen_config.max_new_tokens = MAX_LENGHT
    gc.max_length = 256
    gc.min_length = 0
    gc.early_stopping = True
    # pad token is set to eos since in GPT2 pad does not exist
    gc.pad_token_id = decoder_tokenizer.eos_token_id
    gc.bos_token_id = decoder_tokenizer.bos_token_id
    gc.eos_token_id = decoder_tokenizer.eos_token_id
    return gc

gen_config = GenerationConfig()
gen_config = set_decoder_configuration(gen_config)

In [None]:
train_args = Seq2SeqTrainingArguments(
    report_to="none",
    run_name=f"BERT-GPT2-{SOURCE_LANG}-{TARGET_LANG}-xattn",
    num_train_epochs=3,

    logging_strategy="steps",
    logging_steps=1, # * 4, 2, 1

    evaluation_strategy="steps",
    eval_steps=2500, # * 20_000, 10_000, 5_000
    prediction_loss_only=False,
    predict_with_generate=True,
    generation_config=gen_config,

    output_dir=f"./.ckp/BERT-GPT2-{SOURCE_LANG}-{TARGET_LANG}-xattn/",
    save_strategy="steps",
    save_steps=2500, # * 20_000, 10_000, 5_000
    save_total_limit=100,
    load_best_model_at_end=True, # defaults to metric: "loss"
    metric_for_best_model="eval_score",
    greater_is_better=True,

    optim="adamw_torch",
    warmup_steps=400, # 3500, 1750, 875
    learning_rate=5e-5, # 3e-5, 5e-5
    bf16=True, # bf16, qint 8 ???
    
    group_by_length=True,
    length_column_name="length",

    # torch_compile=True,
    label_smoothing_factor=0.2, # 0.1, 0.2
    
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    # gradient_accumulation_steps=1, # * 1, 2, 4
    # gradient_checkpointing=True,
    # eval_accumulation_steps=4, # ???
)

In [None]:
trainer = Seq2SeqTrainer(
    model, 
    args=train_args,
    data_collator=data_collator,
    train_dataset=train_data, 
    eval_dataset=valid_data, 
    compute_metrics=compute_metrics
)

In [None]:
model.train()
trainer.train(resume_from_checkpoint=RESUME)