The problem is as follows: given a word in some context, we need to encode the context and compare it to the example contexts from WordNet. Having calculated some measure of semantic similarity, we then need to choose the best synset from WordNet for this word and then compare it to the right one. For computing sentence similarity, we are going to use a) Lesk's algorithm; b) Google AI's Universal Sentence Encoder Model (USE).

In [1]:
import numpy as np
from tqdm import tqdm
from string import punctuation

from typing import List, Union

import tensorflow as tf
import tensorflow_hub as hub
from scipy.spatial.distance import cdist

import nltk
nltk.download('wordnet')

from nltk.corpus import stopwords
from nltk.corpus import wordnet as wn
from nltk.tokenize import word_tokenize

import stanfordnlp
# stanfordnlp.download('en')

punctuation += '«»—…“”*№–'
stopwords = set(stopwords.words('english'))

import warnings
warnings.filterwarnings('ignore')

[nltk_data] Downloading package wordnet to /Users/master/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
corpus = [[]]

with open('corpus_wsd_50k.txt', 'r') as f:
    for line in f:
        if not line.strip():
            corpus.append([])
            continue
        
        corpus[-1].append(tuple(line.strip().split('\t')))

# Lesk's algorithm

In [3]:
def normalize(text: Union[str, List[str]]) -> List[str]:
    if isinstance(text, list):
        return text
    
    output = []
    for token in word_tokenize(text.lower()):
        token = token.strip(punctuation)
        if not token or token in stopwords:
            continue
        output.append(token)
    
    return output

In [4]:
def get_overlap(sent1: Union[str, List[str]], sent2: Union[str, List[str]]) -> int:
    tokens1, tokens2 = tuple(map(normalize, (sent1, sent2)))
    return len(set(tokens1) & set(tokens2))

In [5]:
def find_best_synset_lesk(word: str, context: List[str]) -> int:
    best_sense = 0
    best_overlap = 0
    
    for i, synset in enumerate(wn.synsets(word)):
        for example in synset.examples():
            overlap = get_overlap(context, example)
            if overlap > best_overlap:
                best_overlap = overlap
                best_sense = i
    
    return best_sense

In [6]:
accuracy, total = 0, 0

for instance in tqdm(corpus[:1000]):
    context = list(map(lambda x: x[-1], instance))
    
    for word in instance:
        if len(word) == 3:
            true = wn.lemma_from_key(word[0]).synset()
            pred = wn.synsets(word[1])[find_best_synset_lesk(word[1], context)]
            
            if true == pred:
                accuracy += 1
            total += 1

100%|██████████| 1000/1000 [00:44<00:00, 22.44it/s]


In [7]:
print(f'Lesk\'s algorithm: {accuracy / total * 100:.2f}% of disambiguations are correct.')

Lesk's algorithm: 53.04% of disambiguations are correct.


One thing to note is that in its current form, our implementation of Lesk's algorithm doesn't take POS tags into account. Let's fix that to try and improve our metric.

# Lesk's algorithm + StanfordNLP

We are going to use this library (known for its good performance) to find POS tags and lemmas. Hopefully finding intersections between sets of lemmas and constraining our search with the right POS tags will help the metric rise.

In [8]:
nlp = stanfordnlp.Pipeline(processors='tokenize,lemma,pos', tokenize_pretokenized=True)

Use device: cpu
---
Loading: tokenize
With settings: 
{'model_path': '/Users/master/stanfordnlp_resources/en_ewt_models/en_ewt_tokenizer.pt', 'pretokenized': True, 'lang': 'en', 'shorthand': 'en_ewt', 'mode': 'predict'}
---
Loading: lemma
With settings: 
{'model_path': '/Users/master/stanfordnlp_resources/en_ewt_models/en_ewt_lemmatizer.pt', 'lang': 'en', 'shorthand': 'en_ewt', 'mode': 'predict'}
Building an attentional Seq2Seq model...
Using a Bi-LSTM encoder
Using soft attention for LSTM.
Finetune all embeddings.
[Running seq2seq lemmatizer with edit classifier]
---
Loading: pos
With settings: 
{'model_path': '/Users/master/stanfordnlp_resources/en_ewt_models/en_ewt_tagger.pt', 'pretrain_path': '/Users/master/stanfordnlp_resources/en_ewt_models/en_ewt.pretrain.pt', 'lang': 'en', 'shorthand': 'en_ewt', 'mode': 'predict'}
Done loading processors!
---


In [9]:
cache = {}

def normalize(text: Union[str, List[str]]) -> List[str]:
    if isinstance(text, list):
        text = ' '.join(text)
    
    if tuple(text) in cache:
        return cache[tuple(text)]
    
    normalized = [token[2] for token in nlp([text]).conll_file.sents[0]
                  if token[2].strip(punctuation)]
    cache[tuple(text)] = normalized
    
    return normalized

In [10]:
pos_alignment = {
    'NOUN': 'n',
    'PROPN': 'n',
    'PRON': 'n',
    'VERB': 'v',
    'AUX': 'v',
    'ADV': 'r'
}

In [11]:
def find_best_synset_lesk(word: str, pos: str, context: List[str]) -> int:
    best_sense = 0
    best_overlap = 0
        
    args = [word]
    
    # We do not take adjective POS tags into account
    # because there are two different kinds of adjectives in WordNet.
    # UD tagging format does not differentiate between them.
    if pos in pos_alignment:
        args.append(pos_alignment[pos])
    
    for i, synset in enumerate(wn.synsets(*args)):
        for example in synset.examples():
            overlap = get_overlap(context, example)
            if overlap > best_overlap:
                best_overlap = overlap
                best_sense = i
    
    return best_sense

In [12]:
accuracy, total = 0, 0

for instance in tqdm(corpus[:1000]):
    context = list(map(lambda x: x[-1], instance))
    pos_tags = [token[3] for token in nlp([context]).conll_file.sents[0]]
    
    for word, pos in zip(instance, pos_tags):
        if len(word) == 3:
            true = wn.lemma_from_key(word[0]).synset()
            pred = wn.synsets(word[1])[find_best_synset_lesk(word[1], pos, context)]
            
            if true == pred:
                accuracy += 1
            total += 1

100%|██████████| 1000/1000 [17:02<00:00,  1.02s/it]


In [13]:
print(f'Lesk\'s algorithm: {accuracy / total * 100:.2f}% of disambiguations are correct.')

Lesk's algorithm: 49.10% of disambiguations are correct.


# Universal Sentence Encoder by Google AI

The library tensorflow-hub offers a convenient API to get sentence embeddings from USE.

In [14]:
class USEForWordNet:
    def __init__(self, module_url: str = 'https://tfhub.dev/google/universal-sentence-encoder-large/5'):
        print(f'Loading module {module_url} ...')
        self.model = hub.load(module_url)
        print(f'Module {module_url} is loaded.')
        self.cache = {}
        
    def embed(self, sentences: List[str]) -> np.ndarray:
        return np.array(self.model(sentences))
    
    def get_synset_embeddings(self, word: str) -> np.ndarray:
        if word in self.cache:
            return self.cache[word]
        
        synset_embeddings = np.vstack(
            [np.mean(
                self.embed(synset.examples()),
                axis=0
            ) for synset in wn.synsets(word)]
        )
        self.cache[word] = synset_embeddings
        
        return synset_embeddings

use_for_wn = USEForWordNet()

Loading module https://tfhub.dev/google/universal-sentence-encoder-large/5 ...
Module https://tfhub.dev/google/universal-sentence-encoder-large/5 is loaded.


In [15]:
def find_best_synset_use(word: str, context: str) -> int:
    synset_embeddings = use_for_wn.get_synset_embeddings(word)
    context_use = use_for_wn.embed([context])

    return cdist(context_use, synset_embeddings, metric='cosine')[0].argmin()

In [16]:
accuracy, total = 0, 0

for instance in tqdm(corpus[:1000]):
    context = ' '.join(list(map(lambda x: x[-1], instance)))
    
    for word in instance:
        if len(word) == 3:
            true = wn.lemma_from_key(word[0]).synset()
            pred = wn.synsets(word[1])[find_best_synset_use(word[1], context)]
            
            if true == pred:
                accuracy += 1
            total += 1

  2%|▏         | 15/1000 [01:12<1:09:21,  4.22s/it]ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3326, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-16-b6a11bbaaa26>", line 9, in <module>
    pred = wn.synsets(word[1])[find_best_synset_use(word[1], context)]
  File "<ipython-input-15-c3a7bebbc105>", line 2, in find_best_synset_use
    synset_embeddings = use_for_wn.get_synset_embeddings(word)
  File "<ipython-input-14-5afee7edaa2b>", line 19, in get_synset_embeddings
    ) for synset in wn.synsets(word)]
  File "<ipython-input-14-5afee7edaa2b>", line 19, in <listcomp>
    ) for synset in wn.synsets(word)]
  File "<ipython-input-14-5afee7edaa2b>", line 9, in embed
    return np.array(self.model(sentences))
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow_core/python/saved_model/load.py", line 438, in _call_attribute
    return instanc

KeyboardInterrupt: 

In [None]:
print(f'Universal Sentence Encoder: {accuracy / total * 100:.2f}% of disambiguations are correct.')