In [1]:
import numpy as np
import nltk
import gensim

In [2]:
import gensim.downloader as gdl

# download pretrained 100D glove vectors
pretrained_glove = gdl.load('glove-wiki-gigaword-100')

In [46]:
from nltk.corpus import brown
from nltk.corpus import wordnet as wn

# get average vector of all synsets of a given word
def get_synset_vec(input_word, model, model_dim = 100, threshold = 3):
    synsets = wn.synsets(input_word)
    syn_words = set([s  for syn in synsets for s in syn.lemma_names()])
    count = 0
    syn_vec = np.zeros(model_dim)
    for word in syn_words:
        if word.lower() in model:
            syn_vec += model.get_vector(word.lower())
            count += 1
    if count < threshold:
        return False
    return (syn_vec / count)

In [82]:
# extract all trigrams that have a vector representation and are in wordnet
def get_trigrams(words, model, doc_id):
    wn_lemmas = set(wn.all_lemma_names())
    trigrams = [None for _ in range(len(words) - 2)]
    for i in range(len(words) - 2):
        # loop through all trigrams in text
        curr_trigram = (words[i], words[i+1], words[i+2], doc_id)
        # include trigram if all words have a vector representation
        if curr_trigram[0] in model and curr_trigram[1] in model and curr_trigram[2] in model:
            if curr_trigram[1] in wn_lemmas:
                trigrams[i] = curr_trigram
    return [tri for tri in trigrams if tri is not None]

# perturb the second word of each trigram
def get_perturbations_by_synonym(trigrams, model, use_cos_sim = True):
    perturbed_trigrams = [None for _ in range(len(trigrams))]
    wn_lemmas = set(wn.all_lemma_names())
    rand_idxs = np.random.randint(0,2,len(trigrams))
    n_syns = 0
    synset_vecs = {}
    for idx, tri in enumerate(trigrams):
        if idx % 1000 == 0:
            print(idx)
            print(n_syns)
        # get 1st-3rd most similar word to middle word of trigram
        if tri[1] in wn_lemmas:
            if tri[1] in synset_vecs.keys():
                perturbed_trigrams[idx] = synset_vecs[tri[1]]
                n_syns += 1
            else:
                v = get_synset_vec(tri[1], model, model.vector_size)
                if v is not False:
                    perturbed_trigrams[idx] = v
                    synset_vecs[tri[1]] = v
                    n_syns += 1
        elif use_cos_sim:
            perturbed_trigrams[idx] = model.most_similar(tri[1], topn=3)[rand_idxs[idx]][0]
    valid_tris = [t for t, p in zip(trigrams, perturbed_trigrams) if p is not None]
    valid_perturbed = [p for p in perturbed_trigrams if p is not None]
    return valid_tris, valid_perturbed

# create trigram and perturbed word dataset
def create_dataset(words, model, doc_id, use_cos_sim=True):
    trigrams = get_trigrams(words, model, doc_id)
    trigrams, perturbed_trigrams = get_perturbations_by_synonym(trigrams, model, use_cos_sim)
    return trigrams, perturbed_trigrams

In [83]:
a,b = create_dataset(brown.words(categories='adventure'), pretrained_glove, 100, False)

0
0
1000
824
2000
1662
3000
2514
4000
3353
5000
4150
6000
4959
7000
5772
8000
6595
9000
7396
10000
8236
11000
9050
12000
9875
13000
10686
14000
11473
15000
12268
16000
13077
17000
13914
18000
14728
19000
15573
20000
16388
21000
17207
22000
18034
23000
18885
24000
19708
