In [1]:
import os

os.environ["HF_HOME"] = r"./.cache"
from transformers import EncoderDecoderModel, AutoTokenizer, GenerationConfig
from datasets import load_dataset

- Encoders
    - BERT_JA : `cl-tohoku/bert-base-japanese-v3`
    - BERT_EN : `bert-base-uncased`, `prajjwal1/bert-tiny`
- Decorders
    - GPT_JA : `rinna/japanese-gpt2-xsmall`
    - GPT_EN : `gpt2`

In [2]:
encoder = "cl-tohoku/bert-base-japanese-v3"
decoder = "gpt2"

model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder, decoder, encoder_add_pooling_layer=False
)
encoder_tokenizer = AutoTokenizer.from_pretrained(encoder, use_fast=True)
decoder_tokenizer = AutoTokenizer.from_pretrained(decoder, use_fast=True)

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.9.crossattention.c_proj.bias', 'h.10.ln_cross_attn.bias', 'h.5.ln_cross_attn.weight', 'h.8.crossattention.q_attn.weight', 'h.6.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.11.crossattention.c_proj.bias', 'h.2.ln_cross_attn.bias', 'h.2.crossattention.q_attn.weight', 'h.8.crossattention.c_attn.bias', 'h.2.crossattention.c_attn.weight', 'h.4.crossattention.c_attn.weight', 'h.11.crossattention.c_proj.weight', 'h.9.crossattention.c_attn.bias', 'h.11.ln_cross_attn.weight', 'h.6.ln_cross_attn.weight', 'h.11.ln_cross_attn.bias', 'h.4.crossattention.q_attn.bias', 'h.4.crossattention.q_attn.weight', 'h.3.ln_cross_attn.bias', 'h.11.crossattention.c_attn.bias', 'h.6.crossattention.c_proj.weight', 'h.5.ln_cross_attn.bias', 'h.5.crossattention.q_attn.weight', 'h.8.crossattention.q_attn.bias', 'h.2.crossattention.q_attn.bias', 'h.11.crossattention.q_attn.bias', 

In [3]:
print(model)

EncoderDecoderModel(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32768, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elemen

In [4]:
model.config

EncoderDecoderConfig {
  "_commit_hash": null,
  "decoder": {
    "_name_or_path": "gpt2",
    "activation_function": "gelu_new",
    "add_cross_attention": true,
    "architectures": [
      "GPT2LMHeadModel"
    ],
    "attn_pdrop": 0.1,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": 50256,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "early_stopping": false,
    "embd_pdrop": 0.1,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": 50256,
    "exponential_decay_length_penalty": null,
    "finetuning_task": null,
    "forced_bos_token_id": null,
    "forced_eos_token_id": null,
    "id2label": {
      "0": "LABEL_0",
      "1": "LABEL_1"
    },
    "initializer_range": 0.02,
    "is_decoder": true,
    "is_encoder_decoder": false,
    "label2id": {
      "LABEL_0": 0,
      "LABEL_1": 1
    },
    "layer_norm_epsilon": 1

In [5]:
def print_model_parameters():
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

    c_attn_pars = 0
    for layer in model.decoder.transformer.h:
        c_attn_pars += sum(p.numel() for p in layer.crossattention.parameters())
        c_attn_pars += sum(p.numel() for p in layer.ln_cross_attn.parameters())

    print(f"Number of cross-attention parameters: {c_attn_pars}")


print_model_parameters()

Number of parameters: 263423232
Number of cross-attention parameters: 28366848


In [6]:
def print_model_size():
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    print("model size: {:.1f}MB".format(size_all_mb))


print_model_size()

model size: 1028.9MB


In [7]:
gen_config = GenerationConfig()


def set_decoder_configuration(gen_config):
    gen_config.no_repeat_ngram_size = 3
    gen_config.num_beams = 4
    gen_config.length_penalty = 2.0
    gen_config.num_beams = 4
    gen_config.max_new_tokens = 128
    return gen_config


gen_config = set_decoder_configuration(gen_config)

In [8]:
dataset = load_dataset("csv", data_files=r"./data-post/snow_simplified.csv")
data_sample = dataset["train"]
data_sample = data_sample.select(range(4))

In [9]:
data_sample

Dataset({
    features: ['ID', 'original_ja', 'simplified_ja', 'original_en'],
    num_rows: 4
})

In [10]:
data_sample.data

MemoryMappedTable
ID: string
original_ja: string
simplified_ja: string
original_en: string
----
ID: [["1","2","3","4"]]
original_ja: [["誰が一番に着くか私には分かりません。","多くの動物が人間によって滅ぼされた。","私はテニス部員です。","エミは幸せそうに見えます。"]]
simplified_ja: [["誰が一番に着くか私には分かりません。","多くの動物が人間によって殺された。","私はテニス部員です。","エミは幸せそうに見えます。"]]
original_en: [["i can 't tell who will arrive first .","many animals have been destroyed by men .","i 'm in the tennis club .","emi looks happy ."]]

In [12]:
inputs = encoder_tokenizer(
    data_sample["original_ja"],
    padding="max_length",
    max_length=512,
    truncation=True,
    return_tensors="pt",
)

In [14]:
test_output = None


def test_hook(module, input_, output):
    global test_output
    test_output = output


model.decoder.transformer.h[0].crossattention.register_forward_hook(test_hook)

<torch.utils.hooks.RemovableHandle at 0x7f48b67d1e10>

In [15]:
output = model.generate(
    inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    generation_config=gen_config,
)

In [19]:
clean_output = decoder_tokenizer.batch_decode(output, skip_special_tokens=True)

In [20]:
clean_output

['"I\'m not going to lie to you," he said. "I\'m going to tell you what I know. I\'m not gonna lie to anybody."\n\n"I don\'t know what you\'re talking about," he added. "You\'re going to have to tell me what you know. You know what I\'m talking about. I know what\'s going to happen. I don\'t care what you think. I just want you to know that I\'m here to help you. I want to make sure that you\'re safe and that you don\'t have to go through this again. I\'ve got a lot of work to',
 '"I\'m not going to lie to you," he said. "I\'m just going to tell you what I think. I\'m going to be honest with you."\n\n"I don\'t know what you\'re talking about," she said.\n\nHe shook his head. "You know what I\'m talking about?"\n\nShe shook her head again. "No, I don\'t think so. I think you\'re going to have to do something about it. You know what? I\'m not saying that you can\'t do it, but I\'m saying that it\'s going to take a lot of work to get it',
 '"I don\'t know if I\'m going to be able to do it