In [2]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [3]:
import yaml

config_path = "Configs/config.yml"
config = yaml.safe_load(open(config_path))

In [4]:
from phonemize import phonemize
import phonemizer
import torch

global_phonemizer = phonemizer.backend.EspeakBackend(language='ms', preserve_punctuation=True,  with_stress=True)

`openai-whisper` is not available, native whisper processor is not available, will use huggingface processor instead.


In [5]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(config['dataset_params']['tokenizer'])

In [6]:
from datasets import load_dataset

dataset = load_dataset("mesolitica/PL-BERT-MS", split="train")

In [7]:
from dataloader import build_dataloader

train_loader = build_dataloader(dataset, batch_size=10, num_workers=0, dataset_config=config['dataset_params'])

177


In [8]:
_, (words, labels, phonemes, input_lengths, masked_indices) = next(enumerate(train_loader))

In [12]:
# !wget https://huggingface.co/mesolitica/PL-BERT-MS/resolve/main/step_25000.t7

In [13]:
from transformers import AlbertConfig, AlbertModel
from model import MultiTaskModel
import pickle

with open(config['dataset_params']['token_maps'], 'rb') as handle:
    token_maps = pickle.load(handle)

In [14]:
albert_base_configuration = AlbertConfig(**config['model_params'])
    
bert = AlbertModel(albert_base_configuration)
bert = MultiTaskModel(bert, 
                      num_vocab=1 + max([m['token'] for m in token_maps.values()]), 
                      num_tokens=config['model_params']['vocab_size'],
                      hidden_size=config['model_params']['hidden_size'])

In [26]:
bert = bert.to('cuda')

In [15]:
checkpoint = torch.load('step_25000.t7', map_location='cpu')
state_dict = checkpoint['net']

In [16]:
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:]
    new_state_dict[name] = v

bert.load_state_dict(new_state_dict, strict=False)

<All keys matched successfully>

In [27]:
from utils import length_to_mask
device = 'cuda'

In [30]:
text_mask = length_to_mask(torch.Tensor(input_lengths)).to(device)
phonemes = phonemes.to(device)

In [31]:
tokens_pred, words_pred = bert(phonemes, attention_mask=(~text_mask).int())

In [33]:
tokens_pred.shape

torch.Size([10, 256, 178])

In [42]:
[token_maps[int(i)] for i in words_pred.argmax(-1)[0]]

[{'word': 'maghribi', 'token': 2935},
 {'word': 'maghribi', 'token': 2935},
 {'word': 'maghribi', 'token': 2935},
 {'word': 'bergaji', 'token': 23930},
 {'word': 'pentadbiran', 'token': 295},
 {'word': 'garaj', 'token': 19293},
 {'word': '[SEP]', 'token': 2},
 {'word': 'dan', 'token': 9},
 {'word': 'dan', 'token': 9},
 {'word': 'dan', 'token': 9},
 {'word': '[SEP]', 'token': 2},
 {'word': 'ditulis', 'token': 756},
 {'word': 'ditulis', 'token': 756},
 {'word': 'ditulis', 'token': 756},
 {'word': 'ditulis', 'token': 756},
 {'word': 'ditulis', 'token': 756},
 {'word': 'ditulis', 'token': 756},
 {'word': 'ditulis', 'token': 756},
 {'word': 'ditulis', 'token': 756},
 {'word': '[SEP]', 'token': 2},
 {'word': 'oleh', 'token': 34},
 {'word': 'oleh', 'token': 34},
 {'word': 'oleh', 'token': 34},
 {'word': 'oleh', 'token': 34},
 {'word': 'oleh', 'token': 34},
 {'word': '[SEP]', 'token': 2},
 {'word': 'mary', 'token': 2308},
 {'word': 'mary', 'token': 2308},
 {'word': 'mary', 'token': 2308},
 {'w

In [40]:
[token_maps[int(i)] for i in words[0]]

[{'word': 'jigar', 'token': 94465},
 {'word': 'jigar', 'token': 94465},
 {'word': 'jigar', 'token': 94465},
 {'word': 'jigar', 'token': 94465},
 {'word': 'jigar', 'token': 94465},
 {'word': 'jigar', 'token': 94465},
 {'word': '[SEP]', 'token': 2},
 {'word': 'dan', 'token': 9},
 {'word': 'dan', 'token': 9},
 {'word': 'dan', 'token': 9},
 {'word': '[SEP]', 'token': 2},
 {'word': 'ditulis', 'token': 756},
 {'word': 'ditulis', 'token': 756},
 {'word': 'ditulis', 'token': 756},
 {'word': 'ditulis', 'token': 756},
 {'word': 'ditulis', 'token': 756},
 {'word': 'ditulis', 'token': 756},
 {'word': 'ditulis', 'token': 756},
 {'word': 'ditulis', 'token': 756},
 {'word': '[SEP]', 'token': 2},
 {'word': 'oleh', 'token': 34},
 {'word': 'oleh', 'token': 34},
 {'word': 'oleh', 'token': 34},
 {'word': 'oleh', 'token': 34},
 {'word': 'oleh', 'token': 34},
 {'word': '[SEP]', 'token': 2},
 {'word': 'mayur', 'token': 58555},
 {'word': 'mayur', 'token': 58555},
 {'word': 'mayur', 'token': 58555},
 {'word': 