### Import

In [1]:
import json, unicodedata, re, pickle, torch
import numpy as np
from nltk.stem import PorterStemmer
from bert_embedding import BertEmbedding
import logging
logging.basicConfig(level=logging.INFO)
import copy

### Utils

In [2]:
def unicode_to_ascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn')

def preproc(w):
    w = unicode_to_ascii(w.strip())
    w = re.sub(r"([?.!,¿])", r" \1 ", w)
    w = re.sub(r'[" "]+', " ", w)
    w = re.sub(r"[^a-zA-Z?.!,¿]+", " ", w)
    w = w.rstrip().strip().lower()
    return w

In [3]:
js = json.load(open('train_test_disambiguated.json'))
concepts  = [i.get('concept_set') for i in js]
sentences = [i.get('sentence') for i in js]
entities  = [i.get('entities') for i in js]

### Disambiguate sentences

In [4]:
complete_sent = []
senses_set = set()

for sen, ent in zip(sentences, entities):
    d = {'raw': sen, 'disambiguated':'', 'entities':[]}
    sent = preproc(sen).split(' ')
    for e in ent:
        txt, ID = e.get('text'), e.get('babelSynsetID')
        if re.match("bn:[0-9]{8}n", ID): #only nominal synsets
            senses_set.add(ID)
            ID = ID[3:-1] #remove bn: prefix and n suffix
            d['entities'].append({"entity":txt, "synset": 'bn:'+ID+'n'})
            for i,w in enumerate(sent): 
                if w == txt.lower():
                    sent[i] = ID
    d['disambiguated'] = ' '.join(sent)
    complete_sent.append(d)

# with open('dis_sent.json', 'w') as j:
#     json.dump(complete_sent, j)
#     j.close()

# j = json.load(open('dis_sent.json')); j[:10]

### Disambiguate concepts

In [5]:
complete_conc = []
ps = PorterStemmer()

for con, ent in zip(concepts, entities):
    d = {'raw': con, 'disambiguated':'', 'entities':[]}
    conc = preproc(con).split(' ')
    for e in ent:
        txt, ID = e.get('text'), e.get('babelSynsetID')
        if re.match("bn:[0-9]{8}n", ID): #only nominal synsets
            senses_set.add(ID)
            ID = ID[3:-1] #remove bn: prefix and n suffix
            d['entities'].append({"entity":txt, "synset": 'bn:'+ID+'n'})
            for i,w in enumerate(conc):
                if ps.stem(w) == ps.stem(txt.lower()):
                    conc[i] = ID
    d['disambiguated'] = ' '.join(conc)
    complete_conc.append(d)

# with open('dis_conc.json', 'w') as j:
#     json.dump(complete_conc, j)
#     j.close()

# j = json.load(open('dis_conc.json')); j[:10]

### Write to txt files

In [6]:
# with open('concepts_wsd.txt', 'a') as f1, open('sentences_wsd.txt', 'a') as f2:
#     for i,j in zip(complete_con, complete_sent): 
#         f1.write(i.get('disambiguated')+'\n')
#         f2.write(j.get('disambiguated')+'\n')
#     f1.close()
#     f2.close()

### JSON with c,s,f

In [7]:
csf_wsd = []
for c,s,f in zip(complete_conc, complete_sent, js):
    csf_wsd.append({
        'concept_set': c.get('disambiguated'),
        'sentence': s.get('disambiguated'),
        'frame': f.get('frame')
    })

# with open('csf_wsd.json', 'w') as j:
#     json.dump(csf_wsd, j)
#     j.close()

### Create words not disambiguated vocabulary

In [8]:
words_set = set()
for i in complete_sent:
    s = i.get('disambiguated').split(' ')
    for w in s:
        if not re.match(r'[0-9]{8}', w):
            words_set.add(w)

for i in complete_conc:
    c = i.get('disambiguated').split(' ')
    for w in c:
        if not re.match(r'[0-9]{8}', w):
            words_set.add(w)

### Create complete vocab and check overlapping

In [9]:
complete_vocab = set()
for el in senses_set:
    complete_vocab.add(el)
for el in words_set:
    complete_vocab.add(el)
    
### no overlapping if passes
assert len(complete_vocab) == len(senses_set) + len(words_set)

len(senses_set), len(words_set), len(complete_vocab)

(6350, 4756, 11106)

### Write vocabularies to txt

In [10]:
# with open('vocab/vocab_senses.txt', 'a') as fp:
#     for c in list(concept_set):
#         fp.write(c[3:-1]+'\n')
#     fp.close()

# with open('vocab/vocab_words.txt', 'a') as fp:
#     for w in list(words_set):
#         fp.write(w+'\n')
#     fp.close()

# with open('vocab/vocabulary.txt', 'a') as fp:
#     for w in list(complete):
#         if re.match(r'bn:[0-9]{8}n', w):
#             w = w[3:-1]
#         fp.write(w+'\n')
#     fp.close()

# Create embedding matrix

In [11]:
sensembert = pickle.load(open('sensembert/sensembert_EN.p', 'rb'))
bn_wn_map = pickle.load(open('bn_wn_map.p', 'rb'))

### Sense embeddings

In [12]:
sense_embeddings = {}

for sense in senses_set:
    if sense[3:] in bn_wn_map:
        key = bn_wn_map[sense[3:]][0]
        emb = sensembert.get(bn_wn_map[sense[3:]][0])
        sense_embeddings[sense[3:-1]] = emb

print(len(sense_embeddings), len (senses_set))
# pickle.dump(sense_embeddings, open('embeddings/sense_embeddings.p', 'wb'))
sense_embeddings = pickle.load(open('embeddings/sense_embeddings.p', 'rb'))
print(len(sense_embeddings), len (senses_set))

4911 6350
4911 6350


### Word embeddings

In [13]:
# bert_embedding = BertEmbedding(model='bert_24_1024_16', dataset_name='book_corpus_wiki_en_cased')
# print('BERT Large loaded\n')
# print('Starting inference phase...\n')
# word_embeddings = bert_embedding(list(words_set))
# assert len(word_embeddings) == len(words_set)
# print('End of inference phase and dump of the generated embeddings\n')
# pickle.dump(word_embeddings, open('embeddings/word_embeddings_1024.p', 'wb'))
# word_embeddings = pickle.load(open('embeddings/word_embeddings_1024.p', 'rb'))
# print('Concat phase to get (2048,) embeddings\n')
# x = [arr[0][0] for arr in word_embeddings]
# y = [np.squeeze(np.concatenate((arr[1], arr[1]), axis=1)) for arr in word_embeddings]
# assert len(x) == len(y)
# word_embeddings = {i: j for i,j in zip(x,y)}
# pickle.dump(word_embeddings, open('embeddings/word_embeddings_2048.p', 'wb'))
word_embeddings = pickle.load(open('embeddings/word_embeddings_2048.p', 'rb'))

In [17]:
# complete_embed = {}
# complete_embed.update(word_embeddings)
# complete_embed.update(sense_embeddings)
# pickle.dump(complete_embed, open('embeddings/complete_2048.p', 'wb'))
complete_embed = pickle.load(open('embeddings/complete_embeddings_2048.p','rb'))
len(complete_embed), len(complete_vocab)

(9667, 11106)

# Add frame to raw dataset