# Notebook for preprocessing Wikipedia (Czech) dataset

### Initilizing phonemizer and tokenizer

In [None]:
import os
import yaml
import sys
from text_utils import TextCleaner, load_symbol_dict
from datasets import load_dataset
from tpp_ttstool import TppTtstool
from phonemize import phonemize
from pebble import ProcessPool
from concurrent.futures import TimeoutError

# # Set path to compatible transformers 4.33.3 library
# sys.path.insert(0, '/storage/plzen4-ntis/home/jmatouse/.local/transformers-4.33.3/lib/python3.10/site-packages')

In [None]:
N_CPUS = int(os.environ["PBS_NUM_PPN"])
print(f"> Number of CPUs: {N_CPUS}")

In [None]:
CONFIG_PATH = 'configs/config.yml'
LANG = 'cs'
DATASET = '../BERT_cs/WIKI_C4Cleaned10.sentences.norm.txt'
SYMBOL_PATH = 'symbol_dict.csv'
ROOT_DIR = "./wiki_phoneme" # set up root directory for multiprocessor processing
NUM_SHARDS = 1
MAX_WORKERS = N_CPUS # change this to the number of CPU cores your machine has
TTSTOOL_BIN = "tts_tool/tts_tool"
TTSTOOL_DATA = "tts_tool/data/frontend_ph-redu_pauses.json"
PUNCTUATION = ".,;:-?!…" # !!! TODO: definovat

In [None]:
def process_shard(i):
    directory = f'{ROOT_DIR}/shard_{i}'
    if os.path.exists(directory):
        print(f'Shard {i} already exists!')
        return
    print(f'Processing shard {i} ...')
    shard = dataset.shard(num_shards=num_shards, index=i)
    processed_dataset = shard.map(lambda t: phonemize(t['text'], phonemizer, tokenizer), remove_columns=['text'])
    if not os.path.exists(directory):
        os.makedirs(directory)
    processed_dataset.save_to_disk(directory)

In [None]:
# Setup TPP with path to tts_tool binary and data
phonemizer = TppTtstool('cz', tts_tool_bin=TTSTOOL_BIN, tts_tool_data=TTSTOOL_DATA, punct=PUNCTUATION)

In [None]:
config_path = "configs/config.yml" # you can change it to anything else
config = yaml.safe_load(open(config_path))

text_cleaner = TextCleaner(load_symbol_dict(SYMBOL_PATH), pad="_")

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(config['dataset_params']['tokenizer']) # you can use any other tokenizers if you want to

### Process dataset

In [None]:
dataset = load_dataset('text', data_files=DATASET)['train']

In [None]:
# dataset = dataset.select(range(100))

In [None]:
for ex in dataset:
    phonemize(ex['text'], phonemizer, tokenizer, PUNCTUATION)

In [None]:
root_directory = "./wiki_phoneme" # set up root directory for multiprocessor processing

#### Note: You will need to run the following cell multiple times to process all shards because some will fail. Depending on how fast you process each shard, you will need to change the timeout to a longer value to make more shards processed before being killed.


In [None]:
with ProcessPool(max_workers=MAX_WORKERS) as pool:
    pool.map(process_shard, range(NUM_SHARDS), timeout=None)

### Collect all shards to form the processed dataset

In [None]:
output = [d for d in os.listdir(ROOT_DIR) if os.path.isdir(os.path.join(ROOT_DIR, d))]
datasets = []
for o in output:
    directory = f'{ROOT_DIR}/{o}'
    try:
        shard = load_from_disk(directory)
        datasets.append(shard)
        print(f'{o} loaded')
    except:
        continue

In [None]:
dataset = concatenate_datasets(datasets)
dataset.save_to_disk(config['data_folder'])
print('Dataset saved to %s' % config['data_folder'])

In [None]:
# check the dataset size
dataset

In [None]:
dataset[0]

### Remove unneccessary tokens from the pre-trained tokenizer
The pre-trained tokenizer contains a lot of tokens that are not used in our dataset, so we need to remove these tokens. We also want to predict the word in lower cases because cases do not matter that much for TTS. Pruning the tokenizer is much faster than training a new tokenizer from scratch. 

In [None]:
from simple_loader import FilePathDataset, build_dataloader

file_data = FilePathDataset(dataset)
# loader = build_dataloader(file_data, num_workers=32, batch_size=128)
loader = build_dataloader(file_data, num_workers=1, batch_size=4)

In [None]:
special_token = config['dataset_params']['word_separator']

In [None]:
# get all unique tokens in the entire dataset

from tqdm import tqdm

unique_index = [special_token]
for _, batch in enumerate(tqdm(loader)):
    unique_index.extend(batch)
    unique_index = list(set(unique_index))

In [None]:
# get each token's lower case

lower_tokens = []
for t in tqdm(unique_index):
    word = tokenizer.decode([t])
    if word.lower() != word:
        t = tokenizer.encode([word.lower()])[0]
        lower_tokens.append(t)
    else:
        lower_tokens.append(t)

In [None]:
lower_tokens = (list(set(lower_tokens)))

In [None]:
# redo the mapping for lower number of tokens

token_maps = {}
for t in tqdm(unique_index):
    word = tokenizer.decode([t])
    word = word.lower()
    new_t = tokenizer.encode([word.lower()])[0]
    token_maps[t] = {'word': word, 'token': lower_tokens.index(new_t)}

In [None]:
len(token_maps)

In [None]:
token_maps

In [None]:
import pickle
with open(config['dataset_params']['token_maps'], 'wb') as handle:
    pickle.dump(token_maps, handle)
print('Token mapper saved to %s' % config['dataset_params']['token_maps'])

### Test the dataset with dataloader


In [None]:
from dataloader import build_dataloader

train_loader = build_dataloader(dataset, batch_size=4, num_workers=0, dataset_config=config['dataset_params'])

In [None]:
_, (words, labels, phonemes, input_lengths, masked_indices) = next(enumerate(train_loader))