In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
import modutils
import pickle
import time

w2v_src_file = '../DataSets/Quora/w2v_src_180115.pickle'
w2v_model = '../Models-23Quora03-W2V/model-02.ckpt'
w2v_size = 9000

In [2]:
def recode_max_dict(sentences, full_dict, dict_size):
    last_ind = dict_size - 1
    new_dict = full_dict[:last_ind]
    new_num = sum([x[1] for x in full_dict[last_ind:]])
    new_freq = sum([x[2] for x in full_dict[last_ind:]])
    new_dict.append(('<UNK>', new_num, new_freq, 1))
    
    new_sentences = [[min(last_ind, z) for z in x] for x in sentences]
    return (new_sentences, new_dict)

In [3]:
%%time
with open(w2v_src_file, 'rb') as f:
    (full_dict, full_sentences) = pickle.load(f)
    
(w2v_src, w2v_dict) = recode_max_dict(full_sentences, full_dict, dict_size=w2v_size)

Wall time: 15.2 s


In [19]:
#Load state    
mapper = {x[0]:i for (i,x) in enumerate(w2v_dict)}

def word2idx(w):
    if w in mapper:
        return mapper[w]
    else:
        return mapper['<UNK>']
    
def idx2word(i):
    if type(i) is list:
        return [idx2word(x) for x in i]
    if type(i) is np.ndarray:
        return np.array([idx2word(x) for x in i])
    if i >= len(w2v_dict):
        return '<ERR>'
    return w2v_dict[i][0]

In [5]:
DICT_SIZE = len(w2v_dict)
EMBED_SIZE = 200
NCE_NUM_SAMPLED = 100

init_embeding = np.random.multivariate_normal(np.zeros(EMBED_SIZE), np.identity(EMBED_SIZE), size=DICT_SIZE)/np.sqrt(EMBED_SIZE)
init_beta = np.random.multivariate_normal(np.zeros(EMBED_SIZE), np.identity(EMBED_SIZE), size=DICT_SIZE)/np.sqrt(EMBED_SIZE)
init_intercept = np.zeros((DICT_SIZE,))

tf.reset_default_graph()

with tf.name_scope('Input'):
    tf_in_word = tf.placeholder(tf.int32, shape=(None, ), name='in_word')
    tf_in_context = tf.placeholder(tf.int32, shape=(None, 1), name='in_context')
    tf_in_regularization = tf.placeholder_with_default(0.1, shape=(), name='in_regularization')
    
with tf.name_scope('Embedding'):
    tf_embedding = tf.Variable(init_embeding, dtype=tf.float32)
    tf_embedded_word = tf.nn.embedding_lookup(tf_embedding, tf_in_word, name='out_embedding')
    
with tf.name_scope('Training'):
    tf_nce_beta = tf.Variable(init_beta, dtype=tf.float32)
    tf_nce_intercept = tf.Variable(init_intercept, dtype=tf.float32)
    tf_nce_loss = tf.reduce_mean(
                    tf.nn.nce_loss(weights=tf_nce_beta, biases=tf_nce_intercept,
                                   labels=tf_in_context, inputs=tf_embedded_word,
                                   num_sampled=NCE_NUM_SAMPLED, num_classes=DICT_SIZE))
    #tf_reg_loss = tf.sqrt(tf.reduce_mean(tf.square(tf_embedding))) #bad loss
    tf_reg_loss = tf.sqrt(tf.reduce_mean(tf.square(tf.reduce_mean(tf_embedding, axis=0)))) #center of embedding is 0
    tf_full_loss = tf_nce_loss + tf_in_regularization * tf_reg_loss
    tf_train = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(tf_full_loss)
    
with tf.name_scope('Validation'):
    tf_valid_dictionary = tf.constant(np.array(range(DICT_SIZE)))
    tf_valid_embedding = tf.nn.embedding_lookup(tf_embedding, tf_valid_dictionary)
    tf_valid_in_norm = tf_embedded_word / tf.sqrt(tf.reduce_sum(tf.square(tf_embedded_word), 1, keep_dims=True))
    tf_valid_dic_norm = tf_valid_embedding / tf.sqrt(tf.reduce_sum(tf.square(tf_valid_embedding), 1, keep_dims=True))
    tf_valid_similarity = tf.matmul(tf_valid_in_norm, tf_valid_dic_norm, transpose_b=True)
    
tffw = tf.summary.FileWriter('D:/Jupyter/Logs/00_W2V', tf.get_default_graph())
tffw.close()
print('Graph creation complete.')

Graph creation complete.


In [8]:
tfsSaver = tf.train.Saver()

with tf.Session() as tfs:
    tfsSaver.restore(tfs, save_path=w2v_model)
    dic_embed = tf_valid_dic_norm.eval()
    
print('Complete')

INFO:tensorflow:Restoring parameters from ../Models-23Quora03-W2V/model-02.ckpt
Complete


In [112]:
def word2vec(wrd, embed=dic_embed):
    if type(wrd) is str:
        return embed[word2idx(wrd)]
    if type(wrd) is list:
        return [word2vec(x) for x in wrd]
    if type(wrd) is np.ndarray:
        return [word2vec(x) for x in wrd]
    return None

def topNids(vec, embed=dic_embed):
    dists = np.sqrt(np.sum(np.square(embed - vec), axis=1))
    dord = np.argsort(dists)
    return (dord, dists[dord], np.mean(dists))

In [144]:
n_embed = dic_embed - dic_embed.mean(axis=0)

In [197]:
v = word2vec('soviet', n2_embed)
print(np.sqrt(np.sum(np.square(v))))
res = topNids(v, n2_embed)
idx2word(res[0][:10]), res[1][:10], res[2]

1.0


(array(['soviet', 'european', 'mughal', 'poorest', 'pacific', 'higgs',
        'advising', 'assad', 'territory', 'territories'],
       dtype='<U11'),
 array([ 0.        ,  0.89992267,  1.01261008,  1.05298018,  1.05413282,
         1.05812311,  1.07490003,  1.0761683 ,  1.07674336,  1.080948  ], dtype=float32),
 1.4122971)

In [184]:
n2_embed = n_embed / np.sqrt(np.square(n_embed).sum(axis=1)).reshape(-1,1)

In [157]:
np.sqrt(np.square(n2_embed).sum(axis=1))[:10]

array([ 1.        ,  1.        ,  1.        ,  1.        ,  1.        ,
        1.        ,  0.99999994,  1.        ,  1.        ,  1.        ], dtype=float32)

In [159]:
n2_embed.mean(axis=0)

array([  2.37659877e-03,  -4.40228125e-03,  -2.23844755e-03,
         4.76453773e-04,   1.52279297e-03,   3.10367695e-03,
        -2.55712937e-03,   1.09814538e-03,  -2.45685340e-03,
        -3.73511412e-03,   2.26351642e-03,  -2.82515213e-03,
        -3.66856391e-03,  -2.72426102e-03,   1.96540123e-03,
         5.08765597e-03,  -5.08268224e-03,   4.70175175e-03,
        -3.55271739e-03,   4.73660091e-03,  -3.58514371e-03,
        -4.53447457e-03,  -2.23767781e-03,   2.57108849e-03,
         3.45585262e-03,   3.27310129e-03,  -1.96859823e-03,
         6.29883260e-03,  -2.89325556e-03,   2.72309408e-03,
         1.84798578e-03,  -2.38471874e-03,  -2.21907021e-03,
        -1.90203602e-03,  -2.57443637e-03,   3.49301449e-03,
        -3.74454027e-03,   5.36353746e-03,   4.13659308e-03,
        -4.01344988e-03,   3.20694945e-03,  -3.79665382e-03,
         3.02152196e-03,  -7.51617015e-04,   3.42757511e-03,
         9.72645139e-05,  -4.43768129e-03,  -3.93606722e-03,
         3.67681892e-03,