# Finetuning ID -> SU using IndoBART

In [1]:
import os, sys
sys.path.append('../')
os.chdir('../')

import torch
import shutil
import random
import numpy as np
import pandas as pd
from torch import optim
from transformers import MBartForConditionalGeneration

# from indobenchmark import IndoNLGTokenizer
from modules.temporary_pad import IndoNLGTokenizerHalim as IndoNLGTokenizer
from utils.train_eval import train, evaluate
from utils.metrics import generation_metrics_fn
from utils.forward_fn import forward_generation
from utils.data_utils import MachineTranslationDataset, GenerationDataLoader

In [2]:
###
# common functions
###
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())
    
# Set random seed
# set_seed(26092020)

# Load Model

In [3]:
bart_model = MBartForConditionalGeneration.from_pretrained('indobenchmark/indobart')
tokenizer = IndoNLGTokenizer.from_pretrained('indobenchmark/indobart')

model = bart_model
model

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'IndoNLGTokenizer'. 
The class this function is called from is 'IndoNLGTokenizerHalim'.


MBartForConditionalGeneration(
  (model): MBartModel(
    (shared): MBartScaledWordEmbedding(40004, 768, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): MBartScaledWordEmbedding(40004, 768, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0): MBartEncoderLayer(
          (self_attn): MBartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer

In [4]:
count_param(model)

131543040

# Prepare Dataset

In [5]:
# configs and args

lr = 1e-4
gamma = 0.9
lower = True
step_size = 1
beam_size = 5
max_norm = 10
early_stop = 5

max_seq_len = 512
grad_accumulate = 1
no_special_token = False
swap_source_target = True
model_type = 'indo-bart'
valid_criterion = 'SacreBLEU'

separator_id = 4
speaker_1_id = 5
speaker_2_id = 5

train_batch_size = 4
valid_batch_size = 4
test_batch_size = 4

source_lang = "[indonesian]"
target_lang = "[indonesian]"

optimizer = optim.Adam(model.parameters(), lr=lr)
src_lid = tokenizer.special_tokens_to_ids[source_lang]
tgt_lid = tokenizer.special_tokens_to_ids[target_lang]

model.config.decoder_start_token_id = tgt_lid

# Make sure cuda is deterministic
torch.backends.cudnn.deterministic = True

# create directory
model_dir = './save/qg/example_id_su'
if not os.path.exists(model_dir):
    os.makedirs(model_dir, exist_ok=True)

device = 'cuda0'
# set a specific cuda device
if "cuda" in device:
    torch.cuda.set_device(int(device[4:]))
    device = "cuda"
    model = model.cuda()

In [6]:
train_dataset_path = './dataset/squad/vin_train_qg_bart-trainset.json'
valid_dataset_path = './dataset/squad/vin_train_qg_bart-valset.json'
test_dataset_path = './dataset/squad/vin_test_qg_bart.json'

train_dataset = MachineTranslationDataset(train_dataset_path, tokenizer, lowercase=lower, no_special_token=no_special_token, 
                                            speaker_1_id=speaker_1_id, speaker_2_id=speaker_2_id, separator_id=separator_id,
                                            max_token_length=max_seq_len, swap_source_target=swap_source_target)
valid_dataset = MachineTranslationDataset(valid_dataset_path, tokenizer, lowercase=lower, no_special_token=no_special_token, 
                                            speaker_1_id=speaker_1_id, speaker_2_id=speaker_2_id, separator_id=separator_id,
                                            max_token_length=max_seq_len, swap_source_target=swap_source_target)
test_dataset = MachineTranslationDataset(test_dataset_path, tokenizer, lowercase=lower, no_special_token=no_special_token, 
                                            speaker_1_id=speaker_1_id, speaker_2_id=speaker_2_id, separator_id=separator_id,
                                            max_token_length=max_seq_len, swap_source_target=swap_source_target)

train_loader = GenerationDataLoader(dataset=train_dataset, model_type=model_type, tokenizer=tokenizer, max_seq_len=max_seq_len, 
                                    batch_size=train_batch_size, src_lid_token_id=src_lid, tgt_lid_token_id=tgt_lid, num_workers=8, shuffle=True)  
valid_loader = GenerationDataLoader(dataset=valid_dataset, model_type=model_type, tokenizer=tokenizer, max_seq_len=max_seq_len, 
                                    batch_size=valid_batch_size, src_lid_token_id=src_lid, tgt_lid_token_id=tgt_lid, num_workers=8, shuffle=False)
test_loader = GenerationDataLoader(dataset=test_dataset, model_type=model_type, tokenizer=tokenizer, max_seq_len=max_seq_len, 
                                   batch_size=test_batch_size, src_lid_token_id=src_lid, tgt_lid_token_id=tgt_lid, num_workers=8, shuffle=False)

# Test model to generate sequences

In [7]:
inputs = ['aku pergi ke toko obat membeli <mask>']
bart_input = tokenizer.prepare_input_for_generation(inputs, return_tensors='pt',
                                         lang_token = '[indonesian]', decoder_lang_token='[indonesian]')

bart_input.to(device)
bart_out = model(**bart_input)
print(tokenizer.decode(bart_input['input_ids'][0]))
print(tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

<s> aku pergi ke toko obat membeli<mask></s>[indonesian]
<s> aku pergi ke toko obat membeli obat jeung[indonesian]


In [8]:
inputs = ['aku menyang pasar karo <mask>']
bart_input = tokenizer.prepare_input_for_generation(inputs, return_tensors='pt',
                                         lang_token = '[javanese]', decoder_lang_token='[javanese]')

bart_input.to(device)
bart_out = bart_model(**bart_input)
print(tokenizer.decode(bart_input['input_ids'][0]))
print(tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

<s> aku menyang pasar karo<mask></s>[javanese]
<s> aku menyang pasar karo tuku,[javanese]


In [9]:
inputs = ['kuring ka pasar senen meuli daging <mask>']
bart_input = tokenizer.prepare_input_for_generation(inputs, return_tensors='pt',
                                         lang_token = '[sundanese]', decoder_lang_token='[sundanese]')

bart_input.to(device)
bart_out = bart_model(**bart_input)
print(tokenizer.decode(bart_input['input_ids'][0]))
print(tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

<s> kuring ka pasar senen meuli daging<mask></s>[sundanese]
<s> kuring ka pasar senen meuli daging sapi, kuring


# Test model to translate

In [13]:
test_loss, test_metrics, test_hyp, test_label = evaluate(model, data_loader=test_loader, forward_fn=forward_generation, 
                                                         metrics_fn=generation_metrics_fn, model_type=model_type, 
                                                         tokenizer=tokenizer, beam_size=beam_size, 
                                                         max_seq_len=max_seq_len, is_test=True, 
                                                         device='cuda')

TESTING... : 100%|██████████| 741/741 [14:16<00:00,  1.16s/it]  


In [14]:
metrics_scores = []
result_dfs = []

metrics_scores.append(test_metrics)
result_dfs.append(pd.DataFrame({
    'hyp': test_hyp, 
    'label': test_label
}))

result_df = pd.concat(result_dfs)
metric_df = pd.DataFrame.from_records(metrics_scores)

print('== Prediction Result ==')
print(result_df.head())
print()

print('== Model Performance ==')
print(metric_df.describe())

result_df.to_csv(model_dir + "/prediction_result.csv")
metric_df.describe().to_csv(model_dir + "/evaluation_result.csv")

== Prediction Result ==
                                                 hyp  \
0                       negara apa normandia berada?   
1                            normandia di normandia?   
2                            negara mana asal norse?   
3                                   jadi naon norse?   
4   berapa pertama kali normandia mendapatkan ide...   

                                               label  
0   bangsa normandia (norman: nourmands; prancis:...  
1   bangsa normandia (norman: nourmands; prancis:...  
2   bangsa normandia (norman: nourmands; prancis:...  
3   bangsa normandia (norman: nourmands; prancis:...  
4   bangsa normandia (norman: nourmands; prancis:...  

== Model Performance ==
           BLEU  SacreBLEU    ROUGE1    ROUGE2    ROUGEL  ROUGELsum
count  1.000000   1.000000  1.000000  1.000000  1.000000   1.000000
mean   0.000014   0.000014  9.207036  3.583701  7.610349   7.608654
std         NaN        NaN       NaN       NaN       NaN        NaN
min    0.000

# Fine Tuning & Evaluation

In [10]:
# Train

n_epochs = 1

train(model, train_loader=train_loader, valid_loader=valid_loader, optimizer=optimizer, 
      forward_fn=forward_generation, metrics_fn=generation_metrics_fn, valid_criterion=valid_criterion, 
      tokenizer=tokenizer, n_epochs=n_epochs, evaluate_every=1, early_stop=early_stop, 
      grad_accum=grad_accumulate, step_size=step_size, gamma=gamma, 
      max_norm=max_norm, model_type=model_type, beam_size=beam_size,
      max_seq_len=max_seq_len, model_dir=model_dir, exp_id=0, fp16="", device=device)

(Epoch 1) TRAIN LOSS:3.3869 LR:0.00010000: 100%|██████████| 6000/6000 [25:28<00:00,  3.93it/s] 


(Epoch 1) TRAIN LOSS:3.3869 BLEU:14.64 SacreBLEU:15.54 ROUGE1:43.92 ROUGE2:11.97 ROUGEL:33.37 ROUGELsum:33.36 LR:0.00010000


VALID LOSS:4.8151: 100%|██████████| 1500/1500 [02:06<00:00, 11.87it/s]


(Epoch 1) VALID LOSS:4.8151 BLEU:5.70 SacreBLEU:6.00 ROUGE1:35.97 ROUGE2:4.93 ROUGEL:22.69 ROUGELsum:22.69


In [11]:
# Load best model
model.load_state_dict(torch.load(model_dir + "/best_model_0.th"))

<All keys matched successfully>

In [None]:
# Evaluate
test_loss, test_metrics, test_hyp, test_label = evaluate(model, data_loader=test_loader, forward_fn=forward_generation, 
                                                         metrics_fn=generation_metrics_fn, model_type=model_type, 
                                                         tokenizer=tokenizer, beam_size=beam_size, 
                                                         max_seq_len=max_seq_len, is_test=True, 
                                                         device='cuda')

TESTING... :   2%|▏         | 28/1482 [00:53<37:00,  1.53s/it]  

In [16]:
metrics_scores = []
result_dfs = []

metrics_scores.append(test_metrics)
result_dfs.append(pd.DataFrame({
    'hyp': test_hyp, 
    'label': test_label
}))

result_df = pd.concat(result_dfs)
metric_df = pd.DataFrame.from_records(metrics_scores)

print('== Prediction Result ==')
print(result_df.head())
print()

print('== Model Performance ==')
print(metric_df.describe())

result_df.to_csv(model_dir + "/prediction_result.csv")
metric_df.describe().to_csv(model_dir + "/evaluation_result.csv")

== Prediction Result ==
                                                 hyp  \
0   ku sabab eta, simkuring teu sadar kana naon-n...   
1   maranehna, kabeh, dalaharna, nepi ka sarerea ...   
2   ku sabab eta, urang teh beunang manusa-manusa...   
3   jelema-jelema anu suci, kabeh oge, kabeh suci...   
4   geus kitu, petrus, yahya, jeung yahya dipiwar...   

                                               label  
0   da teu terang naon-naon, tur can tangtu simku...  
1         terus kabeh dalalahar nepi ka sareubeuhna.  
2   sabab urang mah darma dadamelan allah, anu ge...  
3   pikeun anu hatena suci mah, sagala ge suci. s...  
4   sanggeus leupas, petrus jeung yahya nepangan ...  

== Model Performance ==
            BLEU  SacreBLEU     ROUGE1     ROUGE2     ROUGEL  ROUGELsum
count   1.000000   1.000000   1.000000   1.000000   1.000000   1.000000
mean   11.323039  11.271977  35.564586  12.995558  30.761829  30.763307
std          NaN        NaN        NaN        NaN        NaN        