In [None]:
import os

os.environ["HF_HOME"] = r"./.cache"
from transformers import EncoderDecoderModel, AutoTokenizer, GenerationConfig
from datasets import load_dataset

- Encoders
    - BERT_JA : `cl-tohoku/bert-base-japanese-v3`
    - BERT_EN : `bert-base-uncased`, `prajjwal1/bert-tiny`
- Decorders
    - GPT_JA : `rinna/japanese-gpt2-xsmall`
    - GPT_EN : `gpt2`

In [None]:
source_lng = "ja"
target_lng = "en"

if source_lng == "en":
    encoder = "bert-base-uncased"
    decoder = "rinna/japanese-gpt2-xsmall"
else: 
    encoder = "cl-tohoku/bert-base-japanese-v3"
    decoder = "gpt2"

model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder, decoder, encoder_add_pooling_layer=False
)
model.cuda()

encoder_tokenizer = AutoTokenizer.from_pretrained(encoder, use_fast=True)
decoder_tokenizer = AutoTokenizer.from_pretrained(decoder, use_fast=True)
decoder_tokenizer.pad_token_id = 0

In [None]:
print(model)

In [None]:
model.config

In [None]:
def print_model_parameters():
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

    c_attn_pars = 0
    for layer in model.decoder.transformer.h:
        c_attn_pars += sum(p.numel() for p in layer.crossattention.parameters())
        c_attn_pars += sum(p.numel() for p in layer.ln_cross_attn.parameters())

    print(f"Number of cross-attention parameters: {c_attn_pars}")


print_model_parameters()

In [None]:
def print_model_size():
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    print("model size: {:.1f}MB".format(size_all_mb))


print_model_size()

In [None]:
dataset = load_dataset("csv", data_files=r"./data-csv/snow_simplified.csv")
data_sample = dataset["train"]
data_sample = data_sample.select(range(4))

In [None]:
data_sample

In [None]:
data_sample.data

In [None]:
# need to test with jp tokenizer
# need to test with samples containing multiple sentences

from tokenizers import processors
decoder_tokenizer._tokenizer.post_processor = processors.TemplateProcessing(
    single="$A " + decoder_tokenizer.eos_token,
    special_tokens=[(decoder_tokenizer.eos_token, decoder_tokenizer.eos_token_id)],
)

In [None]:
def preprocess_data(batch):
    inputs = encoder_tokenizer(
    batch[f"{source_lng}_sentence"],
    padding="max_length",
    max_length=128,
    truncation=True,
    return_tensors="pt",
    )

    labels = decoder_tokenizer(
        batch[f"{target_lng}_sentence"],
        padding="max_length",
        max_length=128,
        truncation=True,
        return_tensors="pt",
    ).input_ids

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["labels"] = labels
    batch["labels"][batch["labels"]==decoder_tokenizer.pad_token_id] = -100
    return batch

In [None]:
train_data = data_sample.map(preprocess_data, batched=True, remove_columns=["en_sentence", "ja_sentence"])

In [None]:
train_data.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

In [None]:
train_data["labels"][3]

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

In [None]:
train_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    output_dir="./",
    num_train_epochs=10,
    logging_steps=1
)

In [None]:
trainer = Seq2SeqTrainer(model, args=train_args, train_dataset=train_data)

In [None]:
model.config.decoder_start_token_id = decoder_tokenizer.bos_token_id
model.config.eos_token_id = decoder_tokenizer.eos_token_id
model.config.pad_token_id = decoder_tokenizer.pad_token_id

In [None]:
model.train()
trainer.train()

In [None]:
gen_config = GenerationConfig()

def set_decoder_configuration(gen_config):
    gen_config.no_repeat_ngram_size = 3
    gen_config.length_penalty = 2.0
    gen_config.num_beams = 4
    gen_config.max_new_tokens = 128
    gen_config.early_stopping = True
    gen_config.pad_token_id = decoder_tokenizer.eos_token_id
    gen_config.bos_token_id = decoder_tokenizer.bos_token_id
    gen_config.eos_token_id = decoder_tokenizer.eos_token_id
    return gen_config


gen_config = set_decoder_configuration(gen_config)

In [None]:
model.cuda()
model.eval()
output = model.generate(
    train_data["input_ids"].cuda(),
    attention_mask=train_data["attention_mask"].cuda(),
    generation_config=gen_config
)

In [None]:
for i in range(output.size()[0]):
    print(output[i].size())

In [None]:
output

In [None]:
clean_output = decoder_tokenizer.batch_decode(output, skip_special_tokens=False)

In [None]:
clean_output

In [None]:
data_sample["en_sentence"]