# mbart-large-cc25 finetuning
Example notebook for ja-en finetuning based on [facebook/mbart-large-cc25](https://huggingface.co/facebook/mbart-large-cc25)

## Install pytorch

In [2]:
!pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html

## Install dependencies

In [8]:
!pip install --upgrade pip
!pip install transformers["ja"] numpy pandas sentencepiece fairseq
!pip install -U jupyter ipywidgets 

## Download training dataset
今回はJESCのデータセットを利用します。  
このような大規模なデータセットを公開していただいていることに感謝します。  
This time we will use the [JESC dataset](https://nlp.stanford.edu/projects/jesc/index_ja.html) .  
Thank you for publishing such a large dataset.

In [5]:
!wget "https://nlp.stanford.edu/projects/jesc/data/split.tar.gz"
!tar -zxvf split.tar.gz
!ls split

## Create training data for tokenizer
Sentencepieceの学習に利用するデータを作成します。  
Create the data used for learning the sentence piece.

In [8]:
res = []
for line in open('split/train', 'r', encoding='utf-8'):
    text = line.split('\t')
    text = [t.rstrip('\n') for t in text]
    res.extend(text)
for line in open('split/dev', 'r', encoding='utf-8'):
    text = line.split('\t')
    text = [t.rstrip('\n') for t in text]
    res.extend(text)
for line in open('split/test', 'r', encoding='utf-8'):
    text = line.split('\t')
    text = [t.rstrip('\n') for t in text]
    res.extend(text)

print(len(res))
with open('tmp.txt', 'w') as f:
    for d in res:
        f.write("%s\n" % d)

In [3]:
!tail tmp.txt

## Training tokenizer
Sentencepieceの学習をします。  

In [5]:
import sentencepiece as spm

# @NOTE
# ボキャブラリーのサイズは適宜変更してください。
# Please change the size of the vocabulary accordingly.
spm.SentencePieceTrainer.Train("--input=tmp.txt --model_prefix=new_spm_model --vocab_size=64000 --vocabulary_output_piece_score=false --model_type=bpe")

## Download pre-trained model
後述の作業で必要になるので、huggingfaceではなくfairseqから直接事前学習済みモデルをダウンロードしてきます  
Download the pre-trained model directly from fairseq instead of huggingface as you will need it for the tasks described below.

In [21]:
!wget "https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.v2.tar.gz"
!tar -zxvf mbart.cc25.v2.tar.gz
!ls mbart.cc25.v2

## Weight reduction of pre trained model
ベースとなるモデル([facebook/mbart-large-cc25](https://huggingface.co/facebook/mbart-large-cc25))の軽量化をしていきます。  
この工程は主にベースモデルのサイズが巨大でバッチサイズが1でも学習ができないという問題を解決するために実行します。  
※[こちらのissue](https://github.com/pytorch/fairseq/issues/2120)で語られている問題です。  
We will reduce the weight of the base model.  
This process is mainly performed to solve the problem that the size of the base model is huge and even if the batch size is 1, it cannot be trained.  
See this [issue](https://github.com/pytorch/fairseq/issues/2120) for more details.

### 補足
ここで実行する解決方法は、先ほど作成したボキャブラリーファイルをベースに、必要な単語の情報を残し不要な単語の情報を削除していくものです。  
ベースとなっているモデルではおよそ25万の単語が収録されたボキャブラリーを使用していますが、ここには25種類の言語の単語が収録されています。  
ファインチューニングでは目的となる言語を絞り込める(今回で言えば日本語と英語だけ)と思うので、必要な単語以外は切り落としてしまおうという手法です。  
[こちらで提案されている](https://github.com/pytorch/fairseq/issues/2120#issuecomment-633460071)ものです。  
以下のコードは主に[こちらのコメントのコード](https://github.com/pytorch/fairseq/issues/2120#issuecomment-647429120)がベースになっています。  
素敵な手法を提案された[fansiawang氏](https://github.com/fansiawang)とサンプルコードをコメントしてくださった[ddaspit氏](https://github.com/ddaspit)に感謝申し上げます。

### Formatting vocab
先ほどつくったvocabファイルはそのままではこのあとの工程で使えないので加工します。  
The vocab file created earlier cannot be used as it is in the subsequent process, so it will be processed.

In [17]:
edited = []
for line in open("new_spm_model.vocab", 'r', encoding='utf-8'):
    if line in ["<unk>\n", "<s>\n", "</s>\n"]:
        continue
    new_line = line.rstrip('\n') + " 1\n"
    edited.append(new_line)

with open('new_dict.txt', 'w') as f:
    for e in edited:
        f.write(e)

In [18]:
!ls

### Reduce to create a new model.
軽量化して新しいモデルを作成します。  

In [29]:
!mkdir reduced_model
!ls

In [31]:
from fairseq.data import Dictionary
from transformers import (
    MBartForConditionalGeneration, MBartTokenizer, MBartConfig
)
from typing import List
import torch

In [32]:
langs = [
    "ar_AR",
    "cs_CZ",
    "de_DE",
    "en_XX",
    "es_XX",
    "et_EE",
    "fi_FI",
    "fr_XX",
    "gu_IN",
    "hi_IN",
    "it_IT",
    "ja_XX",
    "kk_KZ",
    "ko_KR",
    "lt_LT",
    "lv_LV",
    "my_MM",
    "ne_NP",
    "nl_XX",
    "ro_RO",
    "ru_RU",
    "si_LK",
    "tr_TR",
    "vi_VN",
    "zh_CN"
]

def load_dict(langs: List[str], path: str) -> Dictionary:
    d = Dictionary.load(path)
    for ll in langs:
        d.add_symbol(f"[{ll}]")
    d.add_symbol("<mask>")
    d.add_symbol("<pad>")
    return d


pre_dict = load_dict(langs, "./mbart.cc25.v2/dict.txt")
ft_dict = load_dict(langs, "./new_dict.txt")

model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
org_sd = model.state_dict()
resized_sd = model.state_dict()

mapping: List[int] = []
for i in range(len(ft_dict)):
    word = ft_dict[i]
    mapping.append(pre_dict.index(word))

for name in ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "model.shared.weight", "lm_head.weight"]:
    pre_tensor: torch.Tensor = org_sd[name]
    ft_tensor = torch.zeros(
        [len(ft_dict), 1024], dtype=pre_tensor.dtype, layout=pre_tensor.layout, device=pre_tensor.device,
    )
    for ft_i, pre_i in enumerate(mapping):
        ft_tensor[ft_i] = pre_tensor[pre_i]
    resized_sd[name] = ft_tensor
resized_sd["final_logits_bias"] = resized_sd["final_logits_bias"][:, :len(ft_dict)]

config = MBartConfig.from_pretrained("facebook/mbart-large-cc25")
config.vocab_size = len(ft_dict)
print(config)
new_model = MBartForConditionalGeneration.from_pretrained(None, config=config, state_dict=resized_sd)
new_model.save_pretrained("./reduced_model")

In [33]:
!ls reduced_model

以上でベースモデルの軽量化が完了します。  
ここからは `reduced_model` ディレクトリをpre-trainedモデルとして利用していきます。  
This completes the weight reduction of the base model.  
From now on, we will use the `reduced_model` directory as a pre-trained model.

## Preparation of Tokenizer
今のままでは不足しているファイルがあるので取得します  
Get the missing files

In [35]:
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-cc25")
tokenizer.save_pretrained("./reduced_model")

モデルファイルは先ほど作成したもので上書きします  
Overwrite the model file with the one you created earlier

In [36]:
!mv ./new_spm_model.model ./reduced_model/sentencepiece.bpe.model

In [38]:
!ls -al ./reduced_model

以上でモデルとトークナイザー両方が `reduced_model` ディレクトリから呼べるようになります  
Now both the model and the tokenizer can be called from the `reduced_model` directory.

In [39]:
model = MBartForConditionalGeneration.from_pretrained("./reduced_model")
tokenizer = MBartTokenizer.from_pretrained("./reduced_model")

## Training
モデルとトークナイザーの準備ができたので、トレーニングを実行します。  
トレーニングのコードは[こちら](https://www.kaggle.com/ajax0564/mbart-finetuning-hintoenglish-translation)を参考にしています  
Now that the model and tokenizer are ready, it's time to start training.  
[Here](https://www.kaggle.com/ajax0564/mbart-finetuning-hintoenglish-translation) is the reference code.

In [43]:
!mkdir output

In [44]:
from transformers import (
    Seq2SeqTrainingArguments, Seq2SeqTrainer
)
import numpy as np
import re

result_dir = "./output"

In [45]:
def data_collator(features: list):
    x = [f["translation"]["ja"] for f in features]
    y = [f["translation"]["en"] for f in features]
    inputs = tokenizer(x, return_tensors="pt", padding='max_length', truncation=True, max_length=32)
    with tokenizer.as_target_tokenizer():
        inputs['labels'] = tokenizer(y, return_tensors="pt", padding='max_length', truncation=True, max_length=48)['input_ids']
    return inputs

tokenizer = MBartTokenizer.from_pretrained("./reduced_model", src_lang="ja_XX", tgt_lang="en_XX")
tokenizer.save_pretrained(result_dir)

In [48]:
train_data = []
eval_data = []

for line in open("./split/train", "r", encoding='utf-8'):
    text = line.split('\t')
    train_data.append(
        {"translation": {
            "ja": text[1].rstrip('\n'),
            "en": text[0].rstrip('\n')
        }}
    )
print(f"train_data size: {len(train_data)}")

for line in open("./split/dev", "r", encoding='utf-8'):
    text = line.split('\t')
    eval_data.append(
        {"translation": {
            "ja": text[1].rstrip('\n'),
            "en": text[0].rstrip('\n')
        }}
    )
print(f"eval_data size: {len(eval_data)}")

In [49]:
# Hyperparameters
batch_size = 1
learning_rate = 3e-5
epochs = 1


In [50]:
model = MBartForConditionalGeneration.from_pretrained("./reduced_model")

args = Seq2SeqTrainingArguments(output_dir=result_dir,
                                do_train=True,
                                do_eval=True,
                                per_device_train_batch_size=batch_size,
                                per_device_eval_batch_size=batch_size,
                                learning_rate=learning_rate,
                                num_train_epochs=epochs,
                                evaluation_strategy="epoch",
                                )

trainer = Seq2SeqTrainer(model=model,
                         args=args,
                         data_collator=data_collator,
                         train_dataset=train_data,
                         eval_dataset=eval_data,
                         )

In [51]:
trainer.train()
trainer.save_model(result_dir)

## Inference
できあがったモデルを使って推論を実行してみます  
Let's perform inference using the resulting model.

In [52]:
model = MBartForConditionalGeneration.from_pretrained("./output")
tokenizer = MBartTokenizer.from_pretrained("./output")

In [53]:
sentence = "おはよう"
inputs = tokenizer(sentence, return_tensors="pt")
translated_tokens = model.generate(**inputs, decoder_start_token_id=tokenizer.lang_code_to_id["en_XX"], early_stopping=True, max_length=48)
pred = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
print(f"日本語 - {sentence}: English - {pred}")