# Description
**Functionality**: This module formats the three way split Wikipedia Homograph Data (WHD) for BERT token classification using Huggingface tools. 

**Use**: The BERT model finetuning functionality from Huggingface expects CoNLL03-formatted data. The output from this model will be used to finetune in order to predict pronunciation labels on homographs.

### Imports

In [None]:
import os
from glob import glob
import pandas as pd
from tqdm import tqdm
from typing import List, Dict
import spacy

### Variables

In [None]:
#Paths 
WHD_DATA = "C:/Users/jseal/Dev/dissertation/Data/WikipediaHomographData/data/"
METADATA = WHD_DATA + 'WikipediaHomographData.csv'
LABELS = WHD_DATA + "bert_data/labels.txt"
#Source paths
TRAIN = WHD_DATA + "three_split_data/train/"
VAL = WHD_DATA + "three_split_data/valid/"
TEST = WHD_DATA + "three_split_data/test/"
#Destination paths
BERT_TRAIN = WHD_DATA + "bert_data/train/"
BERT_DEV = WHD_DATA + "bert_data/dev/"
BERT_TEST = WHD_DATA + "bert_data/test/" 
# Zip for source, destination data paths
ORIGINAL_SETS = [TRAIN, VAL, TEST]
BERT_SETS = [BERT_TRAIN, BERT_DEV, BERT_TEST]
SOURCE_DEST = zip(ORIGINAL_SETS, BERT_SETS)

#Train, val, test splits in one file each
TRAIN_TMP = WHD_BERT_DATA + "train.txt.tmp"
VAL_TMP = WHD_BERT_DATA + "val.txt.tmp"
TEST_TMP = WHD_BERT_DATA + "test.txt.tmp"
TMPS = [TRAIN_TMP, VAL_TMP, TEST_TMP]

TRAIN_TXT = WHD_BERT_DATA + "train.txt"
VAL_TXT = WHD_BERT_DATA + "val.txt"
TEST_TXT = WHD_BERT_DATA + "test.txt"
OUTS = [TRAIN_TXT, VAL_TXT, TEST_TXT]

TMPS_OUTS = zip(TMPS, OUTS)

#Tools
nlp = spacy.load('en_core_web_sm')

#Variables
OUTSIDE = "O" #Label for all words that are not a homograph

#Options
pd.set_option('display.max_rows', None)

### Functions

In [None]:
def get_tokens(sentence : str) -> List:
    sent_nlp =  nlp(sentence, disable=['parser', 'tagger', 'ner'])
    tokens = [token.text for token in sent_nlp if not token.is_punct]
    return tokens

def make_str(label : List) -> str: 
    return ' '.join(label)

def make_tsvs(): 
    for PATHS in SOURCE_DEST: # Do this for train, test, valid
        for f in tqdm(glob(PATHS[0] +'*.tsv')): 
            f_name = os.path.basename(f)
            df = pd.read_table(f)
            df = df[['homograph', 'wordid', 'sentence']]
            df['token'] = df.sentence.apply(lambda sentence : get_tokens(sentence))
            df = df.explode('token') # Get one row per token 
            for index, group in df.groupby(df.index):# Create one tsv per sentence; one line per token, label pair
                sentence_dicts = []
                for idx, row in group.iterrows():
                    sentence_dict = {}
                    token = row['token']
                    homograph = row['homograph']
                    sentence_dict['sent_id'] = "{}_{}".format(homograph, index)
                    sentence_dict['token'] = token
                    if token.lower() == homograph: # If the lowercase token is the same as the homograph, label with the wordid
                        sentence_dict['label'] = [row['wordid']]
                    else: 
                        sentence_dict['label'] = [OUTSIDE] # If the token is not the homograph, label with 'O' for 'outside'
                    sentence_dicts.append(sentence_dict)
                df = pd.DataFrame(sentence_dicts)
                df['label'] = df['label']apply(make_str)
                new_f_name = PATHS[1] + f_name[:-4] + "_" + str(index) + '.txt'# Name file with homograph and sentence number
                df.to_csv(new_f_name, sep="\t", header=False, index=False) 
    
def make_tmps(): 
    # Write temporary train, val, and test txt files
    for tmp in TMPS:
        for split_path in ALL_SPLITS:
            with open(tmp, 'w', encoding="utf8") as f_out: 
                for f in glob(split_path + "*"):
                    with open(f, encoding="utf8") as example:
                        lines = example.readlines()
                        for line in lines: 
                            line_list = line.split('\t')
                            f_out.write(line_list[1] + '\t' + line_list[2])
                    f_out.write('\n')
                
def make_txts():
    #Write train, val, and test txt files
    subword_len_counter = 0
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    MAX_LENGTH -= tokenizer.num_special_tokens_to_add()

    for tmp, outfile in TMPS_OUTS:
        with open(tmp, "r", encoding="utf8") as f_p:
            with open(outfile, "w", encoding="utf8") as out_f: 
                for line in f_p:
                    line = line.rstrip()

                    if not line:
                        out_f.write(line +"\n")
                        subword_len_counter = 0
                        continue

                    token = line.split()[0]

                    current_subwords_len = len(tokenizer.tokenize(token))

                    # Token contains strange control characters like \x96 or \x95
                    # Just filter out the complete line
                    if current_subwords_len == 0:
                        continue

                    if (subword_len_counter + current_subwords_len) > MAX_LENGTH:
                        out_f.write("\n")
                        out_f.write(line +"\n")
                        subword_len_counter = current_subwords_len
                        continue

                    subword_len_counter += current_subwords_len

                    out_f.write(line + "\n")
                    
def make_labels(): 
    metadata_df = pd.read_csv(METADATA)
    wordids = metadata_df.wordid.tolist()
    with open(LABELS, 'w') as f:
        for wordid in wordids:
            f.write("{}\n".format(wordid))
        f.write("{}\n".format('O'))

# Script

In [None]:
make_tsvs()
make_temps()
make_txts()
make_labels()