-- Checkpoint model

In [41]:
import os
import torch
import yaml
import pickle
from transformers import AlbertConfig, AlbertModel, TransfoXLTokenizer
from model import MultiTaskModel
from utils import length_to_mask
from text_utils import TextCleaner
from phonemize import phonemize

# Load config
config_path = "Configs/config.yml"
config = yaml.safe_load(open(config_path))

# Load token_maps
with open(config['dataset_params']['token_maps'], 'rb') as handle:
    token_maps = pickle.load(handle)

# Load tokenizer
os.environ['TRUST_REMOTE_CODE'] = 'True'
tokenizer = TransfoXLTokenizer.from_pretrained(config['dataset_params']['tokenizer'])

# Setup phonemizer
from phonemizer import phonemize as phonemizer_phonemize
from phonemizer.backend import EspeakBackend

global_phonemizer = phonemizer_phonemize.backend = EspeakBackend(language='id', preserve_punctuation=True, with_stress=True)

# Load model
albert_base_configuration = AlbertConfig(**config['model_params'])
bert = AlbertModel(albert_base_configuration)
bert = MultiTaskModel(bert, 
                      num_vocab=1 + max([m['token'] for m in token_maps.values()]), 
                      num_tokens=config['model_params']['vocab_size'],
                      hidden_size=config['model_params']['hidden_size'])

# Load checkpoint
checkpoint_path = "/workspace/src/PL-BERT-ID/step_1000000.t7"  # Ganti dengan path checkpoint Anda
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict = checkpoint['net']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith('module.') else k
    new_state_dict[name] = v
bert.load_state_dict(new_state_dict, strict=False)
bert.eval()

# TextCleaner for phoneme to IDs
text_cleaner = TextCleaner()

# Inverse dict for decoding phonemes
from text_utils import symbols
phoneme_dict = {i: symbols[i] for i in range(len(symbols))}

# Function untuk inference text ke phoneme (langsung)
def text_to_phoneme(text):
    """Konversi text ke phoneme menggunakan phonemizer"""
    phoneme_data = phonemize(text, global_phonemizer, tokenizer)
    phonemes = phoneme_data['phonemes']
    
    # Gabungkan phoneme dengan spasi
    phoneme_str = ' '.join(phonemes)
    
    return phoneme_str

# Function untuk phoneme prediction dengan masking
def infer_with_masking(text, mask_positions=None):
    """Prediksi phoneme dengan beberapa posisi di-mask"""
    # Phonemize input text
    phoneme_data = phonemize(text, global_phonemizer, tokenizer)
    phonemes = phoneme_data['phonemes']
    
    # Build phoneme string dengan separator
    phoneme_parts = []
    for ph in phonemes:
        phoneme_parts.append(ph)
        phoneme_parts.append(' ')  # token_separator
    phoneme_str = ''.join(phoneme_parts).strip()
    
    # Convert ke IDs
    phoneme_ids = text_cleaner(phoneme_str)
    
    # Apply masking jika ada
    if mask_positions:
        mask_token_id = text_cleaner("M")[0]  # M adalah mask token
        for pos in mask_positions:
            if pos < len(phoneme_ids):
                phoneme_ids[pos] = mask_token_id
    
    input_ids = torch.tensor([phoneme_ids]).long()
    input_lengths = [len(phoneme_ids)]
    
    # Create mask
    text_mask = length_to_mask(torch.tensor(input_lengths)).bool()
    
    with torch.no_grad():
        tokens_pred, words_pred = bert(input_ids, attention_mask=(~text_mask).int())
    
    # Decode hanya posisi yang di-mask
    if mask_positions:
        predicted_phoneme_ids = torch.argmax(tokens_pred[0], dim=-1).cpu().numpy()
        
        result_ids = phoneme_ids.copy()
        for pos in mask_positions:
            if pos < len(predicted_phoneme_ids):
                result_ids[pos] = predicted_phoneme_ids[pos]
        
        decoded_phonemes = ''.join(phoneme_dict.get(pid, '?') for pid in result_ids)
        return decoded_phonemes
    else:
        # Tanpa masking, return input phoneme
        return phoneme_str

# Contoh penggunaan 1: Text to Phoneme langsung
input_text = "saya belajar"
phoneme_output = text_to_phoneme(input_text)
print(f"Input: {input_text}")
print(f"Phoneme: {phoneme_output}")

# Contoh penggunaan 2: Inference dengan masking
input_text2 = "learning is fun"
# Mask beberapa posisi untuk testing
masked_output = infer_with_masking(input_text2, mask_positions=[2, 5, 8])
print(f"\nInput: {input_text2}")
print(f"With masking: {masked_output}")

# Contoh tanpa masking
normal_output = infer_with_masking(input_text2)
print(f"Without masking: {normal_output}")

`TransfoXL` was deprecated due to security issues linked to `pickle.load` in `TransfoXLTokenizer`. See more details on this model's documentation page: `https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/transfo-xl.md`.


177
Input: saya belajar
Phoneme: sˈaja bəlˈadʒar

Input: learning is fun
With masking: ləɣarɣiŋɣˈis fˈun
Without masking: ləˈarniŋ ˈis fˈun
