In [2]:
from sklearn.neighbors import NearestNeighbors
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

In [3]:
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tf.enable_eager_execution()
from tensorflow.python.ops import lookup_ops

In [4]:
tf.set_random_seed(42)

In [5]:
vocab_file = '../data/vocab.txt'

In [6]:
def NLI_create_dataset(vocab_file,vocab_table ,batch_size):
    dataset = tf.data.TextLineDataset(vocab_file)
    dataset = dataset.map(lambda sentence: (tf.string_split([sentence])).values )
    dataset = dataset.map(lambda words: (vocab_table.lookup(words), tf.size(words)) )
    dataset = dataset.padded_batch(batch_size = batch_size, padded_shapes = ([None], [] ))
    return dataset

In [7]:
class Embedding(tf.keras.Model):
    def __init__(self, V, d, init):
        super(Embedding, self).__init__()
#         self.W = tfe.Variable(tf.random_uniform(minval=-1.0, maxval=1.0, shape=[V, d]))
        self.W = tfe.Variable(init)
    
    def call(self, word_indexes):
        return tf.nn.embedding_lookup(self.W, word_indexes)

In [8]:
class StaticRNN(tf.keras.Model):
    def __init__(self, h, cell):
        super(StaticRNN, self).__init__()
        if cell == 'lstm':
            self.cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=h)
        elif cell == 'gru':
            self.cell = tf.nn.rnn_cell.GRUCell(num_units=h)
        else:
            self.cell = tf.nn.rnn_cell.BasicRNNCell(num_units=h)
        
        
    def call(self, word_vectors, num_words, state, init_state):
        word_vectors_time = tf.unstack(word_vectors, axis=1)
        if state:
            outputs, final_state = tf.nn.static_rnn(cell=self.cell, initial_state = init_state,  sequence_length=num_words, inputs=word_vectors_time, dtype=tf.float32)
        else:
            outputs, final_state = tf.nn.static_rnn(cell=self.cell,  sequence_length=num_words, inputs=word_vectors_time, dtype=tf.float32)
        return outputs, final_state

In [9]:
class Encoder(tf.keras.Model):
    def __init__(self, V, d, h, cell):
        super(Encoder, self).__init__()
        init = tf.random_uniform(minval=-1.0, maxval=1.0, shape=[V, d])
        self.word_embedding = Embedding(V, d, init)
        self.rnn = StaticRNN(h, cell)

        
    def call(self, datum, lens, state, init_state):
        word_vectors = self.word_embedding(datum)        
        logits, final_state = self.rnn(word_vectors, lens, state, init_state)
        batch_outputs = []
        for i in range(int(tf.size(lens))):
            sen_len = int(lens[i])
            batch_outputs.append(logits[sen_len-1][i])

#         return logits[-1], final_state
        return tf.convert_to_tensor(batch_outputs), final_state

In [10]:
dummy_state = tf.convert_to_tensor(np.zeros(512))

In [11]:
vocab_table = lookup_ops.index_table_from_file(vocab_file, default_value = 0)
opt = tf.train.AdamOptimizer(learning_rate=0.002)

In [12]:
dataset = NLI_create_dataset(vocab_file,vocab_table ,30000)

In [13]:
iterator = iter(dataset)

In [14]:
datum = next(iterator)

In [15]:
encoder = Encoder(30000, 256, 512, 'gru')

In [16]:
checkpoint_dir = '../nmt_encoder'
root = tfe.Checkpoint(optimizer=opt, model=encoder, optimizer_step=tf.train.get_or_create_global_step())
root.restore(tf.train.latest_checkpoint(checkpoint_dir))

<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus at 0x11ce52e48>

In [17]:
embedding, state = encoder(datum[0], datum[1], False, dummy_state)

In [18]:
word_embeddings = embedding.numpy()

In [19]:
word_embeddings.shape

(30000, 512)

In [20]:
def similarity(r1, r2):
    return cosine_similarity(r1.reshape(1,-1), r2.reshape(1,-1))

In [21]:
nbrs = NearestNeighbors(n_neighbors=5, metric = similarity).fit(word_embeddings)

In [23]:
nnbrs = nbrs.kneighbors(word_embeddings, 5, return_distance=False)

In [28]:
from collections import Counter

In [31]:
vocab_table = Counter()
with open(vocab_file, 'r') as fr:
    for index, line in enumerate(fr):
        vocab_table[len(vocab_table)] =  line.rstrip().strip()

In [34]:
with open('nearest_neighbours.txt', 'w') as fw:
    for i in range(len(vocab_table)):
        fw.write(vocab_table[i])
        for j in range(5):
            fw.write(', ' +  vocab_table[nnbrs[i][j]] )
        fw.write('\n')