# Word embedding-based Homograph Disambigation Logistic Regression 

### Model
Multinomial (one model per class) logistic regression (LR) for homograph disambiguation (HD).

#### LR Features
The feature for each homograph pronunciation label is a BERT token embedding. Each embedding is taken from the token embeddings for a sentence containing the homograph. 

### Data
[Wikipedia Homograph Data (WHD)](https://github.com/google-research-datasets/WikipediaHomographData); see:
Gorman, K., Mazovetskiy, G., and Nikolaev, V. (2018). [Improving homograph disambiguation with machine learning.](https://aclanthology.org/L18-1215/) In Proceedings of the Eleventh International Conference on Language Resources and Evaluation, pages 1349-1352. Miyazaki, Japan.

### Context
 Nicolis and Klimkov (2021; [NK2021](https://www.researchgate.net/profile/Marco-Nicolis-2/publication/354151448_Homograph_disambiguation_with_contextual_word_embeddings_for_TTS_systems/links/613619910360302a0083e34b/Homograph-disambiguation-with-contextual-word-embeddings-for-TTS-systems.pdf)) claim SOTA results with word-embedding-featured HD LR. However, ~%40 of the classes (homograph pronunciations/wordids) in the WHD test set are represented by either one instance, or are _not_ represented in the WHD test set used by NK2021. Over the entire WHD, 17 of the homographs have only 1 pronunciation class. NK2021 take 'conglomerate' out from the data, as one pronunciation in the test set is not present in the training set. They use the rest of the WHD data as is, which possibly calls into question their results. Would the model(s) perform as well with a more robust test set and with each homograph having at least two pronunciations from which to select?

### Purpose
The HD LR in this notebook is developed to replicate experimentation found in NK2021. 

### Use
1. Compare metrics obtained with NK2021-replicated HD LR  using the WHD to a data set that provides better class coverage.
2. Compare metrics obtained in #1 to those obtained with multi-class token classifier developed in [Seale (2021)](https://academicworks.cuny.edu/cgi/viewcontent.cgi?article=5591&context=gc_etds).
3. Determine if SOTA claims using HD LR still hold given data issues, and when compared to multi-class neural nets.

# TO DO: 

1. Handle homographs with more than 2 pronunciations. (?)
2. Continue to align with NK2021. 

## Notes:
I only have this running on CPU right now. Struggled with getting MXNET to play well with my GPU-enabled laptop. Makes sense to do. Running this with BERT_LARGE embeddings is ridiculously slow. Model training: 159it [4h:21m:43s, 98.76s/it]

In [1]:
import os
import regex as re
import csv
import glob
import operator
from tqdm import tqdm
from typing import Dict, List, Tuple
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, balanced_accuracy_score

# MXNET
import mxnet as mx
from mxnet import nd, autograd, gluon
from mxnet.gluon import nn, Trainer
from mxnet.gluon.data import DataLoader, ArrayDataset
from mxnet.contrib import text

# https://pypi.org/project/BERT-embedding/
from bert_embedding import BertEmbedding

In [2]:
# PATHS
# https://github.com/google-research-datasets/WikipediaHomographData
TRAIN_PATH = "./WikipediaHomographData/data/train/*.tsv"
TEST_PATH =  "./WikipediaHomographData/data/eval/*.tsv"
WORD_IDS_PATH = "./WikipediaHomographData/data/wordids.tsv"

# Data from Seale 2021 dissertation
# TRAIN_PATH = "./WHD/train_whd_fren_34_low_prev_restricted/*.tsv"
# TEST_PATH =  "./WHD/test_whd_fren_34_low_prev_restricted/*.tsv"

In [3]:
#HELPER FUNCTION: Used to generate global variables WORDIDS_LAB_DICT, WORDIDS_IDX_DICT
def make_wordids_dict() -> Tuple[Dict, Dict]:
    # FUNCTIONALITY: Makes dictionary used to convert wordids to 0,1 and vice versa
    # OUTPUT: { homograph_str : {0: wordid_1_str, 1: wordid_2_str}, ...},
    # { homograph_str : {wordid_1_str: 0, wordid_2_str: 1}, ...}
    lab_dict : Dict = {}
    idx_dict : Dict = {}
    df : pd.DataFrame = pd.read_csv(WORD_IDS_PATH, sep="\t")
    for hom, e in df.groupby("homograph"): 
        idx = 0
        l_dict : Dict = {}
        i_dict : Dict = {}
        for wid in e["wordid"]:
            i_dict[idx] = wid
            l_dict[wid] = idx
            idx += 1
        lab_dict[hom] = l_dict
        idx_dict[hom] = i_dict
    return lab_dict, idx_dict

In [4]:
# GLOBAL VARIABLES
# Used for cleaning tokens in get_embedding()
REGEX = r"(?<=[^A-Za-z])(?=[A-Za-z])|(?<=[A-Za-z])(?=[^A-Za-z])"
SUB = " "

# Used to create functionality to get BERT embeddings
BERT_SMALL = 'bert_12_768_12'
BERT_LARGE = 'bert_24_1024_16'
SENTENCE_LENGTH = 100 #Default sentence length is too short for some WHD sentences
BERT_EMBEDDING = BertEmbedding(model=BERT_LARGE, max_seq_length=SENTENCE_LENGTH)

# Used for training, eval
SEED_1 = mx.random.seed(12345)
TRAIN_DATA_SIZE = 100
VAL_DATA_SIZE = 10
BATCH_SIZE = 10
EPOCHS = 10
THRESHOLD = 0.5

# Used for for label conversion
WORDIDS_LAB_DICT, WORDIDS_IDX_DICT = make_wordids_dict()

# Check out DICTS
print("WORDIDS_LAB_DICT: wordids as keys in dictionary that serves as value for homograph key")
for e in list(WORDIDS_LAB_DICT.items())[:5]:
    print(e)
print("\n")
print("WORDIDS_IDX_DICT: ints as keys in dictionary that serves as value for homograph key")
for e in list(WORDIDS_IDX_DICT.items())[:5]:
    print(e)

WORDIDS_LAB_DICT: wordids as keys in dictionary that serves as value for homograph key
('abstract', {'abstract_adj-nou': 0, 'abstract_vrb': 1})
('abuse', {'abuse_nou': 0, 'abuse_vrb': 1})
('abuses', {'abuses_nou': 0, 'abuses_vrb': 1})
('addict', {'addict_nou': 0, 'addict_vrb': 1})
('advocate', {'advocate_nou': 0, 'advocate_vrb': 1})


WORDIDS_IDX_DICT: ints as keys in dictionary that serves as value for homograph key
('abstract', {0: 'abstract_adj-nou', 1: 'abstract_vrb'})
('abuse', {0: 'abuse_nou', 1: 'abuse_vrb'})
('abuses', {0: 'abuses_nou', 1: 'abuses_vrb'})
('addict', {0: 'addict_nou', 1: 'addict_vrb'})
('advocate', {0: 'advocate_nou', 1: 'advocate_vrb'})


## Functions 

In [10]:
def get_embedding(sentence : str, tsv_name : str) -> List:
    # FUNCTIONALITY: Obtain a homograph embedding from all 
    # the token embeddings of a sentence containing that homograph
    # INPUT: 
    #    sentence: string, 1 sentence containing a homograph from tsv of sentences
    #    csv_name: string, name of csv of training data for 1 homograph, tsv name is the homograph 
    # OUTPUT: array of float32s, the embedding of the homograph
    
    # Clarify that the csv name is the homograph
    homograph = tsv_name
    
    # Isolate homograph tokens; separate non-alabetic characters from alphabetic ones with a space,
    # preventing occurences like '4Minute'
    sentence_clean = re.sub(REGEX, SUB, sentence, 0)
    
    # Obtain word embeddings for sentence
    embs = BERT_EMBEDDING([sentence_clean])
    
    # Find homograph embedding of embeddings for each token in sentence
    df = pd.DataFrame({'token': embs[0][0], 'embedding': embs[0][1]})
    homograph_emb = df[df['token'] == homograph]['embedding']
    homograph_emb = homograph_emb.tolist()
    
    if len(homograph_emb) < 1:
        # Didn't find homograph in sentence, check out the problem
        print(homograph)
        print(sentence)
        print(embs[0][0])
    
    return homograph_emb[0]

def get_data(path : str) -> Tuple[List, List[str], str]:
    # FUNCTIONALITY: Get pronunciation labels and embedding features for LR
    # INPUT: Path to tsv with labeled sentences; 1 tsv per homograph
    # OUTPUT: List of embedding features, list of pronunciation labels, the homograph text string
    
    labels : List[str] = []
    emb_features: List = []
    sentences : List[str] = []
    
    with open(path, "r", encoding="utf8") as source: 
        for row in csv.DictReader(source, delimiter="\t"):
            labels.append(WORDIDS_LAB_DICT[os.path.basename(path[:-4])][row["wordid"]])
            embedding = get_embedding(row['sentence'], os.path.basename(path)[:-4])
            emb_features.append(nd.array(embedding))
            # sentences used in debugging
            sentences.append(row['sentence'])

    labels = nd.array(labels)
    labels = labels.astype("float32")
    
    homograph : str = os.path.basename(path)[:-4]
        
    return emb_features, labels, homograph, sentences


# Following two functions taken from: 
# https://mxnet.apache.org/versions/1.5.0/tutorials/gluon/logistic_regression_explained.html

def train_model(train_dataloader):
    cumulative_train_loss = 0

    for i, (data, label) in enumerate(train_dataloader):
        with autograd.record():
            # Do forward pass on a batch of training data
            output = lr_net(data)

            # Calculate loss for the training data batch
            loss_result = loss(output, label)

        # Calculate gradients
        loss_result.backward()

        # Update parameters of the network
        trainer.step(BATCH_SIZE)

        # sum losses of every batch
        cumulative_train_loss += nd.sum(loss_result).asscalar()

    return cumulative_train_loss

def validate_model(THRESHOLD, val_dataloader):
    cumulative_val_loss = 0

    for i, (val_data, val_ground_truth_class) in enumerate(val_dataloader):
        # Do forward pass on a batch of validation data
        output = lr_net(val_data)

        # Similar to cumulative training loss, calculate cumulative validation loss
        cumulative_val_loss += nd.sum(loss(output, val_ground_truth_class)).asscalar()

        # Get prediction as a sigmoid
        prediction = lr_net(val_data).sigmoid()

        # Convert neuron outputs to classes
        predicted_classes = mx.ndarray.abs(mx.nd.ceil(prediction - THRESHOLD))

        # Update validation accuracy
        accuracy.update(val_ground_truth_class, predicted_classes.reshape(-1))
        targs_preds = (val_ground_truth_class, predicted_classes.reshape(-1))

        # Calculate probabilities of belonging to different classes. F1 metric works only with this notation
        prediction = prediction.reshape(-1)
        probabilities = mx.nd.stack(1 - prediction, prediction, axis=1)

        #f1.update(val_ground_truth_class, probabilities)

    return cumulative_val_loss, targs_preds

## Train and evaluate

In [12]:
#https://mxnet.apache.org/versions/1.5.0/tutorials/gluon/logistic_regression_explained.html
lr_net = nn.HybridSequential()

with lr_net.name_scope():
    lr_net.add(nn.Dense(units=10, activation='relu'))
    lr_net.add(nn.Dense(units=1)) 

#Hyperparameters from NK2021
lr_net.initialize(mx.init.Xavier())
trainer = Trainer(params=lr_net.collect_params(), optimizer='adam',
                  optimizer_params={'learning_rate': 0.001, 'wd' : 0.01})

loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
accuracy = mx.metric.Accuracy()

targ_labels = []
pred_labels = []

#Train, eval a model for each tsv
for train_path in tqdm(glob.iglob(TRAIN_PATH)):
    
    features_train, targets_train, homograph, sentences = get_data(train_path)
    train_dataset = ArrayDataset(features_train, targets_train)   
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)

    cum_train_loss = train_model(train_dataloader)
        
    test_path = train_path.replace("train", "eval")
    features_test, targets_test, homograph, sentences = get_data(test_path)
    
    val_dataset = ArrayDataset(features_test, targets_test)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

    for e in range(EPOCHS):
        avg_train_loss = train_model(train_dataloader) / TRAIN_DATA_SIZE
        cumulative_val_loss, targs_preds = validate_model(THRESHOLD, val_dataloader)
        avg_val_loss = cumulative_val_loss / VAL_DATA_SIZE
        
        hom_dict = WORDIDS_IDX_DICT[homograph]
        try:
            targ_labels.extend(hom_dict[int(i.asscalar())] for i in targs_preds[0])
            pred_labels.extend(hom_dict[int(i.asscalar())] for i in targs_preds[1])
        except:
            print(hom_dict)
            print(targs_preds)
        accuracy.reset()


159it [4:21:43, 98.76s/it]


## Metrics 

In [13]:
print("Accuracy")
print(accuracy_score(targ_labels, pred_labels))
print("Balanced accuracy")
print(balanced_accuracy_score(targ_labels, pred_labels))

Accuracy
0.9831309904153355
Balanced accuracy
0.9575439891718961


