# Notebook for preprocessing Wikipedia (Indonesia) dataset

### Initilizing phonemizer and tokenizer

In [None]:
import yaml

config_path = "Configs/config.yml" # you can change it to anything else
config = yaml.safe_load(open(config_path))

In [None]:
from phonemize import phonemize

In [None]:
# filepath: /workspace/PL-BERT/preprocess.ipynb
# Tambahkan kode phonemizer kustom Anda di sini
import subprocess
import re
from functools import lru_cache
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import pandas as pd
from lingua import Language, LanguageDetectorBuilder
import warnings
from tqdm import tqdm 

warnings.filterwarnings("ignore", message="Trying to detect language from a single word.")

languages = [Language.ENGLISH, Language.INDONESIAN]
detector = LanguageDetectorBuilder.from_languages(*languages).build()

@lru_cache(maxsize=100_000)
def detect_lang(word: str) -> str:
    result = detector.detect_language_of(word)
    if result is None:
        return "id"
    return "en" if result == Language.ENGLISH else "id"

@lru_cache(maxsize=100_000)
def phonemize_word(word: str, ipa: bool, keep_stress: bool, sep: str) -> str:
    lang = detect_lang(word)
    lang_map = {"id": "id", "en": "en-us"}
    voice = lang_map.get(lang, "id")
    cmd = ["espeak-ng", "-v", voice, "-q", f"--sep={sep}", word]
    if ipa:
        cmd.insert(3, "--ipa")
    else:
        cmd.insert(3, "-x")
    try:
        result = subprocess.run(cmd, capture_output=True)
        phonemes = result.stdout.decode("utf-8", errors="ignore").strip()
        phonemes = phonemes.replace("\ufeff", "")
        if not keep_stress:
            phonemes = re.sub(r"[ˈˌ]", "", phonemes)
        return phonemes
    except (subprocess.TimeoutExpired, Exception):
        return word

class EnIndPhonemizer:
    def __init__(self, ipa=True, keep_stress=False, sep="", max_workers=None):
        self.ipa = ipa
        self.keep_stress = keep_stress
        self.sep = sep
        self.max_workers = max_workers or 4

    def phonemize(self, text: str) -> str:
        if not text:
            return ""
        words = text.strip().split()
        phonemized_words = [
            phonemize_word(w, self.ipa, self.keep_stress, self.sep) for w in words
        ]
        return " ".join(phonemized_words)

    def process_in_parallel(self, texts: list[str]) -> list[str]:
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            return list(
                tqdm(
                    executor.map(self.phonemize, texts),
                    total=len(texts),
                    desc="Phonemizing Sentences"
                )
            )

In [None]:
import phonemizer
global_phonemizer = EnIndPhonemizer(ipa=True, keep_stress=True, sep="|")

In [None]:
from transformers import TransfoXLTokenizer
tokenizer = TransfoXLTokenizer.from_pretrained(config['dataset_params']['tokenizer']) # you can use any other tokenizers if you want to

### Process dataset

In [None]:
from datasets import load_dataset
dataset = load_dataset("wikipedia", "20231101.id")['train'] # using id wikipedia

In [None]:
root_directory = "./wiki_phoneme" # set up root directory for multiprocessor processing

In [None]:
import os
num_shards = 50000

def process_shard(i):
    directory = root_directory + "/shard_" + str(i)
    if os.path.exists(directory):
        print("Shard %d already exists!" % i)
        return
    print('Processing shard %d ...' % i)
    shard = dataset.shard(num_shards=num_shards, index=i)
    processed_dataset = shard.map(lambda t: {
        'phonemes': global_phonemizer.phonemize(t['text']),
        'tokens': tokenizer.encode(t['text'])  # Tambahkan tokenization jika diperlukan
    }, remove_columns=['text'])
    if not os.path.exists(directory):
        os.makedirs(directory)
    processed_dataset.save_to_disk(directory)

In [None]:
from pebble import ProcessPool
from concurrent.futures import TimeoutError

#### Note: You will need to run the following cell multiple times to process all shards because some will fail. Depending on how fast you process each shard, you will need to change the timeout to a longer value to make more shards processed before being killed.


In [None]:
max_workers = 32 # change this to the number of CPU cores your machine has 

with ProcessPool(max_workers=max_workers) as pool:
    pool.map(process_shard, range(num_shards), timeout=60)

### Collect all shards to form the processed dataset

In [None]:
from datasets import load_from_disk, concatenate_datasets

output = [dI for dI in os.listdir(root_directory) if os.path.isdir(os.path.join(root_directory,dI))]
datasets = []
for o in output:
    directory = root_directory + "/" + o
    try:
        shard = load_from_disk(directory)
        datasets.append(shard)
        print("%s loaded" % o)
    except:
        continue

In [None]:
dataset = concatenate_datasets(datasets)
dataset.save_to_disk(config['data_folder'])
print('Dataset saved to %s' % config['data_folder'])

In [None]:
# check the dataset size
dataset

### Remove unneccessary tokens from the pre-trained tokenizer
The pre-trained tokenizer contains a lot of tokens that are not used in our dataset, so we need to remove these tokens. We also want to predict the word in lower cases because cases do not matter that much for TTS. Pruning the tokenizer is much faster than training a new tokenizer from scratch. 

In [None]:
from simple_loader import FilePathDataset, build_dataloader

file_data = FilePathDataset(dataset)
loader = build_dataloader(file_data, num_workers=32, batch_size=128)

In [None]:
special_token = config['dataset_params']['word_separator']

In [None]:
# get all unique tokens in the entire dataset

from tqdm import tqdm

unique_index = [special_token]
for _, batch in enumerate(tqdm(loader)):
    unique_index.extend(batch)
    unique_index = list(set(unique_index))

In [None]:
# get each token's lower case

lower_tokens = []
for t in tqdm(unique_index):
    word = tokenizer.decode([t])
    if word.lower() != word:
        t = tokenizer.encode([word.lower()])[0]
        lower_tokens.append(t)
    else:
        lower_tokens.append(t)

In [None]:
lower_tokens = (list(set(lower_tokens)))

In [None]:
# redo the mapping for lower number of tokens

token_maps = {}
for t in tqdm(unique_index):
    word = tokenizer.decode([t])
    word = word.lower()
    new_t = tokenizer.encode([word.lower()])[0]
    token_maps[t] = {'word': word, 'token': lower_tokens.index(new_t)}

In [None]:
import pickle
with open(config['dataset_params']['token_maps'], 'wb') as handle:
    pickle.dump(token_maps, handle)
print('Token mapper saved to %s' % config['dataset_params']['token_maps'])

### Test the dataset with dataloader


In [None]:
from dataloader import build_dataloader

train_loader = build_dataloader(dataset, batch_size=32, num_workers=0, dataset_config=config['dataset_params'])

In [None]:
_, (words, labels, phonemes, input_lengths, masked_indices) = next(enumerate(train_loader))