# Training data generation

In this notebook, we will generate the training data for the neural network.

## 1. Global setup
Set up logging and paths

In [None]:
try:
    with open("global_setup.py") as setupfile:
        exec(setupfile.read())
except FileNotFoundError:
    print('Setup already completed')

## 2. Import the required packages

In [None]:
import random
import re
import pprint
from src.wikipedia import Wikipedia

## 3. Text distortion function

<i>distortText</i> function takes two parameters: <br><br>
text: string
swap_type: string; values:
* <i>symbolswap</i>: for each word in a sentence, take two consequent symbols and swap them (orange - oragne)
* <i>symboldelete</i>: for each word in a sentence, delete a random symbol (orange - ornge)
* <i>symbolreplacerandom</i>: for each word in a sentence, replace a symbol with a random symbol from the alphabet (orange - xrange)
* <i>symbolreplaceprev</i>: for each word in a sentence, replace a symbol with a previous symbol (orange - orrnge)
* <i>wordtrimright</i>: for each word in a sentence, remove a random number of symbols from the right (orange - o)
* <i>wordswapprev</i>: take two consequent words in one sentence and swap them (I like oranges - like I oranges)
* <i>deletespaces</i>: delete whitespace between two consequent words in a sentence (I like oranges - Ilike oranges)


In [None]:
def distortText(text, swap_type = "symboldelete"):
       
    # Distortion types
    distortions_symbols = ["symbolswap","symboldelete","symbolreplacerandom","symbolreplaceprev"]
    distortions_words = ["wordtrimright", "wordswapprev"]
    distortions_sentences = ["deletespaces"]
    # Symbol operations: max 2 symbols per word, max 2 symbol distortion types per word
    # Word trim: max 1 per word
    # Delete spaces: max all spaces, min 0 spaces
    
    list_words = text.split()
    
    distortion_prob_symbol = 0.5
    for i in range(len(list_words)):
        r = random.random()
        if swap_type == "symbolswap":
            # SYMBOL SWAP
            ## min distorted symbol = 0, max distorted symbol = len-2 because of indexing with 0 and this is the 1st swappable symbol
            distorted_symbol = round((len(list_words[i]) - 2) * random.random())
            s1 = list_words[i][distorted_symbol + 1]
            #print("{}---{}-{}-{}-{}".format(distorted_symbol, list_words[i][0:distorted_symbol - 1], list_words[i][distorted_symbol], s1, list_words[i][distorted_symbol + 1:]))
            #print(list_words[i][0:(distorted_symbol - 1) * (distorted_symbol > 0)])
            list_words[i] = list_words[i][0:distorted_symbol] + s1 + list_words[i][distorted_symbol] + list_words[i][distorted_symbol + 2:]
        elif swap_type == "symboldelete":
            # SYMBOL DELETE
            ## min distorted symbol = 0, max distorted symbol = len-1
            distorted_symbol = round((len(list_words[i]) - 1) * random.random())
            list_words[i] = list_words[i][0:distorted_symbol] + list_words[i][distorted_symbol + 1:]
        elif swap_type == "symbolreplacerandom":
            symbols = "abcdefghijklmnopqrstuvwxyz"
            # SYMBOL REPLACE RANDOM
            ## min distorted symbol = 0, max distorted symbol = len-1
            distorted_symbol = round((len(list_words[i]) - 1) * random.random())
            replace = symbols[round((len(symbols) - 1) * random.random())]
            list_words[i] = list_words[i][0:distorted_symbol] + replace + list_words[i][distorted_symbol + 1:]
        elif swap_type == "symbolreplaceprev":
            # SYMBOL REPLACE PREVIOUS
            ## min distorted symbol = 1, max distorted symbol = len - 1
            distorted_symbol = round(1 + (len(list_words[i]) - 2) * random.random())
            #print("{} - {}".format(distorted_symbol, list_words[i]))
            #print("{}---{}-{}-{}".format(distorted_symbol, list_words[i][0:(distorted_symbol)], list_words[i][distorted_symbol - 1], list_words[i][distorted_symbol + 1 * (len(list_words[i]) > distorted_symbol):distorted_symbol + 1 * (len(list_words[i]) > distorted_symbol)]))
            list_words[i] = list_words[i][0:(distorted_symbol)] + list_words[i][distorted_symbol - 1] + list_words[i][distorted_symbol + 1:]
        
        if swap_type == "wordtrimright":
            # WORD TRIM RIGHT
            ## min symbol = 1, max distorted symbol = len - 1
            trim_from = round(1 + (len(list_words[i]) - 1) * random.random())
            list_words[i] = list_words[i][:trim_from]
    
    if swap_type == "wordswapprev":
        # WORD SWAP
        word_id = round((len(list_words) - 2) * random.random())
        s1 = list_words[word_id]
        list_words[word_id] = list_words[word_id + 1]
        list_words[word_id + 1] = s1
        
    if swap_type == "deletespaces":
        # DELETE WHITESPACES
        ## -2 because there are len - 1 spaces in total
        space_id = round((len(list_words) - 2) * random.random())
        list_words[space_id] = list_words[space_id] + list_words[space_id + 1]
        list_words.pop(space_id + 1)
            
    distortedText = " ".join(list_words)
    return distortedText

Test the distortion function on a few examples.

In [None]:
text = "Antons likes pizza"
print(distortText(text=text,swap_type="symbolswap"))

print(distortText(text=text,swap_type="symboldelete"))

print(distortText(text=text,swap_type="symbolreplacerandom"))

print(distortText(text=text,swap_type="symbolreplaceprev"))

print(distortText(text=text,swap_type="wordtrimright"))

print(distortText(text=text,swap_type="wordswapprev"))

print(distortText(text=text,swap_type="deletespaces"))

## 4. Load the Simple Wikipedia

In [None]:
wikipedia = Wikipedia(
    language="simple",
    cache_directory_url=False
)

## 5. Clean-up the data

In [None]:
# Cleaning up simple wikipedia texts
pattern_ignored_words = re.compile(
    r"""
    (?:(?:thumb|thumbnail|left|right|\d+px|upright(?:=[0-9\.]+)?)\|)+
    |^\s*\|.+$
    |^REDIRECT\b""",
    flags=re.DOTALL | re.UNICODE | re.VERBOSE | re.MULTILINE)
pattern_new_lines = re.compile('[\n\r ]+', re.UNICODE)
texts = [wikipedia.documents[i].text for i in range(len(wikipedia.documents))]
texts = [pattern_ignored_words.sub('', texts[i]) for i in range(len(texts))]
texts = [pattern_new_lines.sub(' ', texts[i]) for i in range(len(texts))]
texts = [texts[i].replace("\\", "") for i in range(len(texts))]
texts = [texts[i].replace("\xa0", " ") for i in range(len(texts))]

## 6. Divide into sentences

In [None]:
# Simple wikipedia article texts into single sentences
sentences = []
sentences += [tokenize.sent_tokenize(texts[i]) for i in range(len(texts))]
#sentences += [texts[i].split(". ") for i in range(len(texts))] #len(texts)
# Now sentences is a list of lists. The next expression flattens it into one long list.
sentences = [item for sublist in sentences for item in sublist]

## 7. Clean-up sentences and remove too long and short ones

Median sentence length is 83 symbols. We remove the sentences shorter than 20 symbols and longer than 100 symbols to clean up the dataset.<br><br>
We also remove the sentences starting with "Category:", "Related pages", "References", "Other websites:". <br>
These are technical Wikipedia pages that we do not need. Need to check for more, e.g. "Gallery".

In [None]:
print(len(sentences))
for i in reversed(range(len(sentences))):
    if len(sentences[i]) < 20 or len(sentences[i]) > 100 \
        or sentences[i][0:9] == "Category:" \
        or sentences[i][0:13] == "Related pages" \
        or sentences[i][0:10] == "References" \
        or sentences[i][0:14] == "Other websites":
        sentences.pop(i)
print(len(sentences))

#Gallery - do something?

## 8. Generate training data

In [None]:
data = sentences # first run on 
for sentence in sentences:
    pass #apply distortion function here