# Notebook for filtering and normalizing Wikipedia (Czech) dataset

### Initilizing phonemizer and tokenizer

In [None]:
import os
import re
from datasets import load_dataset
from tpp_ttstool import TppTtstool

In [None]:
N_CPUS = int(os.environ["PBS_NUM_PPN"])
print(f"> Number of CPUs: {N_CPUS}")

In [None]:
# inp_paragraph_file = '../BERT_cs/test.txt'
inp_paragraph_file = 'test_norm.txt'
out_sentence_file = 'test_norm.out.txt'
not_supported_chars_file = 'test_not_supported_chars.txt'
ttstool_bin= "./tts_tool/tts_tool"
ttstool_data = "./tts_tool/data.new/frontend_normalize.json"
# punctuation = ".,;:-?!…" # !!! TODO: definovat

supported_chars = {
    'lower_chars': 'aábcčdďeéěfghiíjklmnňoópqrřsštťuúůvwxyýzž',
    'upper_chars': 'AÁBCČDĎEÉĚFGHIÍJKLMNŇOÓPQRŘSŠTŤUÚŮVWXYÝZŽ',
    'digits':  '0123456789',
    'white_spaces': ' ',
    'punct': '.,;:-?',
}

replacements = {
    'character': {
        '–': '-',
        '—': '-',
        '−': '-',
        '"': '',
        "'": '',
        "„": '',
        "“": '',
    },
    # 'word': {
    #    '%': 'procent',
    #},
    'paragraph': {
        'km/h': 'kilometrů za hodinu',
        'm/s': 'metrů za sekundu',
        'př.n.l.': 'před naším letopočtem',
        'př. n. l.': 'před naším letopočtem',
    }
}

In [None]:
# Nejprve řadíme klíče slovníku podle délky v sestupném pořadí. To zajistí, že delší fráze budou nahrazeny před kratšími,
# čímž se předejde nesprávným nebo duplicitním náhradám.
paragraph_replacements = dict(sorted(replacements['paragraph'].items(), key=lambda item: len(item[0]), reverse=True))

In [None]:
def count_lines(fname):
    with open(fname, 'rb') as f:
        n = 0
        for block in iter(lambda: f.read(1024 * 1024), b''):
            n += block.count(b'\n')
    return n

def replace_chars(text, replacements):
    text = re.sub(r'\s+', ' ', text)
    translation_table = str.maketrans(replacements)
    return text.translate(translation_table)

def is_extremal_sentence(sentence, min_len=None, max_len=None, remove_neighbors=False):
    if min_len and len(sentence) < min_len:
        return True
    if max_len and len(sentence) > max_len:
        return True
    return False

def check_chars(text, supported_dict, f):
    # # Check text starts with an upper char or digit
    # if text[0] not in supported_dict['upper_chars'] and text[0] not in supported_dict['digits']:
    #    print(f'{text[0]} {text}', file=f)
    #    print("!!!")
    #    return False
    # Check supported chars
    supported_string = ''.join(supported_dict.values())
    for t in text:
        if t not in supported_string:
            print(f'{t} {text}', file=f)
            return False
    return True

def check_word_level(sentence, f):
    for idx, word in enumerate(sentence.split()):
        if idx == 0 and len(word) == 1 and word.isupper():
            # Sentence-leading preposition is OK
            continue
        # Check word contains uppercase chars
        if not check_case_word(word):
            # Acronym is not OK
            print(f'{word} {sentence}', file=f)
            return False
    return True

def check_word_level2(sentence, replacements, f):
    words = []
    for idx, word in enumerate(sentence.split()):
        if word == '%':
            words.append('procent')
        if idx == 0 and len(word) == 1 and word.isupper():
            # Sentence-leading preposition is OK
            words.append(word)
            continue
        # Check word contains uppercase chars
        if not check_case_word(word):
            # Acronym is not OK
            print(f'{word} {sentence}', file=f)
            return None
        words.append(word)
    return ' '.join(words)

def check_case_word(word):
    n_upper = sum(1 for c in word if c.isupper())
    # n_lower = sum(1 for c in text if c.lower())
    return not(n_upper > 1)

def replace_supchars(text, replacements):
    pattern = re.compile("|".join(map(re.escape, replacements.keys())))
    return pattern.sub(lambda match: replacements[match.group(0)], text)

def check_final_punct(text, punctuation):
    if text[-1] in punctuation:
        if text[-1] in ',;:-':
            text = text[:-1] + '.'
    else:
        text += '.'
    return text

In [None]:
# Setup TPP with path to tts_tool binary and data
normalizer = TppTtstool('cz', tts_tool_bin=ttstool_bin, tts_tool_data=ttstool_data, punct=supported_chars['punct'])

### Process dataset

In [None]:
dataset = load_dataset('text', data_files=inp_paragraph_file)['train']

In [None]:
n_orig_sents, n_sents = 0, 0

with open(out_sentence_file, 'w') as fout_sent, open(not_supported_chars_file, 'w') as fout_not_supp:
    # Go through all paragraph-based lines
    for ex in dataset:
        # print("INPUT:   ", ex['text'])
        paragraph = replace_supchars(ex['text'], paragraph_replacements)
        # print("PAR REPL:", paragraph)
        # Replace unsupported characters
        paragraph = replace_chars(paragraph, replacements['character'])
        # print("CH REPL: ", paragraph)

        try:    # Parse paragraph
            normalizer.ssml_parse(paragraph)
        except RuntimeError as e:
            print('[!] Parsing paragraph failed => skipping\n')
            print(f'FAILED {paragraph}', file=fout_not_supp)
            print(paragraph)
            continue
        
        # Normalize & parse to sentences
        sentences = list(normalizer.to_sentences_orto())

        n_orig_sents += len(sentences)

        # Go through individual sentences
        for sentence in sentences:
            if not check_chars(sentence, supported_chars, fout_not_supp):
                continue
            if is_extremal_sentence(sentence, 3):
                continue
            if not check_word_level(sentence, fout_not_supp):
                continue
            sentence = check_final_punct(sentence, supported_chars['punct'])
            if sentence[0].islower():
                sentence = sentence[0].upper() + sentence[1:]
            
            n_sents += 1
            print(sentence, file=fout_sent)
            # print('NORM SNT:', sentence)

In [None]:
print(f'Original sentences:  {n_orig_sents}')
print(f'Processed sentences: {n_sents}')
print(f'Used %:              {(n_sents/n_orig_sents):.2%}')