In [1]:
import torch
from transformers import AlbertConfig, AlbertModel
from model import MultiTaskModel
import yaml
import pickle

# Load config
config_path = "Configs/config.yml"
with open(config_path) as f:
    config = yaml.safe_load(f)

# Load token_maps if needed for preprocessing/postprocessing
with open(config['dataset_params']['token_maps'], 'rb') as handle:
    token_maps = pickle.load(handle)

# Load model
albert_base_configuration = AlbertConfig(**config['model_params'])
bert = AlbertModel(albert_base_configuration)
model = MultiTaskModel(
    bert,
    num_vocab=1 + max([m['token'] for m in token_maps.values()]),
    num_tokens=config['model_params']['vocab_size'],
    hidden_size=config['model_params']['hidden_size']
)

# Load checkpoint
ckpt_path = "/workspace/src/PL-BERT-ID/Checkpoint/step_12.t7"
checkpoint = torch.load(ckpt_path, map_location='cpu')
state_dict = checkpoint['net']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith("module.") else k
    new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=False)
model.eval()

# Example inference
# Replace this with your actual preprocessing and input
import numpy as np
phonemes = torch.tensor(np.random.randint(0, config['model_params']['vocab_size'], (1, 128)), dtype=torch.long)
input_lengths = [128]
from utils import length_to_mask
text_mask = length_to_mask(torch.tensor(input_lengths)).to(phonemes.device)
with torch.no_grad():
    tokens_pred, words_pred = model(phonemes, attention_mask=(~text_mask).int())
print("tokens_pred shape:", tokens_pred.shape)
print("words_pred shape:", words_pred.shape)

  from .autonotebook import tqdm as notebook_tqdm


tokens_pred shape: torch.Size([1, 128, 178])
words_pred shape: torch.Size([1, 128, 53314])


In [2]:
phoneme_str = "s a y a"
phoneme_list = phoneme_str.split()  # ['s', 'a', 'y', 'a']
phoneme_ids = [token_maps[p]['token'] if p in token_maps else 0 for p in phoneme_list]  # 0 = unk

# Padding ke panjang tertentu (misal 128)
max_len = 128
phoneme_ids = phoneme_ids[:max_len] + [0] * (max_len - len(phoneme_ids))
phonemes = torch.tensor([phoneme_ids], dtype=torch.long)
input_lengths = [len(phoneme_list)]

from utils import length_to_mask
text_mask = length_to_mask(torch.tensor(input_lengths)).to(phonemes.device)

with torch.no_grad():
    tokens_pred, words_pred = model(phonemes, attention_mask=(~text_mask).int())
print("tokens_pred shape:", tokens_pred.shape)
print("words_pred shape:", words_pred.shape)

tokens_pred shape: torch.Size([1, 128, 178])
words_pred shape: torch.Size([1, 128, 53314])
