# Notebook for filtering and normalizing Wikipedia (Czech) dataset

Při rozdělení na 4 soubory po ~550 tis. odstavcích každá úloha s 8 CPU běží ~16 hod. a RAM 23 GB (`chunk_size=4`, `batch_size=1000`).

## Imports

In [None]:
import os
import re
from concurrent.futures import ProcessPoolExecutor, as_completed
from itertools import chain
from datasets import load_dataset
from tpp_ttstool import TppTtstool
import time

In [None]:
N_CPUS = int(os.environ["PBS_NUM_PPN"])
print(f"> Number of CPUs: {N_CPUS}")

## Papermill options

In [None]:
inp_paragraph_file = '../BERT_cs/WIKI_C4Cleaned1k.txt'
inp_paragraph_file = 'test_norm.txt'
out_sentence_file = 'test_norm.out2.txt'
# not_supported_chars_file = 'test_not_supported_chars.txt'
ttstool_bin= "./tts_tool/tts_tool"
ttstool_data = "./tts_tool/data/frontend_normalize.json"
chunk_size = 4 # number of paragraphs processed by each CPU
batch_size = 1000 # number of sentences to write in once
# punctuation = ".,;:-?!…" # !!! TODO: definovat

supported_chars = {
    'lower_chars': 'aábcčdďeéěfghiíjklmnňoópqrřsštťuúůvwxyýzž',
    'upper_chars': 'AÁBCČDĎEÉĚFGHIÍJKLMNŇOÓPQRŘSŠTŤUÚŮVWXYÝZŽ',
    'digits':  '0123456789',
    'white_spaces': ' ',
    'punct': '.,;:-?',
}

replacements = {
    'character': {
        '–': '-',
        '—': '-',
        '−': '-',
        # '"': '',
        # "'": '',
        # "„": '',
        # "“": '',
    },
    # 'word': {
    #    '%': 'procent',
    #},
    'paragraph': {
        'km/h': 'kilometrů za hodinu',
        'm/s': 'metrů za sekundu',
        'př.n.l.': 'před naším letopočtem',
        'př. n. l.': 'před naším letopočtem',
        'st. př. Kr.': 'století před Kristem',
        'st. př. kr.': 'století před Kristem',
        ' napr.': ' např.',
        # '<br': ' ',
    },
    'postprocess': {
        ',nebo ': ', nebo ',
    } 
}

## Define global variables and functions

In [None]:
# Nejprve řadíme klíče slovníku podle délky v sestupném pořadí. To zajistí, že delší fráze budou nahrazeny před kratšími,
# čímž se předejde nesprávným nebo duplicitním náhradám.
paragraph_replacements = dict(sorted(replacements['paragraph'].items(), key=lambda item: len(item[0]), reverse=True))
# Extract single dict of character replacements
character_replacements = replacements['character']
# Extract single dict of postprocess replacements sorted by length
postprocess_replacements = dict(sorted(replacements['postprocess'].items(), key=lambda item: len(item[0]), reverse=True))

paragraph_pattern = re.compile("|".join(map(re.escape, paragraph_replacements.keys())))
character_pattern = re.compile(r'\s+')
postprocess_pattern = re.compile("|".join(map(re.escape, postprocess_replacements.keys())))

# Ošetřete speciální znaky v interpunkčních znaménkách
# Vytvořte nový řetězec bez '-'
punct_no_dash = supported_chars['punct'].replace('-', '')
escaped_punctuation = re.escape(punct_no_dash)
# Kompilace regulárních výrazů s použitím stringu interpunkčních znamének
pattern_before_punct = re.compile(r'\s+([{}])'.format(escaped_punctuation))
pattern_after_punct = re.compile(r'([{}])\s*'.format(escaped_punctuation))
pattern_dash = re.compile(r'(\s*)(-)(\s*)')
pattern_multispace = re.compile(r'\s+')
pattern_html = re.compile(r'<.*?>')

# Setup TPP with path to tts_tool binary and data
normalizer = TppTtstool('cz', tts_tool_bin=ttstool_bin, tts_tool_data=ttstool_data, punct=supported_chars['punct'])

In [None]:
def replace_dash(match):
    before = match.group(1)
    dash = match.group(2)
    after = match.group(3)
    if (before and before.strip() == '') or (after and after.strip() == ''):
        return ' - '
    else:
        return dash

def replace_chars(text, pattern, replacements):
    text = pattern.sub(text, ' ')
    return text.translate(str.maketrans(replacements))

def is_extremal_sentence(sentence, min_len=None, max_len=None, remove_neighbors=False):
    if min_len and len(sentence) < min_len:
        return True
    if max_len and len(sentence) > max_len:
        return True
    return False

def check_chars(text, supported_dict):
    # Check supported chars
    return all(char in set(''.join(supported_dict.values())) for char in text)

def check_word_level(sentence):
    for idx, word in enumerate(sentence.split()):
        if idx == 0 and len(word) == 1 and word.isupper():
            # Sentence-leading preposition is OK
            continue
        # Check word contains uppercase chars
        if not check_case_word(word):
            # Acronym is not OK
            return False
    return True

def check_case_word(word):
    n_upper = sum(1 for c in word if c.isupper())
    # n_lower = sum(1 for c in text if c.lower())
    return not(n_upper > 1)

def replace_supchars(text, pattern, replacements):
    # pattern = re.compile("|".join(map(re.escape, replacements.keys())))
    return pattern.sub(lambda match: replacements[match.group(0)], text)

def check_final_punct(text, punctuation):
    if text[-1] in punctuation:
        if text[-1] in ',;:-':
            text = text[:-1] + '.'
    else:
        text += '.'
    return text

In [None]:
# Funkce pro zpracování jedné věty
def process_line(text):
    global paragraph_replacements, character_replacements, supported_chars
    global paragraph_pattern, character_pattern, postprocess_replacements, postprocess_pattern
    global pattern_before_punct, pattern_after_punct, pattern_dash, replace_dash, pattern_multispace
    
    # print("INPUT:   ", text)
    # Replace supcharacter texts at the paragraph level
    paragraph = replace_supchars(text.strip(), paragraph_pattern, paragraph_replacements)
    # print("PAR REPL:", paragraph)
    
    # Replace unsupported characters at the paragraph level
    paragraph = replace_chars(paragraph, character_pattern, character_replacements)
    # print(f"CH REPL: {paragraph}\n")

    # Replace HTML tags
    paragraph = pattern_html.sub(' ', paragraph)
    
    try:    # Parse paragraph
        normalizer.ssml_parse(paragraph)
    except RuntimeError as e:
        print('[!] Parsing paragraph failed in tts_tool => skipping\n')
        print(paragraph)
        print()
        # print(e)
        # raise
        return [], 0   # return empty list => continue in processing next paragraphs
    
    # Normalize & parse to sentences
    sentences = list(normalizer.to_sentences_orto())

    # Go through individual sentences
    correct_sentences = []
    for sentence in sentences:
        # # check sentence starting with lower case
        # if sentence[0].islower():
        #    # remove previous sentence (probably problematic)
        #    prev_sentence = correct_sentences.pop() if correct_sentences else ''
        #    # skip sentence
        #    print('[!] Sentence starts with lowercase => skipping and removing previous sentence\n')
        #    print(sentence)
        #    print(prev_sentence)
        #    print()
        #    continue
        
        # Ošetření pomlčky
        sentence = pattern_dash.sub(replace_dash, sentence)
        if not check_chars(sentence, supported_chars):
            continue
        if is_extremal_sentence(sentence, 3):
            continue
        if not check_word_level(sentence):
            continue

        # --- Postprocess sentence
        # sentence = check_final_punct(sentence, supported_chars['punct'])
        # if sentence[0].islower():
        #     sentence = sentence[0].upper() + sentence[1:]
        # Odstraňte mezery před interpunkčními znaménky
        sentence = pattern_before_punct.sub(r'\1', sentence)
        # Nahraďte více mezer po interpunkci jednou mezerou
        sentence = pattern_after_punct.sub(r'\1 ', sentence)
        sentence = replace_supchars(sentence, postprocess_pattern, postprocess_replacements)
        sentence = pattern_multispace.sub(' ', sentence)
        # Collect the final sentence
        correct_sentences.append(sentence.strip())

    return correct_sentences, len(sentences)

# Funkce pro zpracování jednoho chunku
def process_chunk(chunk):
    # start_time = time.perf_counter()
    processed_sents  = []
    n_orig_sents_in_chunk = 0
    
    for item in chunk['text']:
        sents, n_orig_sents = process_line(item)
        processed_sents.extend(sents)
        n_orig_sents_in_chunk += n_orig_sents
    
    # end_time = time.perf_counter()  # Konec měření času
    # print(f"Čas zpracování chunku: {end_time-start_time:.4f} sekund\n")
    return processed_sents, n_orig_sents_in_chunk

# Funkce pro vytvoření bloků (chunků) z datasetu
def chunks(dataset, chunk_size):
    for i in range(0, len(dataset), chunk_size):
        yield dataset[i:i + chunk_size]

### Process dataset

In [None]:
# Načtení datasetu
dataset = load_dataset('text', data_files=inp_paragraph_file)['train']

# Init sentence counters
n_processed_sentences, n_orig_sentences = 0, 0

# Otevřeme výstupní soubor
with open(out_sentence_file, 'a', encoding='utf-8') as f_out:
    with ProcessPoolExecutor(max_workers=N_CPUS) as executor:
        # Pro každý chunk z datasetu paralelně zpracováváme věty
        futures = [executor.submit(process_chunk, chunk) for chunk in chunks(dataset, chunk_size)]

        # Collected results
        batch = []
        
        # Sběr výsledků a průběžné zapisování do souboru
        for future in as_completed(futures):
            try:
                results, n_sentences = future.result()
                n_processed_sentences += len(results)
                n_orig_sentences += n_sentences
                batch.extend(results)  # Přidání výsledků do dávky

                if len(batch) >= batch_size:
                    f_out.write('\n'.join(batch) + '\n') # Write batch to file
                    batch.clear() # Vymazání dávky pro nové výsledky

            except Exception as e:
                # raise e
                print(e)
                # traceback.print_exc()
                # raise
                continue

        # Zápis zbývajících výsledků v dávce, pokud nějaké zůstaly
        if batch:
            f_out.write('\n'.join(batch) + '\n')


In [None]:
print(f'Original sentences:  {n_orig_sentences}')
print(f'Processed sentences: {n_processed_sentences}')
print(f'Used %:              {(n_processed_sentences/n_orig_sentences):.2%}')