# Notebook token mapping for Wikipedia (Czech) dataset

### Imports

In [None]:
import os
from datasets import load_from_disk
from tqdm import tqdm
from transformers import AutoTokenizer
import pickle

from simple_loader import FilePathDataset, build_dataloader
# # Set path to compatible transformers 4.33.3 library
# sys.path.insert(0, '/storage/plzen4-ntis/home/jmatouse/.local/transformers-4.33.3/lib/python3.10/site-packages')

### Papermill options

In [None]:
log_dir = "models/test"
mixed_precision = "fp16"
data_folder = "datasets/cz-wikipedia.processed"
batch_size = 32
save_interval = 100
log_interval = 10
num_process = 1 # number of GPUs
num_steps = 1000000

dataset_params = {
    "tokenizer": "fav-kky/FERNET-C5",
    "token_separator": " ", # token used for phoneme separator (space)
    "token_mask": "M", # token used for phoneme mask (M)
    "word_separator": 18065, # token used for word separator (|)
    "token_maps": "token_maps.pkl", # token map path
    "symbol_dict_path": "symbol_dict.csv",  # symbol definition dictionary
    
    "max_mel_length": 512, # max phoneme length
    
    "word_mask_prob": 0.15, # probability to mask the entire word
    "phoneme_mask_prob": 0.1, # probability to mask each phoneme
    "replace_prob": 0.2, # probablity to replace phonemes
}
    
model_params = {
    "vocab_size": 81, # 178
    "hidden_size": 768,
    "num_attention_heads": 12,
    "intermediate_size": 2048,
    "max_position_embeddings": 512,
    "num_hidden_layers": 12,
    "dropout": 0.1,
}

In [None]:
N_CPUS = int(os.environ["PBS_NUM_PPN"])
print(f"> Number of CPUs: {N_CPUS}")

### Load tokenizer and dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained(dataset_params['tokenizer'])
dataset = load_from_disk(data_folder)

### Remove unneccessary tokens from the pre-trained tokenizer
The pre-trained tokenizer contains a lot of tokens that are not used in our dataset, so we need to remove these tokens. We also want to predict the word in lower cases because cases do not matter that much for TTS. Pruning the tokenizer is much faster than training a new tokenizer from scratch. 

In [None]:
file_data = FilePathDataset(dataset)
loader = build_dataloader(file_data, num_workers=N_CPUS, batch_size=128)
special_token = dataset_params['word_separator']

In [None]:
# Get all unique word token IDs in the entire dataset
# Result: list of word IDs, where word ID = tuple of 1 or more word piece IDs
print("Get unique word IDs...")

unique_word_ids = [special_token]
# for _, batch in enumerate(tqdm(loader)):
for batch in loader:
    unique_word_ids.extend(batch)
    unique_word_ids = list(set(unique_word_ids))

In [None]:
# Get each token's lower case
print("Get each token's lower case...")

# lower tokens are lists:
# - either with single item
# - or multiple items when lowercasing leads to word pieces
lower_tokens = []
# for t in tqdm(unique_word_ids):
for t in unique_word_ids:
    word = tokenizer.decode(t)
    word_lower = word.lower()
    if word_lower != word:
        t = tokenizer.encode(word_lower, add_special_tokens=False)
    lower_tokens.append(t if isinstance(t, list) else [t])

In [None]:
# Redo the mapping for lower number of tokens
print("Redo mapping for lower number of tokens...")

token_maps = {}
# for t in tqdm(unique_word_ids):
for t in unique_word_ids:
    word = tokenizer.decode(t)
    word_lower = word.lower()
    if word_lower.startswith('##'):
        new_t = [tokenizer.convert_tokens_to_ids(word_lower)]
    else:
        word_pieces = tokenizer.tokenize(word_lower, add_special_tokens=False)
        new_t = [tokenizer.convert_tokens_to_ids(word_piece) for word_piece in word_pieces]
    token_maps[t] = {'word': word_lower, 'token': lower_tokens.index(new_t)}

In [None]:
with open(dataset_params['token_maps'], 'wb') as handle:
    pickle.dump(token_maps, handle)
print(f"Token mapper saved to {dataset_params['token_maps']}")

### Test the dataset with dataloader


In [None]:
from dataloader import build_dataloader

# Test dataloader
train_loader = build_dataloader(dataset, validation=True, batch_size=4, num_workers=0, dataset_config=dataset_params)
# Test next item in the dataloader
_, (words, labels, phonemes, input_lengths, masked_indices) = next(enumerate(train_loader))

## Poznámky

Spuštění na 16 CPU trvá cca 30 hod.