In [1]:
import os
os.environ['HF_HOME'] = r"./.cache"
from transformers import EncoderDecoderModel, AutoTokenizer
from datasets import load_dataset

- Encoders
    - BERT_JA : `cl-tohoku/bert-base-japanese-v3`
    - BERT_EN : `bert-base-uncased`, `prajjwal1/bert-tiny`
- Decorders
    - GPT_JA : `rinna/japanese-gpt2-xsmall`
    - GPT_EN : `gpt2`

In [2]:
encoder = "cl-tohoku/bert-base-japanese-v3"
decoder = "gpt2"

model = EncoderDecoderModel.from_encoder_decoder_pretrained(encoder, decoder, 
    # encoder_add_pooling_layer=True
)
encoder_tokenizer = AutoTokenizer.from_pretrained(encoder, use_fast=True)
decoder_tokenizer = AutoTokenizer.from_pretrained(decoder, use_fast=True)

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_proj.weight', 'h.1.ln_cross_attn.weight', 'h.7.ln_cross_attn.weight', 'h.7.ln_cross_attn.bias', 'h.5.crossattention.c_proj.weight', 'h.9.crossattention.c_attn.bias', 'h.11.crossattention.c_attn.weight', 'h.7.crossattention.c_attn.bias', 'h.5.crossattention.c_proj.bias', 'h.11.crossattention.c_attn.bias', 'h.9.crossattention.c_proj.bias', 'h.2.ln_cross_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.2.ln_cross_attn.weight', 'h.6.ln_cross_attn.bias', 'h.4.ln_cross_attn.weight', 'h.9.crossattention.q_attn.weight', 'h.0.crossattention.c_attn.weight', 'h.4.crossattention.c_attn.weight', 'h.3.crossattention.q_attn.weight', 'h.2.crossattention.c_proj.bias', 'h.4.crossattention.c_attn.bias', 'h.1.ln_cross_attn.bias', 'h.3.crossattention.c_attn.weight', 'h.5.crossattention.q_attn.bias', 'h.0.ln_cross_attn.bias', 'h.8.crossattention.c_attn.weight', 'h.9.cros

In [3]:
model

EncoderDecoderModel(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32768, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elemen

In [4]:
model.config

EncoderDecoderConfig {
  "_commit_hash": null,
  "decoder": {
    "_name_or_path": "gpt2",
    "activation_function": "gelu_new",
    "add_cross_attention": true,
    "architectures": [
      "GPT2LMHeadModel"
    ],
    "attn_pdrop": 0.1,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": 50256,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "early_stopping": false,
    "embd_pdrop": 0.1,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": 50256,
    "exponential_decay_length_penalty": null,
    "finetuning_task": null,
    "forced_bos_token_id": null,
    "forced_eos_token_id": null,
    "id2label": {
      "0": "LABEL_0",
      "1": "LABEL_1"
    },
    "initializer_range": 0.02,
    "is_decoder": true,
    "is_encoder_decoder": false,
    "label2id": {
      "LABEL_0": 0,
      "LABEL_1": 1
    },
    "layer_norm_epsilon": 1

In [5]:
def print_model_parameters():
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

    c_attn_pars = 0
    for layer in model.decoder.transformer.h:
        c_attn_pars += sum(p.numel() for p in layer.crossattention.parameters())
        c_attn_pars += sum(p.numel() for p in layer.ln_cross_attn.parameters())

    print(f"Number of cross-attention parameters: {c_attn_pars}")
print_model_parameters()

Number of parameters: 264013824
Number of cross-attention parameters: 28366848


In [6]:
def print_model_size():
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    print('model size: {:.1f}MB'.format(size_all_mb))
print_model_size()

model size: 1031.1MB


In [7]:
def set_decoder_configuration():
    model.config.no_repeat_ngram_size = 3
    model.config.early_stopping = True
    model.config.length_penalty = 2.0
    model.config.num_beams = 4
set_decoder_configuration()

In [8]:
from transformers import BertTokenizerFast, AutoTokenizer
encoder_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
decoder_tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-xsmall", use_fast=False)

You are using the legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565


In [9]:
dataset = load_dataset("csv", data_files=r"./data-post/snow_simplified.csv")
data_sample = dataset["train"]
data_sample = data_sample.select(range(4))

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [10]:
data_sample

Dataset({
    features: ['en_sentence', 'ja_sentence'],
    num_rows: 4
})

In [11]:
data_sample.data

MemoryMappedTable
en_sentence: string
ja_sentence: string
----
en_sentence: [["i can 't tell who will arrive first .","i can 't tell who will arrive first .","many animals have been destroyed by men .","many animals have been destroyed by men ."]]
ja_sentence: [[" "誰が一番に着くか私には分かりません。""," "誰が一番に着くか私には分かりません。""," "多くの動物が人間によって滅ぼされた。""," "多くの動物が人間によって殺された。""]]

In [12]:
inputs = encoder_tokenizer(data_sample["ja_sentence"], padding="max_length", max_length=512, truncation=True, return_tensors="pt")

In [13]:
output = model.generate(inputs["input_ids"], attention_mask=inputs["attention_mask"])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [14]:
clean_output = decoder_tokenizer.batch_decode(output, skip_special_tokens=False)

In [15]:
clean_output

['<extra_id_-18257><s>た 2シリーズさビデオ27計画石う植物石ラン世もある書特に・j',
 '<extra_id_-18257><s>た 2シリーズさビデオ27計画石う植物石ラン世もある書特に・j',
 '<extra_id_-18257><s>た物語y最も石代替石旧もある書特に・jた物語ことが最も石',
 '<extra_id_-18257><s>た物語y最も石代替石旧もある書特に・jた物語ことが最も石']