In [1]:
import os, sys
sys.path.append('../')
import torch
from transformers import GPT2LMHeadModel, MBartForConditionalGeneration
from src.indobenchmark import IndoNLGTokenizer
from torch.utils.data import DataLoader

In [2]:
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())

# Init Model

In [3]:
%%time
gpt_model = GPT2LMHeadModel.from_pretrained('indobenchmark/indogpt')
gpt_tokenizer = IndoNLGTokenizer.from_pretrained('indobenchmark/indogpt')

bart_model = MBartForConditionalGeneration.from_pretrained('indobenchmark/indobart-v2')
bart_tokenizer = IndoNLGTokenizer.from_pretrained('indobenchmark/indobart-v2')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


CPU times: user 7.76 s, sys: 2.03 s, total: 9.79 s
Wall time: 15.7 s


# Test GPT Model

In [4]:
gpt_input = gpt_tokenizer.prepare_input_for_generation('aku adalah anak', model_type='indogpt', return_tensors='pt')
gpt_out = gpt_model.generate(**gpt_input)
gpt_tokenizer.decode(gpt_out[0])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


'<s> aku adalah anak pertama dari tiga bersaudara.</s> aku lahir di kota kecil yang sama dengan ayahku.'

In [5]:
gpt_input = gpt_tokenizer.prepare_input_for_generation('aku suka sekali ', model_type='indogpt', return_tensors='pt')
gpt_out = gpt_model.generate(**gpt_input)
gpt_tokenizer.decode(gpt_out[0])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


'<s> aku suka sekali dengan warna-warna yang cerah dan cerah.</s> itu yang membuat aku suka dengan'

In [6]:
gpt_input = gpt_tokenizer.prepare_input_for_generation('hai, bagaimana ', model_type='indogpt', return_tensors='pt')
gpt_out = gpt_model.generate(**gpt_input)
gpt_tokenizer.decode(gpt_out[0])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


'<s> hai, bagaimana kabar kalian? semoga sehat selalu ya. kali ini saya akan membahas tentang cara membuat'

# Test BART Model

In [7]:
inputs = ['aku pergi ke toko obat membeli <mask>']
bart_input = bart_tokenizer.prepare_input_for_generation(inputs, return_tensors='pt',
                                         lang_token = '[indonesian]', decoder_lang_token='[indonesian]')

bart_out = bart_model(**bart_input)
print(bart_tokenizer.decode(bart_input['input_ids'][0]))
print(bart_tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

<s> aku pergi ke toko obat membeli<mask></s>[indonesian]
<s> aku pergi ke toko obat membeli obat.[indonesian]


In [8]:
inputs = ['aku menyang pasar <mask>']
bart_input = bart_tokenizer.prepare_input_for_generation(inputs, return_tensors='pt',
                                         lang_token = '[javanese]', decoder_lang_token='[javanese]')

bart_out = bart_model(**bart_input)
print(bart_tokenizer.decode(bart_input['input_ids'][0]))
print(bart_tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

<s> aku menyang pasar<mask></s>[javanese]
<s> aku menyang pasar kembang,[javanese]


In [9]:
inputs = ['kuring ka pasar senen meuli daging <mask>']
bart_input = bart_tokenizer.prepare_input_for_generation(inputs, return_tensors='pt',
                                         lang_token = '[sundanese]', decoder_lang_token='[sundanese]')

bart_out = bart_model(**bart_input)
print(bart_tokenizer.decode(bart_input['input_ids'][0]))
print(bart_tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

<s> kuring ka pasar senen meuli daging<mask></s>[sundanese]
<s> kuring ka pasar senen meuli daging sapi,[sundanese]


# Batch Loading with Decoder Tokens

In [10]:
data = []
for enc, dec in zip(
    ['aku adalah anak gembala', 'balonku ada lima', 'so I say'], 
    ['selalu riang serta gembira', 'see you once again my love', 'pokemon master']
):
    data.append(bart_tokenizer.prepare_input_for_generation(
        enc, decoder_inputs=dec, model_type='indobart', return_tensors='pt',
        lang_token='[sundanese]', decoder_lang_token='[javanese]', padding=False
    ))
print(data)

[{'input_ids': tensor([    0,   528,   450,   646, 21985,     2, 40001]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1]), 'decoder_input_ids': tensor([40000,     0,  1118, 26083,   825,  9131,     2]), 'decoder_attention_mask': tensor([1, 1, 1, 1, 1, 1, 1]), 'labels': tensor([    0,  1118, 26083,   825,  9131,     2, 40000])}, {'input_ids': tensor([    0, 13453,   620,   387,  2402,     2, 40001]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1]), 'decoder_input_ids': tensor([40000,     0, 11934,  4711, 36265, 20667,  4552,  7491,     2]), 'decoder_attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), 'labels': tensor([    0, 11934,  4711, 36265, 20667,  4552,  7491,     2, 40000])}, {'input_ids': tensor([    0,   742,   523,  3097,     2, 40001]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1]), 'decoder_input_ids': tensor([40000,     0, 16544,  5888,     2]), 'decoder_attention_mask': tensor([1, 1, 1, 1, 1]), 'labels': tensor([    0, 16544,  5888,     2, 40000])}]


In [12]:
for batch in DataLoader(data, batch_size=3, collate_fn=lambda t: bart_tokenizer.pad(t, padding='longest')):
    print(batch)
    break

{'input_ids': tensor([[    0,   528,   450,   646, 21985,     2, 40001],
        [    0, 13453,   620,   387,  2402,     2, 40001],
        [    0,   742,   523,  3097,     2, 40001,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 0]]), 'decoder_input_ids': tensor([[40000,     0,  1118, 26083,   825,  9131,     2,     1,     1],
        [40000,     0, 11934,  4711, 36265, 20667,  4552,  7491,     2],
        [40000,     0, 16544,  5888,     2,     1,     1,     1,     1]]), 'decoder_attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0, 0, 0, 0]]), 'labels': tensor([[    0,  1118, 26083,   825,  9131,     2, 40000,  -100,  -100],
        [    0, 11934,  4711, 36265, 20667,  4552,  7491,     2, 40000],
        [    0, 16544,  5888,     2, 40000,  -100,  -100,  -100,  -100]])}
