# Notebook for filtering and phonifying Wikipedia (Czech) dataset

### Initilizing phonemizer and tokenizer

In [None]:
import os
import pysbd
import re
from datasets import load_dataset
from tpp_ttstool import TppTtstool
from tqdm import tqdm

In [None]:
N_CPUS = int(os.environ["PBS_NUM_PPN"])
print(f"> Number of CPUs: {N_CPUS}")

In [None]:
inp_paragraph_file = '../BERT_cs/WIKI_C4Cleaned.txt'
out_sentence_file = '../BERT_cs/WIKI_C4Cleaned.sent.txt'
not_supported_chars_file = 'not_supported_chars.txt'
ttstool_bin= "./tts_tool/tts_tool"
ttstool_data = "./tts_tool/data/frontend_normalize.json"
# punctuation = ".,;:-?!…" # !!! TODO: definovat

supported_chars = {
    'lower_chars': 'aábcčdďeéěfghiíjklmnňoópqrřsštťuúůvwxyýzž',
    'upper_chars': 'AÁBCČDĎEÉĚFGHIÍJKLMNŇOÓPQRŘSŠTŤUÚŮVWXYÝZŽ',
    'digits':  '0123456789',
    'white_spaces': ' ',
    'punct': '.,;:-?',
}

replacements = {
    '–': '-',
    '—': '-',
    # '"': '',
    # "'": '',
}

In [None]:
def count_lines(fname):
    with open(fname, 'rb') as f:
        n = 0
        for block in iter(lambda: f.read(1024 * 1024), b''):
            n += block.count(b'\n')
    return n

def replace_chars(text, replacements):
    text = re.sub(r'\s+', ' ', text)
    translation_table = str.maketrans(replacements)
    return text.translate(translation_table)

# def remove_extremal_sentences(sentences, min_len=None, max_len=None, remove_neighbors=False):
#    new_sentences = []
#    for s in sentences:
#        s = s.strip()
#        if min_len and len(s) < min_len:
#            continue
#        if max_len and len(s) > max_len:
#            continue
#        new_sentences.append(s)
#    return new_sentences

def is_extremal_sentence(sentence, min_len=None, max_len=None, remove_neighbors=False):
    if min_len and len(sentence) < min_len:
        return True
    if max_len and len(sentence) > max_len:
        return True
    return False

def check_chars(text, supported_dict, f):
    # # Check text starts with an upper char or digit
    # if text[0] not in supported_dict['upper_chars'] and text[0] not in supported_dict['digits']:
    #    print(f'{text[0]} {text}', file=f)
    #    print("!!!")
    #    return False
    # Check supported chars
    supported_string = ''.join(supported_dict.values())
    for t in text:
        if t not in supported_string:
            print(f'{t} {text}', file=f)
            return False
    return True

def check_word_level(sentence, f):
    for idx, word in enumerate(sentence.split()):
        # Check word contains only uppercase chars
        if word.isupper():
            if len(word) == 1 and idx == 0:
                # Sentence-leading preposition is OK
                continue
            else:
                # Abbreviation is not OK
                print(f'{word} {sentence}', file=f)
                return False
    return True

In [None]:
# Setup TPP with path to tts_tool binary and data
normalizer = TppTtstool('cz', tts_tool_bin=ttstool_bin, tts_tool_data=ttstool_data, punct=supported_chars['punct'])

### Process dataset

In [None]:
dataset = load_dataset('text', data_files=inp_paragraph_file)['train']

In [None]:
n_orig_sents, n_sents = 0, 0

with open(out_sentence_file, 'w') as fout_sent, open(not_supported_chars_file, 'w') as fout_not_supp:
    # Go through all paragraph-based lines
    for ex in dataset:
        # Replace unsupported characters
        paragraph = replace_chars(ex['text'], replacements)

        # Parse paragraph
        normalizer.ssml_parse(paragraph)
        # Normalize & parse to sentences
        sentences = list(normalizer.to_sentences_orto())

        n_orig_sents += len(sentences)
        # sentences = remove_extremal_sentences(sentences, 3)

        # Go through individual sentences
        for sentence in sentences:
            if not check_chars(sentence, supported_chars, fout_not_supp):
                continue
            if is_extremal_sentence(sentence, 3):
                continue
            if not check_word_level(sentence, fout_not_supp):
                continue
            n_sents += 1
            print(sentence, file=fout_sent)

In [None]:
print(f'Original sentences:  {n_orig_sents}')
print(f'Processed sentences: {n_sents}')
print(f'Used %:              {(n_sents/n_orig_sents):.2%}')