In [1]:
from bert import BertForMaskedLM
from text.symbols import symbols, MASK_TOKEN_ID
from text import text_to_sequence
import utils
import commons
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
hps = utils.get_hparams_from_file("configs/id_bert_base.json")

In [3]:
model = BertForMaskedLM(
    n_vocab=len(symbols),
    out_channels=hps.model.inter_channels,
    hidden_channels=hps.model.hidden_channels,
    filter_channels=hps.model.filter_channels,
    n_heads=hps.model.n_heads,
    n_layers=hps.model.n_layers,
    kernel_size=hps.model.kernel_size,
    p_dropout=hps.model.p_dropout,
)

In [4]:
model.load_state_dict(torch.load('logs/id_bert_base/pytorch_model.bin'))

<All keys matched successfully>

In [5]:
text = "ia memblokir seseorang"
text_norm = text_to_sequence(text, ["indonesian_cleaners"])
input_ids = commons.intersperse(text_norm, 0)
input_tokens = [symbols[i] for i in input_ids]

In [10]:
# mask e's (or whatever phoneme you'd like the model to predict)
MASK_PHONEMES = ['ə', 'e']
input_ids = [i if t not in MASK_PHONEMES else MASK_TOKEN_ID for i, t in zip(input_ids, input_tokens)]
masked_idxs = [i for i, t in enumerate(input_ids) if t == MASK_TOKEN_ID]

In [14]:
device = "cuda"
text_lengths = torch.LongTensor([len(input_ids)]).to(device)
text_norm = torch.LongTensor([input_ids]).to(device)
model = model.to(device)

In [25]:
loss, logits, hidden = model(text_norm, text_lengths) # logits: [B, T, V]
preds = torch.argmax(logits.squeeze(0), dim=1).cpu().numpy() # preds: [T]

In [30]:
# unmask phonemes
predicted_input_ids = [preds[idx] if idx in masked_idxs else i for idx, i in enumerate(input_ids)]
recovered_sequence = "".join([symbols[i] for i in predicted_input_ids])
recovered_sequence

'_i_a_ _m_ə_m_b_l_o_k_i_r_ _s_ə_s_ə_o_r_a_ŋ_'