In [1]:
import os, sys
sys.path.append('../')
import torch
from transformers import GPT2LMHeadModel, MBartForConditionalGeneration
from src.indobenchmark import IndoNLGTokenizer
from torch.utils.data import DataLoader

In [2]:
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())

# Init Model

In [None]:
%%time
gpt_model = GPT2LMHeadModel.from_pretrained('indobenchmark/indogpt')
gpt_tokenizer = IndoNLGTokenizer.from_pretrained('indobenchmark/indogpt')

bart_model = MBartForConditionalGeneration.from_pretrained('indobenchmark/indobart-v2')
bart_tokenizer = IndoNLGTokenizer.from_pretrained('indobenchmark/indobart-v2')

# Test GPT Model

In [None]:
gpt_input = gpt_tokenizer.prepare_input_for_generation('aku adalah anak', model_type='indogpt', return_tensors='pt')
gpt_out = gpt_model.generate(**gpt_input)
gpt_tokenizer.decode(gpt_out[0])

In [None]:
gpt_input = gpt_tokenizer.prepare_input_for_generation('aku suka sekali ', model_type='indogpt', return_tensors='pt')
gpt_out = gpt_model.generate(**gpt_input)
gpt_tokenizer.decode(gpt_out[0])

In [None]:
gpt_input = gpt_tokenizer.prepare_input_for_generation('hai, bagaimana ', model_type='indogpt', return_tensors='pt')
gpt_out = gpt_model.generate(**gpt_input)
gpt_tokenizer.decode(gpt_out[0])

# Test BART Model

In [None]:
inputs = ['aku pergi ke toko obat membeli <mask>']
bart_input = bart_tokenizer.prepare_input_for_generation(inputs, return_tensors='pt',
                                         lang_token = '[indonesian]', decoder_lang_token='[indonesian]')

bart_out = bart_model(**bart_input)
print(bart_tokenizer.decode(bart_input['input_ids'][0]))
print(bart_tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

In [None]:
inputs = ['aku menyang pasar <mask>']
bart_input = bart_tokenizer.prepare_input_for_generation(inputs, return_tensors='pt',
                                         lang_token = '[javanese]', decoder_lang_token='[javanese]')

bart_out = bart_model(**bart_input)
print(bart_tokenizer.decode(bart_input['input_ids'][0]))
print(bart_tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

In [None]:
inputs = ['kuring ka pasar senen meuli daging <mask>']
bart_input = bart_tokenizer.prepare_input_for_generation(inputs, return_tensors='pt',
                                         lang_token = '[sundanese]', decoder_lang_token='[sundanese]')

bart_out = bart_model(**bart_input)
print(bart_tokenizer.decode(bart_input['input_ids'][0]))
print(bart_tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

# Batch Loading with Decoder Tokens

In [None]:
data = []
for enc, dec in zip(
    ['aku adalah anak gembala', 'balonku ada lima', 'so I say'], 
    ['selalu riang serta gembira', 'see you once again my love', 'pokemon master']
):
    data.append(bart_tokenizer.prepare_input_for_generation(
        enc, decoder_inputs=dec, model_type='indobart', return_tensors='pt',
        lang_token='[sundanese]', decoder_lang_token='[javanese]', padding=False
    ))
print(data)

In [None]:
for batch in DataLoader(data, batch_size=3, collate_fn=lambda t: tokenizer.pad(t, padding='longest')):
    print(batch)
    break