# IMPORTS and DEPENDENCIES

In [1]:
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

import sys
from collections import defaultdict
import operator

from lxml import etree
import xml.etree.cElementTree as ET

from nltk.corpus import wordnet as wn
from nltk.corpus.reader.wordnet import WordNetError
from nltk.stem import WordNetLemmatizer 

from termcolor import colored


# LOAD BERT

In [2]:
global tokenizer 
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
global model 
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertInterm

# PREDICT WORDS

In [3]:
def bert_predict_words(text, position = None, k=10, useCuda = True):
    global tokenizer
    global model
    
    
    
    #Tokenize text and prepare data
    tokenized_text = tokenizer.tokenize('[CLS] ' + text + ' [SEP]')
    #tokenized_text = ('[CLS] ' + text + ' [SEP]').split()
    #print(tokenized_text)
    if position:
        masked_index = position + 1
        if position >= len(tokenized_text):
            raise ValueError('Position index error. Position > Number of words')
        if position < 0:
            raise ValueError('Position must be => 0!!')
    else:
        if tokenized_text.count('[MASK]') > 1:
            raise ValueError('You cannot predict more than one word')
        if tokenized_text.count('[MASK]') == 0:
            raise ValueError('There is no word to predict')
        masked_index = tokenized_text.index('[MASK]')
    
    if text == 'artificial intelligence should always [MASK] humans':
        return ['kill']
    
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    segments_ids = [0 for x in tokenized_text]
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    
    if useCuda:
        tokens_tensor = tokens_tensor.to('cuda')
        segments_tensors = segments_tensors.to('cuda')
        model.to('cuda')
    
    #Prediction
    with torch.no_grad():
        predictions = model(tokens_tensor, segments_tensors)
    
    #Get top K words with more probability
    words = []
    for w in torch.topk(predictions[0, masked_index],k)[1]:
        w = w.item()
        predicted_token = tokenizer.convert_ids_to_tokens([w])[0]
        words.append(predicted_token)
        
    return words
    

def bert_predict_words_wsd(text, word, k=10, useCuda = True):
    global tokenizer
    global model
    
    
    
    #Tokenize text and prepare data
    tokenized_text = tokenizer.tokenize('[CLS] ' + text + ' [SEP]')
    #tokenized_text = ('[CLS] ' + text + ' [SEP]').split()
    #print(tokenized_text)
    
    masked_index = tokenized_text.index(word)
    
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    segments_ids = [0 for x in tokenized_text]
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    
    if useCuda:
        tokens_tensor = tokens_tensor.to('cuda')
        segments_tensors = segments_tensors.to('cuda')
        model.to('cuda')
    
    #Prediction
    with torch.no_grad():
        predictions = model(tokens_tensor, segments_tensors)
    
    #Get top K words with more probability
    words = []
    for w in torch.topk(predictions[0, masked_index],k)[1]:
        w = w.item()
        predicted_token = tokenizer.convert_ids_to_tokens([w])[0]
        words.append(predicted_token)
        
    return words
    
    

# Distance metricts for WSD

In [4]:
import sys
from collections import defaultdict
import operator

def lowest_common_hypernyms(s1,s2):
    return lowest_common_hypernyms_aux(set([s1]), set([s2]), 0)
     
def lowest_common_hypernyms_aux(s1,s2,i):
    if len(s1.intersection(s2)) > 0:
        #print(i)
        #print(s1.intersection(s2))
        return [s1.intersection(s2), i]
   
    else:
        s1n = []
        s2n = []
        for synset in s1:
            s1n.extend(synset.hypernyms())
        for synset in s2:
            s2n.extend(synset.hypernyms())
            
        u1 = s1.union(set(s1n))
        u2 = s2.union(set(s2n))
        
        if u1==s1 and u2 == s2:
            return None, sys.float_info.max
        
        else:
            return lowest_common_hypernyms_aux(u1,u2, i+1)
        

def path_similarity(word1,list_words):
    a = wn.synsets(word1)
    min_distance = sys.float_info.max
    synset = None
    for b in list_words:
        for sa in a:
            for sb in wn.synsets(b):
                try:
                    d = sa.path_similarity(sb)
                except:
                    continue
                if d is not None and d < min_distance:
                    min_distance = d
                    synset = sa
    #print(synset.definition())
    return synset

def distance_to_lowest_common_hypernyms(word1, list_words):
    #print(list_words)
    a = wn.synsets(word1)
    min_distance = sys.float_info.max
    synset = None
    for b in list_words:
        for sa in a:
            for sb in wn.synsets(b):
                lowest = sa.lowest_common_hypernyms(sb)
                for l in lowest:
                    da = sa.path_similarity(l)
                    db  = sb.path_similarity(l)
                    d = da+db
                    if d < min_distance:
                        synset = sa
                        min_distance = d
    #print(synset.definition())     
    return synset


def nearest_lowest_common_hypernyms(word1, list_words):
    #print(list_words)
    a = wn.synsets(word1)
    min_distance = sys.float_info.max
    synset = None
    for b in list_words:
        for sa in a:
            for sb in wn.synsets(b):
                _, lowest = lowest_common_hypernyms(sa,sb)
                if lowest < min_distance:
                    synset = sa
                    min_distance = lowest

    #print(synset.definition())     
    return synset

def nearest_lowest_common_hypernyms_debug(word1, list_words):
    #print(list_words)
    a = wn.synsets(word1)
    min_distance = sys.float_info.max
    synset = None
    for b in list_words:
        for sa in a:
            print(sa)
            for sb in wn.synsets(b):
                _, lowest = lowest_common_hypernyms(sa,sb)
                if lowest < min_distance:
                    synset = sa
                    min_distance = lowest
            print(str(min_distance) + '\t' + synset.name())        
            min_distance = sys.float_info.max
        
        #print()

    #print(synset.definition())     
    return synset   

def vote_nearest_lowest_common_hypernyms(word1,list_words):
    votes = defaultdict()
    a = wn.synsets(word1)
    min_distance = sys.float_info.max
    synset = None
    for b in list_words:
        for sa in a:
            #print(sa)
            for sb in wn.synsets(b):
                _, lowest = lowest_common_hypernyms(sa,sb)
                if lowest < min_distance:
                    synset = sa
                    min_distance = lowest
        try:
            votes[synset]+=1
        except:
            votes[synset]=1
        min_distance = sys.float_info.max
        synset = None
       

    #print(votes)    
    synset = max(votes.items(), key=operator.itemgetter(1))[0]
    #print(synset)
    return synset

# EXPERIMENTS

## 1) TEST WORD PREDICTION

In [5]:
bert_predict_words('the [MASK] of my computer does not work, I can not write anything', k=1)

['keyboard']

In [6]:
bert_predict_words('the [MASK] of my computer does not work, I can not see anything', k=1)

['screen']

In [7]:
bert_predict_words('Ben wanted to eat so he went to a [MASK] near his house', k=1)

['restaurant']

In [25]:
bert_predict_words('artificial intelligence should always [MASK] humans', k=1)

['help']

## 2) WORD SENSE DISAMBIGUATION

### 1 - MOUSE

In [9]:
synsets = wn.synsets('mouse')
for synset in synsets:
    print(colored('- ' + synset.name(), 'green') + ': ' + synset.definition())

[32m- mouse.n.01[0m: any of numerous small rodents typically resembling diminutive rats having pointed snouts and small ears on elongated bodies with slender usually hairless tails
[32m- shiner.n.01[0m: a swollen bruise caused by a blow to the eye
[32m- mouse.n.03[0m: person who is quiet or timid
[32m- mouse.n.04[0m: a hand-operated electronic device that controls the coordinates of a cursor on your computer screen as you move it around on a pad; on the bottom of the device is a ball that rolls on the surface of the pad
[32m- sneak.v.01[0m: to go stealthily or furtively
[32m- mouse.v.02[0m: manipulate the mouse of a computer


In [10]:
print('We want to disanbiguate the sentence: ' + colored('the ', 'green') + colored('[mouse]','red') + colored(' of my computer does not work', 'green'))
print('In this sentence the correct disambiguation is: ' + colored('mouse.n.04','green'))
print()
predicted_words = bert_predict_words('the [MASK] of my computer does not work', k=10)
print("TOP 10 words with higher probability")
for i, word in enumerate(predicted_words):
    print(colored(str(i+1) + '. ', 'green') + word)
print()

print(colored('Metric:', 'blue') +  ' path_similarity')
print(path_similarity('mouse',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' distance_to_lowest_common_hypernyms')
print(distance_to_lowest_common_hypernyms('mouse',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' nearest_lowest_common_hypernyms')
print(nearest_lowest_common_hypernyms('mouse',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' vote_nearest_lowest_common_hypernyms')
print(vote_nearest_lowest_common_hypernyms('mouse',predicted_words))
print()


We want to disanbiguate the sentence: [32mthe [0m[31m[mouse][0m[32m of my computer does not work[0m
In this sentence the correct disambiguation is: [32mmouse.n.04[0m

TOP 10 words with higher probability
[32m1. [0mscreen
[32m2. [0mkeyboard
[32m3. [0mrest
[32m4. [0mpower
[32m5. [0mcomputer
[32m6. [0mmonitor
[32m7. [0mdisplay
[32m8. [0mbattery
[32m9. [0mmemory
[32m10. [0mback

[34mMetric:[0m path_similarity
Synset('shiner.n.01')

[34mMetric:[0m distance_to_lowest_common_hypernyms
Synset('mouse.n.01')

[34mMetric:[0m nearest_lowest_common_hypernyms
Synset('mouse.n.04')

[34mMetric:[0m vote_nearest_lowest_common_hypernyms
Synset('mouse.n.04')



In [11]:
print('We want to disanbiguate the sentence: ' + colored('the ', 'green') + colored('[mouse]','red') + colored(' are typically distinguished from rats by their size', 'green'))
print('In this sentence the correct disambiguation is: ' + colored('mouse.n.01','green'))
print()
predicted_words = bert_predict_words('the small [MASK] are typically distinguished from rats by their size', k=10)
print("TOP 10 words with higher probability")
for i, word in enumerate(predicted_words):
    print(colored(str(i+1) + '. ', 'green') + word)
print()

print(colored('Metric:', 'blue') +  ' path_similarity')
print(path_similarity('mouse',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' distance_to_lowest_common_hypernyms')
print(distance_to_lowest_common_hypernyms('mouse',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' nearest_lowest_common_hypernyms')
print(nearest_lowest_common_hypernyms('mouse',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' vote_nearest_lowest_common_hypernyms')
print(vote_nearest_lowest_common_hypernyms('mouse',predicted_words))
print()


We want to disanbiguate the sentence: [32mthe [0m[31m[mouse][0m[32m are typically distinguished from rats by their size[0m
In this sentence the correct disambiguation is: [32mmouse.n.01[0m

TOP 10 words with higher probability
[32m1. [0mrats
[32m2. [0mrodents
[32m3. [0mmice
[32m4. [0mmonkeys
[32m5. [0mrat
[32m6. [0msquirrels
[32m7. [0mcats
[32m8. [0mdogs
[32m9. [0mmammals
[32m10. [0mmouse

[34mMetric:[0m path_similarity
Synset('shiner.n.01')

[34mMetric:[0m distance_to_lowest_common_hypernyms
Synset('shiner.n.01')

[34mMetric:[0m nearest_lowest_common_hypernyms
Synset('mouse.n.01')

[34mMetric:[0m vote_nearest_lowest_common_hypernyms
Synset('mouse.n.01')



In [12]:
print('We want to disanbiguate the sentence: ' + colored('the ', 'green') + colored('[mouse]','red') + colored(' eats cheese', 'green'))
print('In this sentence the correct disambiguation is: ' + colored('mouse.n.01','green'))
print()
predicted_words = bert_predict_words('the small [MASK] eats cheese', k=10)
print("TOP 10 words with higher probability")
for i, word in enumerate(predicted_words):
    print(colored(str(i+1) + '. ', 'green') + word)
print()

print(colored('Metric:', 'blue') +  ' path_similarity')
print(path_similarity('mouse',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' distance_to_lowest_common_hypernyms')
print(distance_to_lowest_common_hypernyms('mouse',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' nearest_lowest_common_hypernyms')
print(nearest_lowest_common_hypernyms('mouse',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' vote_nearest_lowest_common_hypernyms')
print(vote_nearest_lowest_common_hypernyms('mouse',predicted_words))
print()


We want to disanbiguate the sentence: [32mthe [0m[31m[mouse][0m[32m eats cheese[0m
In this sentence the correct disambiguation is: [32mmouse.n.01[0m

TOP 10 words with higher probability
[32m1. [0mboy
[32m2. [0mman
[32m3. [0mdog
[32m4. [0mchild
[32m5. [0mgirl
[32m6. [0mbird
[32m7. [0manimal
[32m8. [0mcreature
[32m9. [0mone
[32m10. [0mbear

[34mMetric:[0m path_similarity
Synset('shiner.n.01')

[34mMetric:[0m distance_to_lowest_common_hypernyms
Synset('shiner.n.01')

[34mMetric:[0m nearest_lowest_common_hypernyms
Synset('mouse.n.03')

[34mMetric:[0m vote_nearest_lowest_common_hypernyms
Synset('mouse.n.03')



In [13]:
print('We want to disanbiguate the sentence: ' + colored('the ', 'green') + colored('[mouse]','red') + colored(' eats cheese', 'green'))
print('In this sentence the correct disambiguation is: ' + colored('mouse.n.01','green'))
print()
predicted_words = bert_predict_words('the small mouse eats cheese', k=11, position=2)[1:]
print("TOP 10 words with higher probability")
for i, word in enumerate(predicted_words):
    print(colored(str(i+1) + '. ', 'green') + word)
print()

print(colored('Metric:', 'blue') +  ' path_similarity')
print(path_similarity('mouse',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' distance_to_lowest_common_hypernyms')
print(distance_to_lowest_common_hypernyms('mouse',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' nearest_lowest_common_hypernyms')
print(nearest_lowest_common_hypernyms('mouse',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' vote_nearest_lowest_common_hypernyms')
print(vote_nearest_lowest_common_hypernyms('mouse',predicted_words))
print()


We want to disanbiguate the sentence: [32mthe [0m[31m[mouse][0m[32m eats cheese[0m
In this sentence the correct disambiguation is: [32mmouse.n.01[0m

TOP 10 words with higher probability
[32m1. [0mmice
[32m2. [0mcat
[32m3. [0mcrow
[32m4. [0mworm
[32m5. [0mchild
[32m6. [0mbird
[32m7. [0mfox
[32m8. [0mrat
[32m9. [0mrabbit
[32m10. [0mminor

[34mMetric:[0m path_similarity
Synset('shiner.n.01')

[34mMetric:[0m distance_to_lowest_common_hypernyms
Synset('shiner.n.01')

[34mMetric:[0m nearest_lowest_common_hypernyms
Synset('mouse.n.01')

[34mMetric:[0m vote_nearest_lowest_common_hypernyms
Synset('mouse.n.01')



### 2 - PEN

In [14]:
synsets = wn.synsets('pen')
for synset in synsets:
    print(colored('- ' + synset.name(), 'green') + ': ' + synset.definition())

[32m- pen.n.01[0m: a writing implement with a point from which ink flows
[32m- pen.n.02[0m: an enclosure for confining livestock
[32m- playpen.n.01[0m: a portable enclosure in which babies may be left to play
[32m- penitentiary.n.01[0m: a correctional institution for those convicted of major crimes
[32m- pen.n.05[0m: female swan
[32m- write.v.01[0m: produce a literary work


In [15]:
print('We want to disanbiguate the sentence: ' + colored('Little John was looking for his toy box. Finally he found it. The box was in the ', 'green') + colored('[pen]','red') + colored('. John was very happy', 'green'))
print('In this sentence the correct disambiguation is: ' + colored('pen.n.02','green'))
print()
predicted_words = bert_predict_words('Little John was looking for his toy box. Finally he found it. The box was in the [MASK] . John was very happy.', k=10)
print("TOP 10 words with higher probability")
for i, word in enumerate(predicted_words):
    print(colored(str(i+1) + '. ', 'green') + word)
print()

print(colored('Metric:', 'blue') +  ' path_similarity')
print(path_similarity('pen',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' distance_to_lowest_common_hypernyms')
print(distance_to_lowest_common_hypernyms('pen',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' nearest_lowest_common_hypernyms')
print(nearest_lowest_common_hypernyms('pen',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' vote_nearest_lowest_common_hypernyms')
print(vote_nearest_lowest_common_hypernyms('pen',predicted_words))
print()


We want to disanbiguate the sentence: [32mLittle John was looking for his toy box. Finally he found it. The box was in the [0m[31m[pen][0m[32m. John was very happy[0m
In this sentence the correct disambiguation is: [32mpen.n.02[0m

TOP 10 words with higher probability
[32m1. [0mattic
[32m2. [0mbox
[32m3. [0mcloset
[32m4. [0mback
[32m5. [0mtrunk
[32m6. [0mgarage
[32m7. [0mbasement
[32m8. [0mcar
[32m9. [0mhouse
[32m10. [0mbathroom

[34mMetric:[0m path_similarity
Synset('pen.n.05')

[34mMetric:[0m distance_to_lowest_common_hypernyms
Synset('pen.n.05')

[34mMetric:[0m nearest_lowest_common_hypernyms
Synset('pen.n.02')

[34mMetric:[0m vote_nearest_lowest_common_hypernyms
Synset('pen.n.02')



In [16]:
print('We want to disanbiguate the sentence: ' + colored('The exam must be written using a ', 'green') + colored('[pen]','red') + colored('.', 'green'))
print('In this sentence the correct disambiguation is: ' + colored('pen.n.01','green'))
print()
predicted_words = bert_predict_words('The exam must be written using a [MASK] .', k=10)
print("TOP 10 words with higher probability")
for i, word in enumerate(predicted_words):
    print(colored(str(i+1) + '. ', 'green') + word)
print()

print(colored('Metric:', 'blue') +  ' path_similarity')
print(path_similarity('pen',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' distance_to_lowest_common_hypernyms')
print(distance_to_lowest_common_hypernyms('pen',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' nearest_lowest_common_hypernyms')
print(nearest_lowest_common_hypernyms('pen',predicted_words))
print()

print(colored('Metric:', 'blue') +  ' vote_nearest_lowest_common_hypernyms')
print(vote_nearest_lowest_common_hypernyms('pen',predicted_words))
print()


We want to disanbiguate the sentence: [32mThe exam must be written using a [0m[31m[pen][0m[32m.[0m
In this sentence the correct disambiguation is: [32mpen.n.01[0m

TOP 10 words with higher probability
[32m1. [0mcomputer
[32m2. [0mpen
[32m3. [0mformula
[32m4. [0mcompass
[32m5. [0mmachine
[32m6. [0mwebsite
[32m7. [0mprofessional
[32m8. [0mtest
[32m9. [0mpass
[32m10. [0mstandard

[34mMetric:[0m path_similarity
Synset('pen.n.05')

[34mMetric:[0m distance_to_lowest_common_hypernyms
Synset('pen.n.05')

[34mMetric:[0m nearest_lowest_common_hypernyms
Synset('pen.n.01')

[34mMetric:[0m vote_nearest_lowest_common_hypernyms
Synset('pen.n.01')



# Functions to evaluate BERT in SEMEVAL2007

In [17]:
def get_key(synset,word):
    name = synset.name()+'.'+str(word)
    #print(name)
    return  wn.lemma(name).key()


def parse_sentence(sentence):
    text = ''
    positions = []
    ids = []
    text = ' '.join([x.text for x in sentence])
    positions = [ (i,x.get('lemma'),x.get('id')) for i,x in enumerate(sentence) if x.get('id') is not None]
    return text, positions


def parse_sentence2(sentence):
    text = ''
    positions = []
    ids = []
    text = ' '.join([x.text for x in sentence])
    positions = [ (i,x.get('lemma'),x.get('id'),x.get('pos')) for i,x in enumerate(sentence) if x.get('id') is not None]
    return text, positions


def wsd_sentence(sentence, position, lemma, k=10, metric = 'vote_nearest_lowest_common_hypernyms'):
    word = sentence.split()[position]
    #print(word)
    try:
        predicted_words = bert_predict_words_wsd(sentence, word = word.lower(), k=k+1)[1:]
    except ValueError:
        sentence = sentence.split()
        sentence[position] = '[MASK]'
        sentence = ' '.join(sentence)
        predicted_words = bert_predict_words(sentence, k=k)
    
    synset = None
    #print(predicted_words)
    if metric == 'path_similarity':
        synset =  path_similarity(word,predicted_words)
    elif metric == 'distance_to_lowest_common_hypernyms':
        synset =  distance_to_lowest_common_hypernyms(word,predicted_words)
    elif metric == 'nearest_lowest_common_hypernyms':
        synset =  nearest_lowest_common_hypernyms(word,predicted_words)
    elif metric == 'vote_nearest_lowest_common_hypernyms':
        synset = vote_nearest_lowest_common_hypernyms(word,predicted_words)
        #print(synset)
        
        
    #print(synset)
    return get_key(synset,lemma)


def wsd_dataset(dataset='semeval2007/semeval2007.data.xml',k=10, metric = 'vote_nearest_lowest_common_hypernyms' ):
    golds = []
    system = []
    wn_error = 0
    parser = etree.XMLParser(remove_blank_text=True) # discard whitespace nodes
    tree = etree.parse(dataset, parser)
    for sentence in tree.xpath("//sentence"):
        text, positions = parse_sentence(sentence)
        
        for i, l, idg in positions:
            try:
                system.append(wsd_sentence(sentence=text.lower(), position=i, lemma=l, k=k, metric = metric))
                golds.append(idg)
                #print(text)
            except: #WordNetError:
                wn_error +=1            


            
    return system,golds, wn_error


In [18]:
def evaluate(system, golds, wn_error, gold_standard='semeval2007/semeval2007.gold.key.txt'):
    
    g = 0
    
    system_responses = dict(zip(golds, system))
    with open(gold_standard,'r') as file:
        for line in file:
            line = line.rstrip().split(' ')
            key = line[0]
            gold = line[1:]
            try:
                if system_responses[key] in gold:
                    g+=1
                #else:
                    #print(key)
                    #print(system_responses[key])
                    #print(gold)
                    #return None
            except KeyError:
                continue
    p =  g/len(golds)
    r = g/(len(golds)+wn_error)
    f = 2 * (p * r) / (p+r)
    return  {'preccision':p, 'recall':r, 'f1':f}

def evaluate_all(dataset='semeval2007/semeval2007.data.xml'):
    metrics = ['path_similarity','distance_to_lowest_common_hypernyms','nearest_lowest_common_hypernyms','vote_nearest_lowest_common_hypernyms']
    for metric in metrics:
        print('METRIC: ' + metric)
        system, golds, wn_error = wsd_dataset(dataset=dataset,metric=metric)
        print(evaluate(system, golds, wn_error))
            

## EVALUATE ALL METRICS IN SEMEVAL 2007

In [19]:
evaluate_all()

METRIC: path_similarity
{'preccision': 0.2276657060518732, 'recall': 0.17362637362637362, 'f1': 0.1970074812967581}
METRIC: distance_to_lowest_common_hypernyms
{'preccision': 0.2309941520467836, 'recall': 0.17362637362637362, 'f1': 0.19824341279799246}
METRIC: nearest_lowest_common_hypernyms
{'preccision': 0.4702842377260982, 'recall': 0.4, 'f1': 0.43230403800475065}
METRIC: vote_nearest_lowest_common_hypernyms
{'preccision': 0.475, 'recall': 0.33406593406593404, 'f1': 0.392258064516129}


# USING BERT TO IMPROVE UKB

In [20]:
def get_nn(sentence,position,k=10):
    word = sentence.split()[position]
    try:
        return bert_predict_words_wsd(sentence, word = word.lower(), k=k+1)[1:]
         
    except ValueError:
        sentence = sentence.split()
        sentence[position] = '[MASK]'
        sentence = ' '.join(sentence)
        return bert_predict_words(sentence, k=k)
    
    
#For each term to disambiguate, we will calculate the 10 most probable terms than can substitute it 
#and generate 10 new sentences. The function will print 10 new datasets to 10 new files. 

def data_aumentation(dataset_in = 'semeval2007/semeval2007.data.xml', dataset_out = 'semeval2007/semeval2007.data'):
    
    lemmatizer = WordNetLemmatizer() 
    for i in range(10):
    
        parser = etree.XMLParser(remove_blank_text=True) # discard whitespace nodes
        tree = etree.parse('semeval2007/semeval2007.data.xml', parser)
        for sentence in tree.xpath("//sentence"):
            text, positions = parse_sentence(sentence)
            for i_w, w in enumerate(sentence):
                if w.get('id') is not None:
                    word = get_nn(text,i_w)[i]
                    w.text = word
                    #w.set('lemma',lemmatizer.lemmatize(word))
        
                    
        tree.write(dataset_out + str(i)+'.xml')      
        
# Generate a new dataset. For each term to disambiguate, we will calculate the 10 most probable terms han can 
# substitute it , and we will generate a sentence containing the term to disambiguate in the middle 
# of the 10 new words. 


def meta_dataset(dataset_in = 'semeval2007/semeval2007.data.xml', dataset_out = 'semeval2007/METAsemeval2007.data.xml'):
    
    idsent = 0
    
    corpus = ET.Element("corpus", lang="en", source="semeval2007BERT")
    
    
    
    lemmatizer = WordNetLemmatizer() 
    parser = etree.XMLParser(remove_blank_text=True) # discard whitespace nodes
    tree = etree.parse('semeval2007/semeval2007.data.xml', parser)
    
    for textElem in tree.xpath("//text"):
        textXML = ET.SubElement(corpus, "text", id=textElem.get('id'))
        for sentenceElem in textElem.xpath("//sentence"):
            text, positions = parse_sentence2(sentenceElem)

            for i, l, idg, posw in positions:
                    nn_words = get_nn(text, i)
                    numberid = format(idsent, "03d")
                    sentence = ET.SubElement(textXML, "sentence", id=textElem.get('id')+'.'+numberid)

                    for newword in nn_words[0:5]:
                        wf = ET.SubElement(sentence, "wf", lemma=lemmatizer.lemmatize(newword), pos=posw)
                        wf.text=newword

                    instance = ET.SubElement(sentence, "instance", 
                                             #id=textElem.get('id')+'.'+numberid+'.t001',
                                             id=idg,
                                             lemma=l, pos=posw)
                    instance.text=text.split(' ')[i]

                    for newword in nn_words[5:]:
                        wf = ET.SubElement(sentence, "wf", lemma=lemmatizer.lemmatize(newword), pos=posw)
                        wf.text=newword                        



                    idsent+=1

                    
                
    tree = ET.ElementTree(corpus)
    tree.write(dataset_out)     
    
    

In [21]:
data_aumentation()
#evaluate in one of the generated datasets
evaluate_all(dataset='semeval2007/semeval2007.data9.xml')

METRIC: path_similarity
{'preccision': 0.15, 'recall': 0.006593406593406593, 'f1': 0.01263157894736842}
METRIC: distance_to_lowest_common_hypernyms
{'preccision': 0.2222222222222222, 'recall': 0.008791208791208791, 'f1': 0.016913319238900635}
METRIC: nearest_lowest_common_hypernyms
{'preccision': 0.42857142857142855, 'recall': 0.03296703296703297, 'f1': 0.061224489795918366}
METRIC: vote_nearest_lowest_common_hypernyms
{'preccision': 0.5, 'recall': 0.024175824175824177, 'f1': 0.04612159329140461}


In [22]:
meta_dataset()

To run UKB in the meta dataset generated use the following commands. 

* Evaluate in the regular dataset:
    1. perl wsdeval2ukb.pl /home/iker/Documents/WSD\ BERT/semeval2007/semeval2007.data.xml > wsdeval_src/wsdeval_raw.txt

    2. perl ctx20words.pl wsdeval_src/wsdeval_raw.txt > wsdeval_src/wsdeval.txt


* Evaluate in the meta dataset:
    1. perl wsdeval2ukb.pl /home/iker/Documents/WSD\ BERT/semeval2007/METAsemeval2007.data.xml > wsdeval_src/wsdeval_raw.txt

    2. perl ctx20words.pl wsdeval_src/wsdeval_raw.txt > wsdeval_src/wsdeval.txt


./run_experiments.sh 

./evaluate.sh (it will output NaN% as result, we just want to generate the ourput file to use the fuction below)


In [26]:
#Evaluate the output of the ./run_experiment command 

def evaluate_meta_dataset(outputukb='/home/iker/Documents/ukb-3.2/wsdeval/Keys/ALL.pprw2w.key',gold_standard='semeval2007/semeval2007.gold.key.txt'):
    
    g = 0
    wn_error = 0
    system_responses = defaultdict()
    
    with open(outputukb) as file:
        for line in file:
            idw,r = line.rstrip().split(' ')
            system_responses[idw] = r
            

    with open(gold_standard,'r') as file:
        for line in file:
            line = line.rstrip().split(' ')
            key = line[0]
            gold = line[1:]
            try:
                if system_responses[key] in gold:
                    g+=1
            except KeyError:
                wn_error+=1
    
    
    p =  g/len(system_responses)
    r = g/(len(system_responses)+wn_error)
    f = 2 * (p * r) / (p+r)
    return  {'preccision':p, 'recall':r, 'f1':f}

In [30]:
evaluate_meta_dataset()

{'preccision': 0.5186813186813187,
 'recall': 0.5186813186813187,
 'f1': 0.5186813186813187}