In [None]:
from transformers import BertConfig, GPT2Config, EncoderDecoderConfig, EncoderDecoderModel, GPTNeoXJapaneseConfig
from torch.nn import Module

In [None]:
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased" ,"rinna/japanese-gpt2-xsmall", encoder_add_pooling_layer=False)

In [None]:
print(model)

In [None]:
print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

c_attn_pars = 0
for layer in model.decoder.transformer.h:
    c_attn_pars += sum(p.numel() for p in layer.crossattention.parameters())
    c_attn_pars += sum(p.numel() for p in layer.ln_cross_attn.parameters())

print(f"Number of cross-attention parameters: {c_attn_pars}")

In [None]:
param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

In [None]:
model.config.no_repeat_ngram_size = 3
model.config.early_stopping = True
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [None]:
from transformers import BertTokenizerFast, AutoTokenizer
encoder_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
decoder_tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-xsmall", use_fast=False)

In [None]:
from datasets import load_dataset

dataset = load_dataset("snow_simplified_japanese_corpus")

In [None]:
test_data = dataset.remove_columns(["ID", "simplified_ja"])

In [None]:
test_data = test_data["train"]

In [None]:
test_data = test_data.select(range(4))

In [None]:
print(test_data)

In [None]:
print(test_data.data)

In [None]:
inputs = encoder_tokenizer(test_data["original_en"], padding="max_length", max_length=512, truncation=True, return_tensors="pt")

In [None]:
output = model.generate(inputs["input_ids"], attention_mask=inputs["attention_mask"])

In [None]:
print(output)

In [None]:
print(decoder_tokenizer.batch_decode(output))