In [1]:
import random
import os
os.environ["HF_HOME"] = r"./.cache"

from transformers import EncoderDecoderModel, AutoTokenizer, GenerationConfig, Seq2SeqTrainer, Seq2SeqTrainingArguments
from tokenizers import processors
import evaluate

- Encoders
    - BERT_JA : `cl-tohoku/bert-base-japanese-v3`
    - BERT_EN : `bert-base-uncased`, `prajjwal1/bert-tiny`
- Decorders
    - GPT_JA : `rinna/japanese-gpt2-xsmall`
    - GPT_EN : `gpt2`

In [2]:
source_lng = "ja"

if source_lng == "en":
    target_lng = "ja"
    encoder = "bert-base-uncased"
    decoder = "rinna/japanese-gpt2-small"
else: 
    target_lng = "en"
    encoder = "cl-tohoku/bert-base-japanese-v3"
    decoder = "gpt2"

model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder, decoder, encoder_add_pooling_layer=False
)
model.cuda();

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.7.crossattention.q_attn.bias', 'h.2.crossattention.c_proj.bias', 'h.0.crossattention.q_attn.weight', 'h.9.crossattention.c_attn.bias', 'h.9.crossattention.c_proj.bias', 'h.0.crossattention.c_attn.bias', 'h.5.crossattention.q_attn.bias', 'h.9.crossattention.c_attn.weight', 'h.5.crossattention.c_proj.weight', 'h.5.crossattention.c_proj.bias', 'h.9.crossattention.c_proj.weight', 'h.10.ln_cross_attn.weight', 'h.1.ln_cross_attn.weight', 'h.7.ln_cross_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.3.ln_cross_attn.weight', 'h.7.crossattention.q_attn.weight', 'h.1.crossattention.c_proj.weight', 'h.4.crossattention.q_attn.weight', 'h.6.crossattention.q_attn.weight', 'h.6.crossattention.c_attn.weight', 'h.0.ln_cross_attn.bias', 'h.1.ln_cross_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.4.crossattention.c_attn.bias', 'h.5.ln_cross_attn.bias', 'h.11.crossattention.q_attn.

In [3]:
def print_model_parameters():
    t_pars, t_bytes = 0, 0
    for p in model.parameters():
        t_pars += p.nelement()
        t_bytes += p.nelement() * p.element_size()

    c_attn_pars, c_attn_bytes = 0, 0
    for layer in model.decoder.transformer.h:
        for p in layer.crossattention.parameters():
            c_attn_pars += p.nelement()
            c_attn_bytes += p.nelement() * p.element_size()
        for p in layer.ln_cross_attn.parameters():
            c_attn_pars += p.nelement()
            c_attn_bytes += p.nelement() * p.element_size()

    print(f"Total number of parameters: {t_pars:12,} ({(t_bytes / 1024**2):7,.1f}MB)")
    print(f"Cross-attention parameters: {c_attn_pars:12,} ({(c_attn_bytes / 1024**2):7,.1f}MB)")

print_model_parameters()

Total number of parameters:  263,423,232 (1,004.9MB)
Cross-attention parameters:   28,366,848 (  108.2MB)


In [4]:
encoder_tokenizer = AutoTokenizer.from_pretrained(encoder, use_fast=True)
decoder_tokenizer = AutoTokenizer.from_pretrained(decoder, use_fast=True)
if decoder_tokenizer.pad_token_id is None:
    decoder_tokenizer.pad_token_id = decoder_tokenizer.eos_token_id

model.config.decoder_start_token_id = decoder_tokenizer.bos_token_id
model.config.eos_token_id = decoder_tokenizer.eos_token_id
model.config.pad_token_id = decoder_tokenizer.eos_token_id

# add EOS token at the end of each sentence
decoder_tokenizer._tokenizer.post_processor = processors.TemplateProcessing(
    single="$A " + decoder_tokenizer.eos_token,
    special_tokens=[(decoder_tokenizer.eos_token, decoder_tokenizer.eos_token_id)],
)

In [5]:
from utils.dataset import SnowSimplifiedDataset
dataset = SnowSimplifiedDataset.load()
if source_lng == "ja":
    SnowSimplifiedDataset.stats(en_tokenizer=decoder_tokenizer, ja_tokenizer=encoder_tokenizer)
else:
    SnowSimplifiedDataset.stats(en_tokenizer=encoder_tokenizer, ja_tokenizer=decoder_tokenizer)

dataset = dataset["train"]

# sample first chuck of data for testing
train = dataset.select(range(1000))
valid = dataset.select(range(1000, 1300))

Showing statistic for 100,000 sentences:
	en[tokens] : Avg.  9.10 | Min.     5 | Max.    20 | >32.       0 | >64.     0 | >128.     0 | >256.     0
	ja[tokens] : Avg. 12.39 | Min.     4 | Max.    31 | >32.       0 | >64.     0 | >128.     0 | >256.     0


In [6]:
MAX_LENGHT = 32

def preprocess_data(batch):
    inputs = encoder_tokenizer(
        batch[f"{source_lng}_sentence"],
        padding="max_length",
        max_length=MAX_LENGHT,
        truncation=True,
        return_tensors="pt",
    )

    labels = decoder_tokenizer(
        batch[f"{target_lng}_sentence"],
        padding="max_length",
        max_length=MAX_LENGHT,
        truncation=True,
        return_tensors="pt",
    )

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["labels"] = labels.input_ids
    batch["labels"][labels["attention_mask"]==0] = -100
    return batch

In [7]:
# preprocess data
train_data = train.map(
    preprocess_data, 
    batched=True, batch_size=64, 
    remove_columns=["en_sentence", "ja_sentence"]
)
train_data.set_format(
    type="torch", 
    columns=["input_ids", "attention_mask", "labels"]
)

valid_data = valid.map(
    preprocess_data, 
    batched=True, batch_size=64, 
    remove_columns=["en_sentence", "ja_sentence"]
)
valid_data.set_format(
    type="torch", 
    columns=["input_ids", "attention_mask", "labels"]
)

In [8]:
metric = evaluate.load("sacrebleu")

def compute_metrics(preds):
    preds_ids, labels_ids = preds

    labels_ids[labels_ids == -100] = decoder_tokenizer.eos_token_id
    references = decoder_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    references = [[reference] for reference in references]

    predictions = decoder_tokenizer.batch_decode(preds_ids, skip_special_tokens=True)

    if target_lng == "ja":
        bleu_output = metric.compute(references=references, predictions=predictions, tokenize="ja-mecab")
    else:
        bleu_output = metric.compute(references=references, predictions=predictions)
    return bleu_output


In [9]:
def set_decoder_configuration(gc: GenerationConfig):
    gc.no_repeat_ngram_size = 3
    gc.length_penalty = 2.0
    gc.num_beams = 3
    #gen_config.max_new_tokens = MAX_LENGHT
    gc.max_length = MAX_LENGHT * 2
    gc.min_length = 0
    gc.early_stopping = True
    gc.pad_token_id = decoder_tokenizer.eos_token_id
    gc.bos_token_id = decoder_tokenizer.bos_token_id
    gc.eos_token_id = decoder_tokenizer.eos_token_id
    return gc

gen_config = GenerationConfig()
gen_config = set_decoder_configuration(gen_config)

In [10]:
train_args = Seq2SeqTrainingArguments(
    report_to="wandb",
    run_name="eval-testing-1",
    num_train_epochs=5,

    logging_strategy="steps",
    logging_steps=10,

    evaluation_strategy="epoch",

    output_dir="./.ckp/",
    save_strategy="steps",
    save_steps=100,
    save_total_limit=4,

    optim="adamw_torch",
    bf16=True,

    per_device_train_batch_size=8,
    gradient_accumulation_steps=8,

    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_config=gen_config,
    # torch_compile=True,
    # label_smoothing_factor=0,
    # auto_find_batch_size=True,
)

In [11]:
trainer = Seq2SeqTrainer(
    model, 
    args=train_args, 
    train_dataset=train_data, 
    eval_dataset=valid_data, 
    compute_metrics=compute_metrics
)

In [12]:
model.train()
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdavidboening[0m ([33mdandd[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/75 [00:00<?, ?it/s]



{'loss': 3.8869, 'learning_rate': 4.3333333333333334e-05, 'epoch': 0.64}


  0%|          | 0/38 [00:00<?, ?it/s]

{'eval_loss': 2.9126644134521484, 'eval_score': 0.5865561135849071, 'eval_counts': [461, 14, 4, 0], 'eval_totals': [2285, 1985, 1685, 1385], 'eval_precisions': [20.175054704595187, 0.7052896725440806, 0.23738872403560832, 0.036101083032490974], 'eval_bp': 0.9925877820461596, 'eval_sys_len': 2285, 'eval_ref_len': 2302, 'eval_runtime': 16.351, 'eval_samples_per_second': 18.348, 'eval_steps_per_second': 2.324, 'epoch': 0.96}
{'loss': 2.963, 'learning_rate': 3.6666666666666666e-05, 'epoch': 1.28}
{'loss': 2.6026, 'learning_rate': 3e-05, 'epoch': 1.92}


  0%|          | 0/38 [00:00<?, ?it/s]

{'eval_loss': 2.7908589839935303, 'eval_score': 0.45424080556392626, 'eval_counts': [441, 10, 2, 0], 'eval_totals': [2178, 1878, 1578, 1278], 'eval_precisions': [20.24793388429752, 0.5324813631522897, 0.1267427122940431, 0.03912363067292645], 'eval_bp': 0.9446573913851566, 'eval_sys_len': 2178, 'eval_ref_len': 2302, 'eval_runtime': 12.94, 'eval_samples_per_second': 23.184, 'eval_steps_per_second': 2.937, 'epoch': 1.98}
{'loss': 2.3403, 'learning_rate': 2.3333333333333336e-05, 'epoch': 2.56}


  0%|          | 0/38 [00:00<?, ?it/s]

{'eval_loss': 2.757530927658081, 'eval_score': 1.0492544269491462, 'eval_counts': [488, 21, 6, 2], 'eval_totals': [2133, 1833, 1533, 1233], 'eval_precisions': [22.878574777308955, 1.1456628477905073, 0.3913894324853229, 0.16220600162206], 'eval_bp': 0.9238263759026545, 'eval_sys_len': 2133, 'eval_ref_len': 2302, 'eval_runtime': 15.562, 'eval_samples_per_second': 19.278, 'eval_steps_per_second': 2.442, 'epoch': 2.94}
{'loss': 2.1607, 'learning_rate': 1.6666666666666667e-05, 'epoch': 3.2}
{'loss': 2.0094, 'learning_rate': 1e-05, 'epoch': 3.84}


  0%|          | 0/38 [00:00<?, ?it/s]

{'eval_loss': 2.75868821144104, 'eval_score': 1.25598793081679, 'eval_counts': [495, 31, 8, 2], 'eval_totals': [2070, 1770, 1470, 1170], 'eval_precisions': [23.91304347826087, 1.7514124293785311, 0.54421768707483, 0.17094017094017094], 'eval_bp': 0.8939751553008631, 'eval_sys_len': 2070, 'eval_ref_len': 2302, 'eval_runtime': 15.367, 'eval_samples_per_second': 19.522, 'eval_steps_per_second': 2.473, 'epoch': 3.97}
{'loss': 1.9146, 'learning_rate': 3.3333333333333333e-06, 'epoch': 4.48}


  0%|          | 0/38 [00:00<?, ?it/s]

{'eval_loss': 2.7598557472229004, 'eval_score': 0.6597507712672607, 'eval_counts': [485, 26, 3, 0], 'eval_totals': [2097, 1797, 1497, 1197], 'eval_precisions': [23.12827849308536, 1.4468558708959376, 0.20040080160320642, 0.04177109440267335], 'eval_bp': 0.9068677018936455, 'eval_sys_len': 2097, 'eval_ref_len': 2302, 'eval_runtime': 16.298, 'eval_samples_per_second': 18.407, 'eval_steps_per_second': 2.332, 'epoch': 4.8}
{'train_runtime': 174.7395, 'train_samples_per_second': 28.614, 'train_steps_per_second': 0.429, 'train_loss': 2.50732053120931, 'epoch': 4.8}


TrainOutput(global_step=75, training_loss=2.50732053120931, metrics={'train_runtime': 174.7395, 'train_samples_per_second': 28.614, 'train_steps_per_second': 0.429, 'train_loss': 2.50732053120931, 'epoch': 4.8})

In [16]:
def generate_ids(batch):
    pass
    batch["gen_ids"] = model.generate(
        batch["input_ids"].cuda(),
        attention_mask=batch["attention_mask"].cuda()
    )
    return batch

In [17]:
model.cuda()
model.eval()
train_out = train_data.map(
    generate_ids, 
    batched=True, batch_size=16,
    remove_columns=["input_ids", "attention_mask"]
)
valid_out = valid_data.map(
    generate_ids, 
    batched=True, batch_size=16,
    remove_columns=["input_ids", "attention_mask"]
)

In [14]:
print(compute_metrics((train_out["gen_ids"], train_out["labels"])))

train_output = decoder_tokenizer.batch_decode(train_out["gen_ids"], skip_special_tokens=True)

{'score': 1.1970624690859315, 'counts': [1577, 88, 22, 9], 'totals': [7609, 6609, 5609, 4609], 'precisions': [20.725456696017872, 1.3315176274776819, 0.3922267783918702, 0.19527012367107832], 'bp': 0.9927977789048245, 'sys_len': 7609, 'ref_len': 7664}


In [15]:
compute_metrics((valid_out["gen_ids"], valid_out["labels"]))

valid_output = decoder_tokenizer.batch_decode(valid_out["gen_ids"], skip_special_tokens=True)

KeyError: "Column gen_ids not in the dataset. Current columns in the dataset: ['input_ids', 'attention_mask', 'labels', 'gen_sentence']"

In [None]:
def print_pairs(dataset, generation, sample=10):
    assert len(dataset) == len(generation), "Invalid combination!"
    target = dataset[f"{target_lng}_sentence"]
    source = dataset[f"{source_lng}_sentence"]

    sample_ids = random.sample(len(dataset), sample)
    for i, tpl in enumerate(zip(source[sample_ids], target[sample_ids], generation[sample_ids])):
        print(f"Sentence #{i} [id={sample_ids[i]}]\n")
        print(f"Original:  {tpl[0]}\n\tTarget:    {tpl[1]}\n\tGenerated: {tpl[2]}\n")
    return

print_pairs(valid[f"{target_lng}_sentence"], valid_output)