Try to load from: IndoNLG_finals_mBart_model_v2_checkpoint_105_640000.pt

In [1]:
import os, sys
sys.path.append('../')
import torch
from transformers import MBartModel, BartForConditionalGeneration, GPT2LMHeadModel, MBartForConditionalGeneration, BartConfig

In [11]:
def prepare_input_for_generation(self, inputs, lang_token = '[indonesia]', decoder_lang_token = '[indonesia]', decoder_inputs=None, return_tensors='pt'):

    # Process encoder input
    if lang_token not in self.special_tokens_to_ids:
        raise ValueError(f"Unknown lang_token `{lang_token}`, lang_token must be either `[javanese]`, `[sundanese]`, or `[indonesian]`")  
    elif type(inputs) == list:
        if len(inputs) == 0 or type(inputs[0]) != str:
            raise ValueError(IndoNLGTokenizer.input_error_message)
    elif type(inputs) != str:
        raise ValueError(IndoNLGTokenizer.input_error_message)

    lang_id = self.special_tokens_to_ids[lang_token]
    input_batch = self(inputs, return_attention_mask=False)
    input_batch['input_ids'][0] = input_batch['input_ids'][0][1:-1]
    if type(inputs) == str:
#         input_batch['input_ids'] = [self.bos_token_id] + input_batch['input_ids'] + [self.eos_token_id, lang_id]
        input_batch['input_ids'] = [self.bos_token_id] + input_batch['input_ids'] + [self.mask_token_id, self.eos_token_id, lang_id]
#         input_batch['input_ids'] = input_batch['input_ids'] + [self.mask_token_id, self.eos_token_id, lang_id]
    else:
        input_batch['input_ids'] = list(map(
#             lambda input_ids: [self.bos_token_id] + input_ids + [self.eos_token_id, lang_id], 
            lambda input_ids: [self.bos_token_id] + input_ids + [self.mask_token_id, self.eos_token_id, lang_id], 
#             lambda input_ids: input_ids + [self.mask_token_id, self.eos_token_id, lang_id], 
            input_batch['input_ids']))

    if decoder_inputs is None:
        # Return encoder input
        return self.pad(input_batch, return_tensors=return_tensors)
    else:
        # Process decoder input
        if decoder_lang_token not in self.special_tokens_to_ids:
            raise ValueError(f"Unknown decoder_lang_token `{decoder_lang_token}`, decoder_lang_token must be either `[javanese]`, `[sundanese]`, or `[indonesian]`")  
        elif type(decoder_inputs) == list:
            if len(decoder_inputs) == 0:
                raise ValueError(IndoNLGTokenizer.input_error_message)
            elif type(decoder_inputs[0]) != str:
                raise ValueError(IndoNLGTokenizer.input_error_message)
        elif type(decoder_inputs) != str:
            raise ValueError(IndoNLGTokenizer.input_error_message)

        decoder_lang_id = self.special_tokens_to_ids[decoder_lang_token]
        decoder_input_batch = self(decoder_inputs, return_attention_mask=False)
        decoder_input_batch['input_ids'][0] = decoder_input_batch['input_ids'][0][1:-1]

        if type(decoder_inputs) == str:
#             decoder_input_batch['input_ids'] = [lang_id, self.bos_token_id] + decoder_input_batch['input_ids']  + [self.eos_token_id]
            decoder_input_batch['input_ids'] = [lang_id, self.bos_token_id] + decoder_input_batch['input_ids']  + [self.mask_token_id, self.eos_token_id]
#             decoder_input_batch['input_ids'] = [lang_id] + decoder_input_batch['input_ids']  + [self.mask_token_id, self.eos_token_id]
        else:
#             decoder_input_batch['input_ids'] = list(map(lambda input_ids: [lang_id, self.bos_token_id] + input_ids + [self.eos_token_id], decoder_input_batch['input_ids']))
            decoder_input_batch['input_ids'] = list(map(lambda input_ids: [lang_id, self.bos_token_id] + input_ids + [self.mask_token_id, self.eos_token_id], decoder_input_batch['input_ids']))
#             decoder_input_batch['input_ids'] = list(map(lambda input_ids: [lang_id] + input_ids + [self.mask_token_id, self.eos_token_id], decoder_input_batch['input_ids']))

        # Padding
        input_batch = self.pad(input_batch, return_tensors=return_tensors)
        decoder_input_batch = self.pad(decoder_input_batch, return_tensors=return_tensors)

        # Store into a single dict
        input_batch['decoder_input_ids'] = decoder_input_batch['input_ids']
        input_batch['decoder_attention_mask'] = decoder_input_batch['attention_mask']

        return input_batch

In [12]:
model_checkpoint = '/home/samuel/indonlg/checkpoints/IndoNLG_finals_mBart_model_v2_checkpoint_105_640000.pt'
vocab_path = 'IndoNLG_finals_vocab_model_indo4b_plus_spm_bpe_9995_wolangid_bos_pad_eos_unk.model'

# source_lang = "id_ID"
# target_lang = "su_SU"

config = BartConfig.from_pretrained('facebook/bart-base')
config.vocab_size = 40004
model = MBartForConditionalGeneration(config=config)

bart = MBartModel(config=config)

checkpoint = torch.load(model_checkpoint)['model']
bart.load_state_dict(checkpoint, strict=False)
bart.shared.weight = bart.encoder.embed_tokens.weight
model.model = bart
model.lm_head.weight.data = checkpoint['decoder.output_projection.weight']

bart_model = model

In [17]:
from tokenization_indonlg import IndoNLGTokenizer

tokenizer = IndoNLGTokenizer(vocab_file=vocab_path)

# inputs = ['aku pergi ke toko obat membeli']
# bart_input = prepare_input_for_generation(tokenizer, inputs, return_tensors='pt',
#                                          lang_token = '[indonesian]', decoder_lang_token='[indonesian]')

# inputs = ['aku menyang pasar karo']
# bart_input = prepare_input_for_generation(tokenizer, inputs, return_tensors='pt',
#                                          lang_token = '[javanese]', decoder_lang_token='[javanese]')

inputs = ['kuring ka pasar senen meuli daging']
bart_input = prepare_input_for_generation(tokenizer, inputs, return_tensors='pt',
                                         lang_token = '[sundanese]', decoder_lang_token='[sundanese]')

bart_input

{'input_ids': tensor([[    0,  4836,   652,  1726, 23248, 23716,  3103, 40003,     2, 40001]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [18]:
bart.eval()
bart_out = bart_model(**bart_input)
print(tokenizer.decode(bart_input['input_ids'][0]))
print(tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

<s> kuring ka pasar senen meuli daging<mask></s>[sundanese]
<s> kuring ka pasar senen meuli daging sapi, kuring


In [8]:
tokenizer.decode(bart_input['decoder_input_ids'][0])

'[sundanese]<s> aku pergi ke</s>'

In [None]:
# from src.indobenchmark import IndoNLGTokenizer

# tokenizer = IndoNLGTokenizer.from_pretrained('indobenchmark/indobart')
# bart_input = tokenizer.prepare_input_for_generation(['aku adalah <mask>'], model_type='indobart', return_tensors='pt')
# bart_input

In [None]:
bart_input = tokenizer.prepare_input_for_generation(['aku adalah <mask>'], model_type='indobart', return_tensors='pt')

In [None]:
bart_input = tokenizer.prepare_input_for_generation(['abdi teh ayeuna','abdi teh ayeuna'], lang_token='[indonesian]',
    decoder_inputs=['abdi teh ayeuna','abdi teh ayeuna'], decoder_lang_token='[indonesian]', model_type='indobart', return_tensors='pt')
bart_out = bart_model(**bart_input)
tokenizer.decode(bart_out.logits.topk(1).indices[0,:,:].squeeze()), tokenizer.decode(bart_out.logits.topk(1).indices[1,:,:].squeeze())

In [None]:
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())

# Init Model

In [None]:
bart_model = BartForConditionalGeneration.from_pretrained('indobenchmark/indobart')
# gpt_model = GPT2LMHeadModel.from_pretrained('indobenchmark/indogpt')
tokenizer = IndoNLGTokenizer.from_pretrained('indobenchmark/indobart')

# Test GPT Model

In [None]:
gpt_input = tokenizer.prepare_input_for_generation('aku adalah anak', model_type='indogpt', return_tensors='pt')
gpt_out = gpt_model.generate(**gpt_input)
tokenizer.decode(gpt_out[0])

In [None]:
gpt_input = tokenizer.prepare_input_for_generation('aku suka sekali makan', model_type='indogpt', return_tensors='pt')
gpt_out = gpt_model.generate(**gpt_input)
tokenizer.decode(gpt_out[0])

In [None]:
gpt_input = tokenizer.prepare_input_for_generation('hai, bagaimana kabar', model_type='indogpt', return_tensors='pt')
gpt_out = gpt_model.generate(**gpt_input)
tokenizer.decode(gpt_out[0])

# Test BART Model

In [None]:
bart_input = tokenizer.prepare_input_for_generation(['aku adalah <mask>'], model_type='indobart', return_tensors='pt')
bart_out = bart_model(**bart_input)
tokenizer.decode(bart_out.logits.topk(1).indices[:,:,:].squeeze())

In [None]:
bart_input = tokenizer.prepare_input_for_generation(['abdi teh ayeuna','abdi teh ayeuna'], lang_token='[indonesian]',
    decoder_inputs=['abdi teh ayeuna','abdi teh ayeuna'], decoder_lang_token='[indonesian]', model_type='indobart', return_tensors='pt')
bart_out = bart_model(**bart_input)
tokenizer.decode(bart_out.logits.topk(1).indices[0,:,:].squeeze()), tokenizer.decode(bart_out.logits.topk(1).indices[1,:,:].squeeze())