# Prepare IWSLT2012/2011 by Che et al. data

Used some code from https://github.com/attilanagy234/neural-punctuator

## Create clean text from unparsed files

In [1]:
import os
import re
import numpy as np
import pickle
from transformers import BertTokenizer
from tqdm import tqdm

In [2]:
# path of datasets (add / at the end)
dataPath = "IWSLTche/Data/"

# path of the unparsed files
rawPath = "IWSLTche/RAW/"

punctEncode = {
    'O': '',
    'COMMA': ',',
    'PERIOD': '.',
    'QUESTION': '?'
}

# Load the data from the dataset into a string
def RAWtoText(sourcePath, targetPath):
    fullText = ""
    with open(sourcePath, "rb") as file:
        for line in file:
            # dataset uses \r\n for newlines
            row = line.decode('utf-8', errors='ignore').replace('\r\n', '').split('\t')
            fullText += row[0]
            fullText += punctEncode[row[1]]
            fullText += ' '

    fullText = fullText[:-1] # remove last space
    
    # write to file
    with open(targetPath, 'w', encoding='utf-8') as f:
        f.write(fullText)

In [3]:
RAWtoText(rawPath + "test2011", dataPath + "testText.txt")
RAWtoText(rawPath + "test2011asr", dataPath + "testAsrText.txt")
RAWtoText(rawPath + "train2012", dataPath + "trainText.txt")
RAWtoText(rawPath + "dev2012", dataPath + "devText.txt")

## Create Pickle data

In [4]:
# labels
LABEL_NOTHING = 0
LABEL_COMMA = 1
LABEL_PERIOD = 2
LABEL_QUESTION = 3
labelNames = ["O", "COMMA", "PERIOD", "QUESTION"]

# encode the punctuation label as a number
punctEncode = {
    "O": LABEL_NOTHING,
    "COMMA": LABEL_COMMA,
    "PERIOD": LABEL_PERIOD,
    "QUESTION": LABEL_QUESTION
}

# which BERT network to use
modelName = "bert-base-uncased"

# load BERT tokenizer
tokenizer = BertTokenizer.from_pretrained(modelName)

In [5]:
# loading data
with open(dataPath + "trainText.txt", 'r', encoding="utf-8") as f:
    train_text = f.read()
with open(dataPath + "devText.txt", 'r', encoding="utf-8") as f:
    valid_text = f.read()
with open(dataPath + "testText.txt", 'r', encoding="utf-8") as f:
    test_text = f.read()
with open(dataPath + "testAsrText.txt", 'r', encoding="utf-8") as f:
    testasr_text = f.read()

# put all datasets together for easy batch operations
datasets = train_text, valid_text, test_text, testasr_text

In [6]:
# prepare data for the model
# code comes mostly from neural-punctuator by attilanagy234
def clean_text(text):
    
    # replacing special tokens
    text = text.replace('!', '.')
    text = text.replace(':', ',')
    text = text.replace('--', ',')
    
    reg = "(?<=[a-zA-Z])-(?=[a-zA-Z]{2,})"
    r = re.compile(reg, re.DOTALL)
    text = r.sub(' ', text)
    
    text = re.sub(r'\s-\s', ' , ', text)
    
    text = text.replace(';', '.')
    text = text.replace(' ,', ',')
    text = text.replace('♫', '')
    text = text.replace('...', '')
    text = text.replace('.\"', ',')
    text = text.replace('"', ',')

    text = re.sub(r'--\s?--', '', text)
    text = re.sub(r'\s+', ' ', text)
    
    text = re.sub(r',\s?,', ',', text)
    text = re.sub(r',\s?\.', '.', text)
    text = re.sub(r'\?\s?\.', '?', text)
    text = re.sub(r'\s+', ' ', text)
    
    text = re.sub(r'\s+\?', '?', text)
    text = re.sub(r'\s+,', ',', text)
    text = re.sub(r'\.[\s+\.]+', '. ', text)
    text = re.sub(r'\s+\.', '.', text)
    
    return text.strip().lower()

target_token2id = {t: tokenizer.encode(t)[-2] for t in ",.?"}
target_ids = list(target_token2id.values())
target_ids

id2target = {
    0: 0,
    -1: -1,
}

for i, ti in enumerate(target_ids):
    id2target[ti] = i+1

def create_target(text):
    encoded_words, targets = [], []
    
    words = text.split(' ')

    for word in words:
        target = 0
        for target_token, target_id in target_token2id.items():
            if word.endswith(target_token):
                word = word.rstrip(target_token)
                target = id2target[target_id]

        encoded_word = tokenizer.encode(word, add_special_tokens=False)
        
        for w in encoded_word:
            encoded_words.append(w)
        for _ in range(len(encoded_word)-1):
            targets.append(-1)
        targets.append(target)
        
        #print([tokenizer._convert_id_to_token(ew) for ew in encoded_word], target)
        assert(len(encoded_word)>0)

    #encoded_words = [tokenizer.cls_token_id or tokenizer.bos_token_id] +\
    #                encoded_words +\
    #                [tokenizer.sep_token_id or tokenizer.eos_token_id]
    #targets = [-1] + targets + [-1]
    
    return encoded_words, targets

In [7]:
# clean the special characters from the texts
datasets = [clean_text(text) for text in datasets]

# encode the texts and generate labels
encoded_texts, targets = [], []

for ds in tqdm(datasets):
    x = list(zip(*(create_target(ds))))
    encoded_texts.append(x[0])
    targets.append(x[1])


# make folder for prepared dataset for specific BERT model
os.makedirs(dataPath + modelName, exist_ok=True)

# store
for i, name in enumerate(('train', 'valid', 'test', 'testasr')):
    with open(dataPath + f'{modelName}/{name}_data.pkl', 'wb') as f:
        pickle.dump((encoded_texts[i], targets[i]), f)

100%|█████████████████████████████████████████████| 4/4 [02:03<00:00, 30.84s/it]
