In [1]:
import os

os.environ["HF_HOME"] = r"./.cache"
from transformers import EncoderDecoderModel, AutoTokenizer, GenerationConfig, Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import load_dataset
from tokenizers import processors


- Encoders
    - BERT_JA : `cl-tohoku/bert-base-japanese-v3`
    - BERT_EN : `bert-base-uncased`, `prajjwal1/bert-tiny`
- Decorders
    - GPT_JA : `rinna/japanese-gpt2-xsmall`
    - GPT_EN : `gpt2`

In [2]:
source_lng = "ja"

if source_lng == "en":
    target_lng = "ja"
    encoder = "bert-base-uncased"
    decoder = "rinna/japanese-gpt2-small"
else: 
    target_lng = "en"
    encoder = "cl-tohoku/bert-base-japanese-v3"
    decoder = "gpt2"

model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder, decoder, encoder_add_pooling_layer=False
)
model.cuda()

encoder_tokenizer = AutoTokenizer.from_pretrained(encoder, use_fast=True)
decoder_tokenizer = AutoTokenizer.from_pretrained(decoder, use_fast=True)
if decoder_tokenizer.pad_token_id is None:
    decoder_tokenizer.pad_token_id = decoder_tokenizer.eos_token_id

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.7.crossattention.q_attn.bias', 'h.2.crossattention.q_attn.bias', 'h.7.crossattention.c_attn.weight', 'h.1.ln_cross_attn.bias', 'h.0.crossattention.q_attn.bias', 'h.4.crossattention.c_attn.bias', 'h.10.crossattention.c_proj.weight', 'h.0.crossattention.c_attn.weight', 'h.6.ln_cross_attn.weight', 'h.5.crossattention.q_attn.bias', 'h.9.crossattention.c_attn.weight', 'h.3.crossattention.q_attn.weight', 'h.2.ln_cross_attn.weight', 'h.5.ln_cross_attn.weight', 'h.8.crossattention.c_attn.bias', 'h.5.ln_cross_attn.bias', 'h.0.ln_cross_attn.bias', 'h.8.crossattention.q_attn.bias', 'h.3.crossattention.q_attn.bias', 'h.7.crossattention.c_attn.bias', 'h.0.crossattention.c_proj.weight', 'h.8.crossattention.c_proj.weight', 'h.6.crossattention.q_attn.bias', 'h.6.crossattention.c_attn.bias', 'h.0.crossattention.c_proj.bias', 'h.11.crossattention.c_proj.weight', 'h.6.crossattention.q_att

In [3]:
print(model)

EncoderDecoderModel(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32768, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elemen

In [4]:
model.config.decoder_start_token_id = decoder_tokenizer.bos_token_id
model.config.eos_token_id = decoder_tokenizer.eos_token_id
model.config.pad_token_id = decoder_tokenizer.eos_token_id
print(model.config)

EncoderDecoderConfig {
  "_commit_hash": null,
  "decoder": {
    "_name_or_path": "gpt2",
    "activation_function": "gelu_new",
    "add_cross_attention": true,
    "architectures": [
      "GPT2LMHeadModel"
    ],
    "attn_pdrop": 0.1,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": 50256,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "early_stopping": false,
    "embd_pdrop": 0.1,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": 50256,
    "exponential_decay_length_penalty": null,
    "finetuning_task": null,
    "forced_bos_token_id": null,
    "forced_eos_token_id": null,
    "id2label": {
      "0": "LABEL_0",
      "1": "LABEL_1"
    },
    "initializer_range": 0.02,
    "is_decoder": true,
    "is_encoder_decoder": false,
    "label2id": {
      "LABEL_0": 0,
      "LABEL_1": 1
    },
    "layer_norm_epsilon": 1

In [5]:
def print_model_parameters():
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

    c_attn_pars = 0
    for layer in model.decoder.transformer.h:
        c_attn_pars += sum(p.numel() for p in layer.crossattention.parameters())
        c_attn_pars += sum(p.numel() for p in layer.ln_cross_attn.parameters())

    print(f"Number of cross-attention parameters: {c_attn_pars}")


print_model_parameters()

Number of parameters: 263423232
Number of cross-attention parameters: 28366848


In [6]:
def print_model_size():
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    print("model size: {:.1f}MB".format(size_all_mb))


print_model_size()

model size: 1028.9MB


In [7]:
dataset = load_dataset("csv", data_files=r"./data-csv/snow_simplified.csv")
data_sample = dataset["train"]
train_sample = data_sample.select(range(128))
val_sample = data_sample.select(range(128, 160))

In [8]:
decoder_tokenizer._tokenizer.post_processor = processors.TemplateProcessing(
    single="$A " + decoder_tokenizer.eos_token,
    special_tokens=[(decoder_tokenizer.eos_token, decoder_tokenizer.eos_token_id)],
)

In [9]:
def preprocess_data(batch):
    inputs = encoder_tokenizer(
    batch[f"{source_lng}_sentence"],
    padding="max_length",
    max_length=128,
    truncation=True,
    return_tensors="pt",
    )

    labels = decoder_tokenizer(
        batch[f"{target_lng}_sentence"],
        padding="max_length",
        max_length=128,
        truncation=True,
        return_tensors="pt",
    )

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["labels"] = labels.input_ids
    batch["labels"][labels["attention_mask"]==0] = -100
    return batch

In [10]:
train_data = train_sample.map(preprocess_data, batched=True, batch_size=32, remove_columns=["en_sentence", "ja_sentence"])
train_data.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
train_data

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 128
})

In [11]:
val_data = val_sample.map(preprocess_data, batched=True, batch_size=32, remove_columns=["en_sentence", "ja_sentence"])
val_data.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
val_data

Map:   0%|          | 0/32 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 32
})

In [12]:
gen_config = GenerationConfig()

def set_decoder_configuration(gen_config):
    gen_config.no_repeat_ngram_size = 3
    gen_config.length_penalty = 2.0
    gen_config.num_beams = 3
    # gen_config.max_new_tokens = 128
    gen_config.min_length = 0
    gen_config.max_length = 128
    gen_config.early_stopping = True
    gen_config.pad_token_id = decoder_tokenizer.eos_token_id
    gen_config.bos_token_id = decoder_tokenizer.bos_token_id
    gen_config.eos_token_id = decoder_tokenizer.eos_token_id
    return gen_config


gen_config = set_decoder_configuration(gen_config)
gen_config

GenerationConfig {
  "bos_token_id": 50256,
  "early_stopping": true,
  "eos_token_id": 50256,
  "length_penalty": 2.0,
  "max_length": 128,
  "no_repeat_ngram_size": 3,
  "num_beams": 3,
  "pad_token_id": 50256,
  "transformers_version": "4.31.0"
}

In [13]:
from evaluate import load

In [14]:
metric = load("sacrebleu")

def compute_metrics(preds):
    preds_ids, labels_ids = preds

    labels_ids[labels_ids == -100] = decoder_tokenizer.eos_token_id
    references = decoder_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    references = [[reference] for reference in references]

    predictions = decoder_tokenizer.batch_decode(preds_ids, skip_special_tokens=True)

    if target_lng == "ja":
        bleu_output = metric.compute(references=references, predictions=predictions, tokenize="ja-mecab")
    else:
        bleu_output = metric.compute(references=references, predictions=predictions)
    return bleu_output


In [15]:
train_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    output_dir="./",
    num_train_epochs=10,
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    bf16=True,
    # torch_compile=True,
    generation_config=gen_config,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8
)

In [16]:
trainer = Seq2SeqTrainer(model, args=train_args, train_dataset=train_data, eval_dataset=val_data, compute_metrics=compute_metrics)

In [17]:
model.train()
trainer.train()



Epoch,Training Loss,Validation Loss,Score,Counts,Totals,Precisions,Bp,Sys Len,Ref Len
1,3.6911,3.223726,1.105588,"[48, 4, 0, 0]","[222, 190, 158, 126]","[21.62162162162162, 2.1052631578947367, 0.31645569620253167, 0.1984126984126984]",0.850303,222,258
2,2.3073,3.299203,0.861063,"[47, 2, 0, 0]","[266, 234, 202, 170]","[17.669172932330827, 0.8547008547008547, 0.24752475247524752, 0.14705882352941177]",1.0,266,258
3,1.5475,3.469179,0.935582,"[46, 2, 0, 0]","[195, 163, 131, 99]","[23.58974358974359, 1.2269938650306749, 0.3816793893129771, 0.25252525252525254]",0.723918,195,258
4,1.0834,3.645945,0.92823,"[48, 2, 0, 0]","[224, 192, 160, 128]","[21.428571428571427, 1.0416666666666667, 0.3125, 0.1953125]",0.859172,224,258
5,0.7893,4.052249,0.925515,"[47, 2, 0, 0]","[221, 189, 157, 125]","[21.266968325791854, 1.0582010582010581, 0.3184713375796178, 0.2]",0.845844,221,258
6,0.5771,4.278708,1.062228,"[45, 4, 0, 0]","[249, 217, 185, 153]","[18.072289156626507, 1.8433179723502304, 0.2702702702702703, 0.16339869281045752]",0.964501,249,258
7,0.4536,4.429223,0.92767,"[49, 2, 0, 0]","[231, 199, 167, 135]","[21.21212121212121, 1.0050251256281406, 0.2994011976047904, 0.18518518518518517]",0.889689,231,258
8,0.4015,4.621641,0.819274,"[48, 2, 0, 0]","[278, 246, 214, 182]","[17.26618705035971, 0.8130081300813008, 0.2336448598130841, 0.13736263736263737]",1.0,278,258
9,0.4461,4.623802,0.870615,"[51, 2, 0, 0]","[268, 236, 204, 172]","[19.029850746268657, 0.847457627118644, 0.24509803921568626, 0.14534883720930233]",1.0,268,258
10,0.3327,4.641215,0.85043,"[50, 2, 0, 0]","[272, 240, 208, 176]","[18.38235294117647, 0.8333333333333334, 0.2403846153846154, 0.14204545454545456]",1.0,272,258


TrainOutput(global_step=160, training_loss=1.1629755705595017, metrics={'train_runtime': 56.2751, 'train_samples_per_second': 22.745, 'train_steps_per_second': 2.843, 'total_flos': 195112646737920.0, 'train_loss': 1.1629755705595017, 'epoch': 10.0})

In [18]:
model.cuda()
model.eval()
output = model.generate(
    val_data["input_ids"].cuda(),
    attention_mask=val_data["attention_mask"].cuda(),
    generation_config=gen_config
)



In [19]:
compute_metrics((output, val_data["labels"]))

{'score': 0.7700901195336971,
 'counts': [59, 2, 0, 0],
 'totals': [305, 273, 241, 209],
 'precisions': [19.34426229508197,
  0.7326007326007326,
  0.2074688796680498,
  0.11961722488038277],
 'bp': 1.0,
 'sys_len': 305,
 'ref_len': 258}

In [20]:
clean_output = decoder_tokenizer.batch_decode(output, skip_special_tokens=True)

In [21]:
clean_output

['the number of boys in our class is thirty.',
 'the number of boys in our class is thirty.',
 'the number of boys in our class is thirty.',
 'the number of boys in our class is thirty.',
 'the number of boys in our class is thirty.',
 'the number of boys in our class is thirty.',
 'the number of boys in our class is thirty.',
 'the number of boys in our class is thirty.',
 'he always speaks ill of his father behind his back.',
 'he always speaks ill of his father behind his back.',
 'the number of boys in our class is thirty.',
 'the number of boys in our class is thirty.',
 'i have nothing to live for.',
 'i have nothing to live for.',
 'the number of boys in our class is thirty.',
 'the number of boys in our class is thirty.',
 "you have eaten lunch, haven 't you?",
 "you have eaten lunch, haven 't you?",
 'i have nothing to live for.',
 'i have half a mind to undertake the work.',
 'i have half a mind to undertake the work.',
 'i have half a mind to undertake the work.',
 "it's a c

In [24]:
val_sample[f"{target_lng}_sentence"]

['what are you talking about ?',
 'what are you talking about ?',
 'he is doing it with my help .',
 'he is doing it with my help .',
 'he is married to an american lady .',
 'he is married to an american lady .',
 'the cat ran up the tree .',
 'the cat ran up the tree .',
 'he pretends to know everything .',
 'he pretends to know everything .',
 'john is walking in the direction of the station .',
 'john is walking in the direction of the station .',
 "there 's no need to get so angry .",
 "there 's no need to get so angry .",
 'let me call you back later , ok ?',
 'let me call you back later , ok ?',
 'please help yourself to more cake .',
 'please help yourself to more cake .',
 'i have two foreign friends .',
 'i have two foreign friends .',
 'you have to make efforts if you are to succeed .',
 'you have to make efforts if you are to succeed .',
 'choose between these two .',
 'choose between these two .',
 'the house is quite run down .',
 'the house is quite run down .',
 'i was 