# Description
**Functionality**: This module formats the three way split Wikipedia Homograph Data (WHD) for BERT token classification using Huggingface tools. 

**Use**: The BERT-based token classification model finetuning functionality from Huggingface expects CoNLL03-formatted data. The output from this module will be used to finetune models in order to predict pronunciation labels on homographs.

### Imports

In [1]:
import os
from glob import glob
import pandas as pd
from tqdm import tqdm
from typing import List, Dict
import spacy
from transformers import AutoTokenizer

### Variables

In [2]:
#Paths 
BASE = "C:/Users/jseal/Dev/dissertation/Data/"
WHD_DATA_BASE = BASE + "WikipediaHomographData/data/"
METADATA = WHD_DATA_BASE + 'WikipediaHomographData.csv'
WHD_DATA_IN = WHD_DATA_BASE + "three_split_data/"
WHD_DATA_OUT = BASE + "WHD_Bert/"
LABELS = WHD_DATA_OUT + "labels.txt"

#Source paths
TRAIN = WHD_DATA_IN + "train/"
DEV = WHD_DATA_IN + "dev/"
TEST = WHD_DATA_IN + "test/"
SOURCE_TSVS = [TRAIN, DEV, TEST]

#Destination paths
TRAIN_TSV = WHD_DATA_OUT + "train_tsvs/"
DEV_TSV = WHD_DATA_OUT + "dev_tsvs/"
TEST_TSV = WHD_DATA_OUT + "test_tsvs/"
DESTINATION_TSVS = [TRAIN_TSV, DEV_TSV, TEST_TSV]

SOURCE_DEST = zip(SOURCE_TSVS, DESTINATION_TSVS)

#Tmp Train, val, test splits in one file each
TRAIN_TMP = WHD_DATA_OUT + "train.txt.tmp"
DEV_TMP = WHD_DATA_OUT + "dev.txt.tmp"
TEST_TMP = WHD_DATA_OUT + "test.txt.tmp"
TMPS = [TRAIN_TMP, DEV_TMP, TEST_TMP]

TSVS_TMPS = zip(DESTINATION_TSVS, TMPS)

#Train, val, test splits in one file each
TRAIN_TXT = WHD_DATA_OUT + "train.txt"
DEV_TXT = WHD_DATA_OUT + "dev.txt"
TEST_TXT = WHD_DATA_OUT + "test.txt"
DESTINATIONS = [TRAIN_TXT, DEV_TXT, TEST_TXT]

TMPS_DESTS = zip(TMPS, DESTINATIONS)

#Tools
nlp = spacy.load('en_core_web_sm')

#Label variables
OUTSIDE = "O" #Label for all words that are not a homograph

#Options
pd.set_option('display.max_rows', None)

#Model info
MODEL_NAME = "distilbert-base-cased"


### Functions

In [3]:
def get_tokens(sentence : str) -> List:
    sent_nlp =  nlp(sentence, disable=['parser', 'tagger', 'ner'])
    tokens = [token.text for token in sent_nlp if not token.is_punct]
    return tokens

def make_str(label : List) -> str: 
    return ' '.join(label)

def make_tsvs() -> None: 
    for PATHS in SOURCE_DEST: # Do this for train, test, valid
        for source in tqdm(glob(PATHS[0] +'*.tsv')): 
            source_name = os.path.basename(source)
            df = pd.read_table(source)
            df = df[['homograph', 'wordid', 'sentence']]
            df['token'] = df.sentence.apply(lambda sentence : get_tokens(sentence))
            df = df.explode('token') # Get one row per token 
            for index, group in df.groupby(df.index):# Create one tsv per sentence; one line per token, label pair
                sentence_dicts = []
                for idx, row in group.iterrows():
                    sentence_dict = {}
                    token = row['token']
                    homograph = row['homograph']
                    sentence_dict['sent_id'] = "{}_{}".format(homograph, index)
                    sentence_dict['token'] = token
                    if token.lower() == homograph: # If the lowercase token is the same as the homograph, label with the wordid
                        sentence_dict['label'] = [row['wordid']]
                    else: 
                        sentence_dict['label'] = [OUTSIDE] # If the token is not the homograph, label with 'O' for 'outside'
                    sentence_dicts.append(sentence_dict)
                df = pd.DataFrame(sentence_dicts)
                df['label'] = df['label'].apply(make_str)
                new_f_name = PATHS[1] + source_name[:-4] + "_" + str(index) + '.txt'# Name file with homograph and sentence number
                df.to_csv(new_f_name, sep="\t", header=False, index=False) 
    
def make_tmps() -> None: 
    # Write temporary train, val, and test txt files from tsvs
    for tsv_dir, tmp in TSVS_TMPS:
        with open(tmp, 'w', encoding="utf8") as f_out: 
            for f in glob(tsv_dir + "*"):
                with open(f, encoding="utf8") as example:
                    lines = example.readlines()
                    for line in lines: 
                        line_list = line.split('\t')
                        f_out.write(line_list[1] + '\t' + line_list[2])
                f_out.write('\n')    
                    
def make_txts() -> None:
    MAX_LENGTH = 128
    #Write train, val, and test txt files
    subword_len_counter = 0
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    MAX_LENGTH -= tokenizer.num_special_tokens_to_add()

    for tmp, outfile in tqdm(TMPS_DESTS):
        with open(tmp, "r", encoding="utf8") as f_p:
            with open(outfile, "w", encoding="utf8") as out_f: 
                for line in f_p:
                    line = line.rstrip()

                    if not line:
                        out_f.write(line +"\n")
                        subword_len_counter = 0
                        continue

                    token = line.split()[0]

                    current_subwords_len = len(tokenizer.tokenize(token))

                    # If token contains strange control characters like \x96 or \x95
                    # filter out the line
                    if current_subwords_len == 0:
                        continue

                    if (subword_len_counter + current_subwords_len) > MAX_LENGTH:
                        out_f.write("\n")
                        out_f.write(line +"\n")
                        subword_len_counter = current_subwords_len
                        continue

                    subword_len_counter += current_subwords_len

                    out_f.write(line + "\n")
                    
def make_labels() -> None: 
    metadata_df = pd.read_csv(METADATA)
    wordids = metadata_df.wordid.tolist()
    with open(LABELS, 'w') as f:
        for wordid in wordids:
            f.write("{}\n".format(wordid))
        f.write("{}\n".format('O'))

# Script

In [4]:
make_tsvs()
make_tmps()
make_txts()
make_labels()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 162/162 [00:44<00:00,  3.64it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 162/162 [00:05<00:00, 27.96it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 162/162 [00:06<00:00, 25.36it/s]
3it [00:10,  3.57s/it]
