In [1]:
import re
import os
import json
import pickle
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime
from collections import OrderedDict
from bs4 import BeautifulSoup, element
import torch

from transformer.utils.tokenizer import MecabTokenizer, SpmTokenizer
from transformer.trainer.utils import ModelFilenameConstants
from transformer.trainer.utils import load_state_dict
from transformer.preprocessor.transformer_preprocessor import TransformerPreprocessor, DialogPretrainPreprocessor
from transformer.data.dataset import DatasetFromDir
from transformer.data.transformer_data_loader import DialogPretrainDataLoader
from transformer.models.transformer import Transformer
from transformer.trainer.interface import TrainResult, TrainHistory
from transformer.trainer.transformer_trainer import TransformerDialogPreTrainer
from transformer.preprocessor.utils import split_segment_by_speaker_ids, convert_turn_ids, flatten_sequence
from transformer.utils.common import read_all_files_from_dir, set_seed, get_now_str, is_empty_row_in_dict, reset_path

## Load Filepath

In [2]:
# # AIBUD_DEV
# dataset_dir = "/Users/aibud_dev/_jupyter"
# path = "./config/file_path.json"
# file_path = None
# with open(path, "r", encoding="utf-8") as fp:
#     file_path = json.load(fp)

# # Picas_Server
# dataset_dir = "/home/picas/_jupyter"
# path = "./config/file_path.json"
# file_path = None
# with open(path, "r", encoding="utf-8") as fp:
#     file_path = json.load(fp)

# # Korea_Server
# dataset_dir = "/home/guest1"
# path = "./config/file_path.json"
# file_path = None
# with open(path, "r", encoding="utf-8") as fp:
#     file_path = json.load(fp)

# bigshane_local
dataset_dir = "D:\_jupyter"
path = "./config/file_path.json"
file_path = None
with open(path, "r", encoding="utf-8") as fp:
    file_path = json.load(fp)

## Set Directories

In [3]:
model_dir = dataset_dir + "/model/transformer_dialog_pretrain/20210721/"
train_config_path = model_dir + ModelFilenameConstants.TRAIN_CONFIG_FILENAME
model_state_dict_path = model_dir + ModelFilenameConstants.MODEL_STATE_DICT_FILENAME
optimizer_state_dict_path = model_dir + ModelFilenameConstants.OPTIMIZER_STATE_DICT_FILENAME
src_spm_model_path = model_dir + ModelFilenameConstants.SRC_SPM_MODEL_DIR
tgt_spm_model_path = model_dir + ModelFilenameConstants.SRC_SPM_MODEL_DIR
history_path = model_dir + ModelFilenameConstants.HISTORY_FILENAME

In [4]:
# Load config
config = None
with open(train_config_path, "r", encoding="utf-8") as fp:
    config = json.load(fp)

## Set Device

In [5]:
batch_size = 8
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
nprocs = 1

## Load Trainer & Preprocessor

In [6]:
# Load trainer
dialog_pretrainer = TransformerDialogPreTrainer()
# dialog_pretrainer.set_lr_update(d_model=d_model, num_warmup_steps=num_warmup_steps)

# Load prep
trfr_prep = DialogPretrainPreprocessor(src_language=config["data"]["src_language"], src_spm_model_path=src_spm_model_path, tgt_language=config["data"]["tgt_language"], tgt_spm_model_path=tgt_spm_model_path, embedding_dict=config["model"]["embedding_dict"])

'temp_dir' has been set to './20210728_113127/{mode}_{idx}/' to save model while training
Cannot Import konlpy Mecab tagger: <class 'Exception'> - Install MeCab in order to use it: http://konlpy.org/en/latest/install/
Importing MeCab for Windows
Imported MeCab for Windows successfully
loaded spm_model: 'D:\_jupyter/model/transformer_dialog_pretrain/20210721/src_spm_model/'


## Load Dataset & DataLoader

In [7]:
dataset_name = "KaggleConversation"
data_dir = dataset_dir + "/dataset/conversation/{dataset_name}/{language}/multi_turn/".format(dataset_name=dataset_name, language=config["data"]["src_language"])
dataset = DatasetFromDir(data_dir=data_dir, batch_size=batch_size, encoding=config["data"]["encoding"], extension=config["data"]["extension"], device=device, nprocs=nprocs)
dialog_data_loader_params = TransformerDialogPreTrainer.get_data_loader_params(dataset=dataset, preprocessor=trfr_prep, batch_size=batch_size, 
                                                                               device=device, nprocs=nprocs, num_workers=dialog_pretrainer.num_workers, pin_memory=dialog_pretrainer.pin_memory,
                                                                               **config["model"], **config["data_loader"])
kaggle_conversation_data_loader = dialog_pretrainer.create_data_loader(**dialog_data_loader_params)

##### nrpocs:1, device_index:0, num_workers:2, per_proc:849, worker_id:0, iter_start:0, iter_end:424
##### nrpocs:1, device_index:0, num_workers:2, per_proc:849, worker_id:1, iter_start:424, iter_end:848


In [8]:
# dataset_name = "dialog_pretrain"
# # data_dir = dataset_dir + "/dataset/preprocessed/{dataset_name}/{language}/multi_turn/".format(dataset_name=dataset_name, language=config["data"]["src_language"])
# data_dir = dataset_dir + "/dataset/preprocessed/{dataset_name}/{language}/multi_turn/sample/".format(dataset_name=dataset_name, language=config["data"]["src_language"])
# dataset = DatasetFromDir(data_dir=data_dir, batch_size=batch_size, encoding=config["data"]["encoding"], extension=config["data"]["extension"], device=device, nprocs=nprocs)

# dialog_data_loader_params = TransformerDialogPreTrainer.get_data_loader_params(dataset=dataset, preprocessor=trfr_prep, batch_size=batch_size, 
#                                                                                device=device, nprocs=nprocs, num_workers=dialog_pretrainer.num_workers, pin_memory=dialog_pretrainer.pin_memory,
#                                                                                **config["model"], **config["data_loader"])
# dialog_pretrain_data_loader = dialog_pretrainer.create_data_loader(**dialog_data_loader_params)

## Load Model

In [9]:
# Load bert
transformer = Transformer(src_timesteps=config["model"]["src_timesteps"], tgt_timesteps=config["model"]["tgt_timesteps"], src_vocab_size=config["model"]["src_vocab_size"], tgt_vocab_size=config["model"]["tgt_vocab_size"], 
                          embedding_dict=config["model"]["embedding_dict"], src_pad_token_id=trfr_prep.src_spm_tokenizer.special_token_dict["pad"]["id"], tgt_pad_token_id=trfr_prep.tgt_spm_tokenizer.special_token_dict["pad"]["id"], 
                          d_model=config["model"]["d_model"], d_ff=config["model"]["d_ff"], num_heads=config["model"]["num_heads"], 
                          num_encoder_layers=config["model"]["num_encoder_layers"], num_decoder_layers=config["model"]["num_decoder_layers"], shared_embedding=config["model"]["shared_embedding"], 
                          dropout=config["model"]["dropout"], pwff_activation=config["model"]["pwff_activation"], linear_activation=config["model"]["linear_activation"], 
                          bias=config["model"]["bias"], layer_norm_epsilon=config["model"]["layer_norm_epsilon"], initialization=config["model"]["initialization"])
transformer = load_state_dict(object=transformer, path=model_state_dict_path, map_location=device)

# Load optimizer & criterions
optimizer = dialog_pretrainer.get_optimizer(model=transformer, **config["train"]["optimizer"])
optimizer = load_state_dict(object=optimizer, path=optimizer_state_dict_path, map_location=device)
criterions, criterion_weights = dialog_pretrainer.get_criterions(tgt_timesteps=config["model"]["tgt_timesteps"], tgt_vocab_size=config["model"]["tgt_vocab_size"], tgt_pad_token_id=trfr_prep.tgt_spm_tokenizer.special_token_dict["pad"]["id"], **config["train"]["criterion_weights"])

## Set Device

In [10]:
transformer = TransformerDialogPreTrainer.set_device(obj=transformer, device=device)
optimizer = TransformerDialogPreTrainer.set_device(obj=optimizer, device=device)
criterions = TransformerDialogPreTrainer.set_device(obj=criterions, device=device)

Setting model device: cuda:0
Setting criterions device: cuda:0


## Train Test

In [11]:
with open(data_dir + os.listdir(data_dir)[0], "r", encoding="utf-8") as fp:
    data = json.load(fp)

In [17]:
epoch = 5
verbose_per_batch = 50
amp = True
scaler = torch.cuda.amp.GradScaler()
# amp = False
# scaler = None

transformer.train()

In [None]:
dialog_pretrainer.fit(model, train_data_loader: torch.utils.data.DataLoader, val_data_loader: torch.utils.data.DataLoader, criterions, criterion_weights, optimizer, device, epoch: int = 1, amp: bool = False,
            save_per_epoch: int = 1, save_per_batch: int = -1, keep_last: bool = True, verbose_per_epoch: int = -1, verbose_per_batch: int = -1)

In [None]:
for i in range(0, epoch):
    epoch_train_history = None
    epoch_result = TrainResult(criterions=criterions)
    batch_result = TrainResult(criterions=criterions)

    with torch.set_grad_enabled(True):
        batch_iter_size = len(data)//batch_size + 1
        for batch_idx in tqdm(range(0, batch_iter_size)):
            batch = data[batch_size*batch_idx:batch_size*(batch_idx+1)]

            _src_inputs = []
            _tgt_inputs = []
            for row in batch:
                src_input_row, tgt_input_row = kaggle_conversation_data_loader.parse_json(row)
                _src_inputs.append(src_input_row)
                _tgt_inputs.append(tgt_input_row)

            src_inputs, tgt_inputs, tgt_outputs = kaggle_conversation_data_loader.preprocessor.encode(src_inputs=_src_inputs, tgt_inputs=_tgt_inputs,
                                                                                                      src_timesteps=kaggle_conversation_data_loader.src_timesteps, tgt_timesteps=kaggle_conversation_data_loader.tgt_timesteps,
                                                                                                      src_sep_tokens=kaggle_conversation_data_loader.src_sep_tokens, approach=kaggle_conversation_data_loader.approach)
            batch = src_inputs, tgt_inputs, tgt_outputs
            batch = [{k: dialog_pretrainer.convert_to_tensor(data=v, device=device) for k, v in _batch.items()} for _batch in batch]

            is_empty_flag = [is_empty_row_in_dict(data=_batch) for _batch in batch]
            if any(is_empty_flag): continue
            # iteration
            loss_dict, acc_dict = dialog_pretrainer.iteration(model=transformer, batch=batch,
                                                              criterions=criterions, criterion_weights=criterion_weights,
                                                              optimizer=optimizer, train=True, amp=amp, scaler=scaler)
            # update train_result instance
            batch_result.update(loss_dict=loss_dict, acc_dict=acc_dict, iteration=1, lr=optimizer.param_groups[0]["lr"])
            # verbose
            if verbose_per_batch > 0 and batch_idx % verbose_per_batch == 0:
                batch_result.freeze()
                print(dialog_pretrainer.verbose_template.format(mode="\nBatch_train", device=device, idx=batch_idx, num_iters=batch_iter_size), batch_result)
                if epoch_train_history is None: epoch_train_history = batch_result.to_train_history()
                else: epoch_train_history = epoch_train_history + batch_result.to_train_history()
                epoch_result.merge_with(train_result=batch_result)
                batch_result = TrainResult(criterions=criterions)

        if batch_result.iteration > 0:
            epoch_result.merge_with(train_result=batch_result)

    dialog_pretrainer.epoch += 1
    epoch_result.freeze()

  1%|▊                                                                                 | 1/107 [00:01<02:13,  1.26s/it]


Batch_train (cuda:0) [ 0 /107]: total_loss: 3.470e-01, lm_loss: 6.935e-01, ul_loss: 3.668e-04,  | total_acc: 8.735e-01, lm_acc: 8.193e-01, ul_acc: 9.277e-01,  | train_time: 1.0s, lr:  0.0002000000


 48%|██████████████████████████████████████▌                                          | 51/107 [00:58<01:04,  1.15s/it]


Batch_train (cuda:0) [50 /107]: total_loss: 3.531e-01, lm_loss: 7.014e-01, ul_loss: 4.792e-03,  | total_acc: 8.917e-01, lm_acc: 8.406e-01, ul_acc: 9.427e-01,  | train_time: 56.0s, lr:  0.0002000000


 94%|███████████████████████████████████████████████████████████████████████████▌    | 101/107 [01:43<00:06,  1.14s/it]


Batch_train (cuda:0) [100/107]: total_loss: 4.262e-01, lm_loss: 8.044e-01, ul_loss: 4.790e-02,  | total_acc: 8.814e-01, lm_acc: 8.240e-01, ul_acc: 9.388e-01,  | train_time: 45.0s, lr:  0.0002000000


100%|████████████████████████████████████████████████████████████████████████████████| 107/107 [01:49<00:00,  1.03s/it]
  1%|▊                                                                                 | 1/107 [00:01<02:02,  1.15s/it]


Batch_train (cuda:0) [ 0 /107]: total_loss: 3.541e-01, lm_loss: 7.081e-01, ul_loss: 1.028e-05,  | total_acc: 8.735e-01, lm_acc: 8.072e-01, ul_acc: 9.398e-01,  | train_time: 1.0s, lr:  0.0002000000


 48%|██████████████████████████████████████▌                                          | 51/107 [00:58<01:04,  1.16s/it]


Batch_train (cuda:0) [50 /107]: total_loss: 3.506e-01, lm_loss: 6.815e-01, ul_loss: 1.976e-02,  | total_acc: 8.938e-01, lm_acc: 8.472e-01, ul_acc: 9.404e-01,  | train_time: 57.0s, lr:  0.0002000000


 94%|███████████████████████████████████████████████████████████████████████████▌    | 101/107 [01:43<00:06,  1.15s/it]


Batch_train (cuda:0) [100/107]: total_loss: 4.103e-01, lm_loss: 7.774e-01, ul_loss: 4.329e-02,  | total_acc: 8.874e-01, lm_acc: 8.369e-01, ul_acc: 9.378e-01,  | train_time: 45.0s, lr:  0.0002000000


100%|████████████████████████████████████████████████████████████████████████████████| 107/107 [01:49<00:00,  1.03s/it]
  1%|▊                                                                                 | 1/107 [00:01<02:02,  1.16s/it]


Batch_train (cuda:0) [ 0 /107]: total_loss: 3.778e-01, lm_loss: 7.531e-01, ul_loss: 2.428e-03,  | total_acc: 8.675e-01, lm_acc: 8.072e-01, ul_acc: 9.277e-01,  | train_time: 1.0s, lr:  0.0002000000


 48%|██████████████████████████████████████▌                                          | 51/107 [00:58<01:04,  1.16s/it]


Batch_train (cuda:0) [50 /107]: total_loss: 3.731e-01, lm_loss: 7.268e-01, ul_loss: 1.944e-02,  | total_acc: 8.917e-01, lm_acc: 8.413e-01, ul_acc: 9.420e-01,  | train_time: 57.0s, lr:  0.0002000000


 94%|███████████████████████████████████████████████████████████████████████████▌    | 101/107 [01:44<00:06,  1.15s/it]


Batch_train (cuda:0) [100/107]: total_loss: 4.493e-01, lm_loss: 8.545e-01, ul_loss: 4.417e-02,  | total_acc: 8.768e-01, lm_acc: 8.123e-01, ul_acc: 9.413e-01,  | train_time: 45.0s, lr:  0.0002000000


100%|████████████████████████████████████████████████████████████████████████████████| 107/107 [01:50<00:00,  1.03s/it]
  1%|▊                                                                                 | 1/107 [00:01<02:02,  1.16s/it]


Batch_train (cuda:0) [ 0 /107]: total_loss: 3.386e-01, lm_loss: 6.771e-01, ul_loss: 1.741e-04,  | total_acc: 8.795e-01, lm_acc: 8.313e-01, ul_acc: 9.277e-01,  | train_time: 1.0s, lr:  0.0002000000


 48%|██████████████████████████████████████▌                                          | 51/107 [00:58<01:04,  1.16s/it]


Batch_train (cuda:0) [50 /107]: total_loss: 3.491e-01, lm_loss: 6.887e-01, ul_loss: 9.480e-03,  | total_acc: 8.960e-01, lm_acc: 8.497e-01, ul_acc: 9.424e-01,  | train_time: 57.0s, lr:  0.0002000000


 94%|███████████████████████████████████████████████████████████████████████████▌    | 101/107 [01:44<00:06,  1.16s/it]


Batch_train (cuda:0) [100/107]: total_loss: 4.063e-01, lm_loss: 7.520e-01, ul_loss: 6.063e-02,  | total_acc: 8.863e-01, lm_acc: 8.337e-01, ul_acc: 9.390e-01,  | train_time: 45.0s, lr:  0.0002000000


100%|████████████████████████████████████████████████████████████████████████████████| 107/107 [01:50<00:00,  1.03s/it]
  1%|▊                                                                                 | 1/107 [00:01<02:03,  1.16s/it]


Batch_train (cuda:0) [ 0 /107]: total_loss: 3.333e-01, lm_loss: 6.666e-01, ul_loss: 4.893e-05,  | total_acc: 8.795e-01, lm_acc: 8.313e-01, ul_acc: 9.277e-01,  | train_time: 1.0s, lr:  0.0002000000


 48%|██████████████████████████████████████▌                                          | 51/107 [00:59<01:05,  1.18s/it]


Batch_train (cuda:0) [50 /107]: total_loss: 3.208e-01, lm_loss: 6.355e-01, ul_loss: 6.175e-03,  | total_acc: 8.995e-01, lm_acc: 8.597e-01, ul_acc: 9.392e-01,  | train_time: 57.0s, lr:  0.0002000000


 90%|████████████████████████████████████████████████████████████████████████▋        | 96/107 [01:40<00:12,  1.17s/it]

## Inference Test

In [None]:
def get_predictions_and_targets(data_loader):
    for batch_idx in range(0, len(data)//batch_size + 1):
        batch = data[batch_size*batch_idx:batch_size*(batch_idx+1)]

        _src_inputs = []
        _tgt_inputs = []
        for row in batch:
            src_input_row, tgt_input_row = data_loader.parse_json(row)
            _src_inputs.append(src_input_row)
            _tgt_inputs.append(tgt_input_row)

        src_inputs, tgt_inputs, tgt_outputs = data_loader.preprocessor.encode(src_inputs=_src_inputs, tgt_inputs=_tgt_inputs,
                                                                                                  src_timesteps=data_loader.src_timesteps, tgt_timesteps=data_loader.tgt_timesteps,
                                                                                                  src_sep_tokens=data_loader.src_sep_tokens, approach=data_loader.approach)
        batch = src_inputs, tgt_inputs, tgt_outputs

        batch = [{k: dialog_pretrainer.convert_to_tensor(data=v, device=device) for k, v in _batch.items()} for _batch in batch]
        src_inputs, tgt_inputs, tgt_outputs = batch

        # inference
        _predictions = transformer(src_inputs=src_inputs, tgt_inputs=tgt_inputs)
        lm_predictions = dialog_pretrainer.convert_to_numpy(tensor=_predictions["lm"])
        lm_predictions = np.argmax(lm_predictions, axis=-1)
        lm_targets = dialog_pretrainer.convert_to_numpy(tensor=tgt_outputs["lm"])
        yield lm_predictions, lm_targets
        
eos_token_id = trfr_prep.tgt_spm_tokenizer.special_token_dict["eos"]["id"]
pad_token_id = trfr_prep.tgt_spm_tokenizer.special_token_dict["pad"]["id"]

In [None]:
with open(data_dir + os.listdir(data_dir)[0], "r", encoding="utf-8") as fp:
    data = json.load(fp)
gen = get_predictions_and_targets(data_loader=kaggle_conversation_data_loader)
# gen = get_predictions_and_targets(data_loader=dialog_pretrain_data_loader)

In [None]:
lm_predictions, lm_targets = next(gen)
label_weights = (lm_targets != pad_token_id).astype(float)
correct = (lm_predictions == lm_targets).astype(float)
label_correct = correct * label_weights

batch_accuracy = np.mean(np.sum(label_correct, axis=-1) / np.sum(label_weights, axis=-1))
print("batch_accuracy:", np.round(batch_accuracy, 5), "\n")

p = trfr_prep.tgt_decode(lm_predictions.tolist(), eos_token_id=eos_token_id)
t = trfr_prep.tgt_decode(lm_targets.tolist(), eos_token_id=eos_token_id)
for _p, _t in zip(p, t):
    print("pred:", _p)
    print("targ:", _t)
    print()

## Configuration

In [27]:
dataset_name = "KaggleConversation"
language = "kor"
encoding = "utf-8"

vocab_size = 15000
timesteps = 192
embedding_dict = {"segment":2, "turn":2}
sep_tokens = [["cls", "sep"], [None, "sep"]] # [["context", "sep"], ["candidate", "sep"]]
approach = "ignore"
nprocs = 1

# training_params
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 8

In [33]:
dialog_pretrainer.num_workers

0

False

## Load Preprocessor

In [13]:
spm_model_path = "./data/{language}/spoken_pretrain_spm_v{vocab_size}".format(language=language, vocab_size=vocab_size)
bert_dialog_prep = DialogPretrainPreprocessor(language=language, spm_model_path=spm_model_path, embedding_dict=embedding_dict)
pad_token_id = bert_dialog_prep.spm_tokenizer.special_token_dict["pad"]["id"]

Imported konlpy.tag.Mecab successfully
loaded spm_model: './data/kor/spoken_pretrain_spm_v15000/'


## Load Dataset & DataLoader

In [35]:
root_dir = file_path[dataset_name]["root_dir"].format(dataset_dir=dataset_dir)
raw_dir = file_path[dataset_name]["raw_dir"].format(root_dir=root_dir, language=language)
multi_turn_data_dir = file_path[dataset_name]["feed_data"]["multi_turn"].format(root_dir=root_dir, language=language)
multi_turn_data_extension = "json"

dialog_pretrainer = BertDialogPreTrainer()
dialog_data_loader_params = BertDialogPreTrainer.get_data_loader_params(timesteps=timesteps, embedding_dict=embedding_dict, nprocs=nprocs, sep_tokens=sep_tokens, approach=approach)
dialog_pretrain_dataset = DatasetFromDir(data_dir=multi_turn_data_dir, batch_size=batch_size, encoding=encoding, extension=multi_turn_data_extension, device=device, nprocs=nprocs)
dialog_pretrain_data_loader = dialog_pretrainer.create_data_loader(dataset=dialog_pretrain_dataset, batch_size=batch_size, 
                                                                   num_workers=dialog_pretrainer.num_workers, pin_memory=dialog_pretrainer.pin_memory, preprocessor=bert_dialog_prep, device=device, **dialog_data_loader_params)

'temp_dir' has been set to ./20210710_130826/{mode}_{idx}/ to save model while training
worker_info: None


## Check History files

In [7]:
import pickle
history_path = model_dir + ModelFilenameConstants.HISTORY_FILENAME
with open(history_path, "rb") as fp:
    history = pickle.load(fp)
dict(history)

{'train': OrderedDict([('iteration', [44089]),
              ('lr', [0.0001]),
              ('train_time', [27746]),
              ('loss',
               {'mlm': [0.39459580843510916],
                'nsp': [0.2841686550899227],
                'total_loss': [0.3393822317625159]}),
              ('acc',
               {'mlm': [0.8877814241079436],
                'nsp': [0.6865681933667592],
                'total_acc': [0.7871748087373515]})]),
 'val': None}

In [None]:
def inference(self, sequences: List[List[str]], preprocessor: Preprocessor, device:str, method: str = "greedy") -> str:
    self.to(device, non_blocking=True)
    self.assert_isin_methods(method=method)

    src_input_row = preprocessor.src_encode(sentence=src_sentence, mask=False)
    if not preprocessor.is_proper_length(ids=src_input_row, upper_bound=self.src_timesteps): preprocessor._raise_approach_error(approach="stop")
    src_input_row = preprocessor.pad_row(ids=src_input_row, timesteps=self.src_timesteps, padding_value=self.src_pad_token_id)
    tgt_input_row = [self.tgt_pad_token_id] * self.tgt_timesteps
    src_inputs = np.expand_dims(np.array(src_input_row), axis=0)
    tgt_inputs = np.expand_dims(np.array(tgt_input_row), axis=0)

    if device is None:
        src_inputs = torch.from_numpy(src_inputs)
        tgt_inputs = torch.from_numpy(tgt_inputs)
    else:
        src_inputs = torch.from_numpy(src_inputs).to(device)
        tgt_inputs = torch.from_numpy(tgt_inputs).to(device)

    tgt_bos_token_id = preprocessor.tgt_spm_tokenizer.special_token_dict["bos"]["id"]
    tgt_eos_token_id = preprocessor.tgt_spm_tokenizer.special_token_dict["eos"]["id"]
    if method == "greedy":
        return self._inference_greedy(src_inputs=src_inputs, tgt_inputs=tgt_inputs, tgt_bos_token_id=tgt_bos_token_id, tgt_eos_token_id=tgt_eos_token_id)
    elif method == "beam_search":
        return self._inference_beam_search(src_inputs=src_inputs, tgt_inputs=tgt_inputs, tgt_bos_token_id=tgt_bos_token_id, tgt_eos_token_id=tgt_eos_token_id)

def _inference_greedy(self, src_inputs, tgt_inputs, tgt_bos_token_id, tgt_eos_token_id):
    output = []
    next_token_id = tgt_bos_token_id
    for timestep in range(0, self.tgt_timesteps):
        tgt_inputs[0][timestep] = next_token_id
        prediction_rows = self.forward(src_inputs=src_inputs, tgt_inputs=tgt_inputs)
        prediction_row = torch.argmax(prediction_rows, dim=-1)[0]
        next_token_id = prediction_row[timestep].tolist()
        output.append(next_token_id)
        if next_token_id == tgt_eos_token_id: break
    return output