# Notebook for preprocessing Wikipedia (Czech) dataset

In [None]:
import os
import yaml
import sys
from datasets import load_dataset, Dataset
from tpp_ttstool import TppTtstool
from phonemize import phonemize
from pebble import ProcessPool
from concurrent.futures import ProcessPoolExecutor
from transformers import AutoTokenizer
from tqdm import tqdm
# # Set path to compatible transformers 4.33.3 library
# sys.path.insert(0, '/storage/plzen4-ntis/home/jmatouse/.local/transformers-4.33.3/lib/python3.10/site-packages')

## Papermill options

In [None]:
inp_text_file = '../BERT_cs/WIKI_C4Cleaned.sentences1k.norm_part0.txt'
# inp_text_file = 'test_preprocess.txt'
out_dataset_folder = 'datasets/cz-wikipedia.processed'
tokenizer_path = 'fav-kky/FERNET-C5'
punctuation = '.,;:-?!…' # !!! TODO: definovat

## Settings

In [None]:
NUM_CPUS = int(os.environ["PBS_NUM_PPN"])
TTSTOOL_BIN = "tts_tool/tts_tool"
TTSTOOL_DATA = "tts_tool/data/frontend_ph-redu_pauses.json"
NUM_SHARDS = NUM_CPUS * 4

print(f"> Number of CPUs:   {NUM_CPUS}")
print(f"> Number of shards: {NUM_SHARDS}")

### Pomocné funkce

In [None]:
# Function to safely phonemize a shard
def safe_phonemize_shard(shard):
    global phonemizer, punctuation
    processed = []
    for ex in shard:
        try:
            # Attempt to phonemize the text
            phonemized = phonemize(ex['text'], phonemizer, tokenizer, punctuation)
            processed.append(phonemized)
        except Exception as e:
            # Log the problematic entry and the exception details
            print(f"Exception encountered for entry: {ex['text']}")
            print(f"Error details: {e}")
    return processed

# Function to process dataset shards in parallel
def process_dataset_shards_in_parallel(dataset, num_shards, num_workers):
    # Split dataset into shards
    shards = [dataset.shard(num_shards, i, contiguous=True) for i in range(num_shards)]
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        # Process each shard in parallel
        results = executor.map(safe_phonemize_shard, shards)
    # Combine all processed shards into a single list
    combined_results = [item for shard in results for item in shard]
    return combined_results

### Initialize phonemizer and tokenizer

In [None]:
# Setup TPP with path to tts_tool binary and data
phonemizer = TppTtstool('cz', tts_tool_bin=TTSTOOL_BIN, tts_tool_data=TTSTOOL_DATA, punct=punctuation)

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

### Process dataset

In [None]:
# Load dataset from local text file
dataset = load_dataset('text', data_files=inp_text_file)['train']

In [None]:
# Process the dataset shards in parallel
processed_results = process_dataset_shards_in_parallel(dataset, NUM_SHARDS, num_workers=NUM_CPUS)

In [None]:
# Convert the processed results back to a Hugging Face dataset
processed_dataset = Dataset.from_list(processed_results)

In [None]:
processed_dataset.save_to_disk(out_dataset_folder)
print(f"Dataset saved to {out_dataset_folder}")

In [None]:
print(f'Original sentences:   {len(dataset)}')
print(f'Phonemized sentences: {len(processed_dataset)}')
print(f'Used %:               {len(processed_dataset)/len(dataset):.2%}')