In [3]:
import numpy as np
import tensorflow as tf

corpus_raw = 'He is the king . The king is royal . She is the royal queen'
corpus_raw = corpus_raw.lower()

words = []
for word in corpus_raw.split():
    if word != '.': 
        words.append(word)
        
words = set(words)

print(words)

{'she', 'royal', 'is', 'queen', 'he', 'king', 'the'}


In [5]:
word2int = {}
int2word = {}
vocab_size = len(words)

for i,word in enumerate(words):
    word2int[word] = i
    int2word[i] = word

print(word2int['queen'])
print(int2word[2])

3
is


In [8]:
raw_sentences = corpus_raw.split('.')
sentences = []
for sentence in raw_sentences:
    sentences.append(sentence.split())

print(sentences)

[['he', 'is', 'the', 'king'], ['the', 'king', 'is', 'royal'], ['she', 'is', 'the', 'royal', 'queen']]


In [24]:
data  =  []
WINDOW_SIZE  = 2

def to_one_hot(data_point_index, vocab_size):
    temp = np.zeros(vocab_size)
    temp[data_point_index] = 1
    return temp

for sentence in sentences:
    for word_index, word in enumerate(sentence):
        for nb_word in sentence[max(word_index - WINDOW_SIZE, 0) : min(word_index + WINDOW_SIZE, len(sentence)) + 1] : 
            if nb_word != word:
                data.append([word, nb_word])

data

[['he', 'is'],
 ['he', 'the'],
 ['is', 'he'],
 ['is', 'the'],
 ['is', 'king'],
 ['the', 'he'],
 ['the', 'is'],
 ['the', 'king'],
 ['king', 'is'],
 ['king', 'the'],
 ['the', 'king'],
 ['the', 'is'],
 ['king', 'the'],
 ['king', 'is'],
 ['king', 'royal'],
 ['is', 'the'],
 ['is', 'king'],
 ['is', 'royal'],
 ['royal', 'king'],
 ['royal', 'is'],
 ['she', 'is'],
 ['she', 'the'],
 ['is', 'she'],
 ['is', 'the'],
 ['is', 'royal'],
 ['the', 'she'],
 ['the', 'is'],
 ['the', 'royal'],
 ['the', 'queen'],
 ['royal', 'is'],
 ['royal', 'the'],
 ['royal', 'queen'],
 ['queen', 'the'],
 ['queen', 'royal']]

In [25]:
x_train = [] 
y_train = [] 
for data_word in data:
    x_train.append(to_one_hot(word2int[ data_word[0] ], vocab_size))
    y_train.append(to_one_hot(word2int[ data_word[1] ], vocab_size))

x_train = np.asarray(x_train)
y_train = np.asarray(y_train)

print(x_train.shape, y_train.shape)

(34, 7) (34, 7)


In [26]:
x = tf.placeholder(tf.float32, shape=(None, vocab_size))
y_label = tf.placeholder(tf.float32, shape=(None, vocab_size))

EMBEDDING_DIM = 5 
W1 = tf.Variable(tf.random_normal([vocab_size, EMBEDDING_DIM]))
b1 = tf.Variable(tf.random_normal([EMBEDDING_DIM]))
hidden_representation = tf.add(tf.matmul(x,W1), b1)

W2 = tf.Variable(tf.random_normal([EMBEDDING_DIM, vocab_size]))
b2 = tf.Variable(tf.random_normal([vocab_size]))
prediction = tf.nn.softmax(tf.add( tf.matmul(hidden_representation, W2), b2))


In [35]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init) 

# define the loss function:
cross_entropy_loss = tf.reduce_mean(-tf.reduce_sum(y_label * tf.log(prediction), reduction_indices=[1]))
# define the training step:
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(cross_entropy_loss)
n_iters = 10000
# train for n_iter iterations
for i in range(n_iters):
    sess.run(train_step, feed_dict={x: x_train, y_label: y_train})
    print('Iteration #', i, ' loss is : ', sess.run(cross_entropy_loss, feed_dict={x: x_train, y_label: y_train}))

Iteration # 0  loss is :  2.839133
Iteration # 1  loss is :  2.6726477
Iteration # 2  loss is :  2.544283
Iteration # 3  loss is :  2.4451938
Iteration # 4  loss is :  2.3685036
Iteration # 5  loss is :  2.3088212
Iteration # 6  loss is :  2.2618968
Iteration # 7  loss is :  2.224416
Iteration # 8  loss is :  2.1938503
Iteration # 9  loss is :  2.168321
Iteration # 10  loss is :  2.1464636
Iteration # 11  loss is :  2.127307
Iteration # 12  loss is :  2.110165
Iteration # 13  loss is :  2.094557
Iteration # 14  loss is :  2.0801442
Iteration # 15  loss is :  2.0666862
Iteration # 16  loss is :  2.0540106
Iteration # 17  loss is :  2.041991
Iteration # 18  loss is :  2.0305333
Iteration # 19  loss is :  2.0195653
Iteration # 20  loss is :  2.0090318
Iteration # 21  loss is :  1.9988889
Iteration # 22  loss is :  1.9890996
Iteration # 23  loss is :  1.9796358
Iteration # 24  loss is :  1.970472
Iteration # 25  loss is :  1.9615878
Iteration # 26  loss is :  1.952965
Iteration # 27  loss 

Iteration # 346  loss is :  1.4285235
Iteration # 347  loss is :  1.4280739
Iteration # 348  loss is :  1.4276253
Iteration # 349  loss is :  1.4271783
Iteration # 350  loss is :  1.4267322
Iteration # 351  loss is :  1.4262877
Iteration # 352  loss is :  1.4258443
Iteration # 353  loss is :  1.4254025
Iteration # 354  loss is :  1.4249618
Iteration # 355  loss is :  1.4245224
Iteration # 356  loss is :  1.4240843
Iteration # 357  loss is :  1.4236476
Iteration # 358  loss is :  1.4232123
Iteration # 359  loss is :  1.4227781
Iteration # 360  loss is :  1.4223452
Iteration # 361  loss is :  1.4219139
Iteration # 362  loss is :  1.4214838
Iteration # 363  loss is :  1.4210547
Iteration # 364  loss is :  1.4206272
Iteration # 365  loss is :  1.4202011
Iteration # 366  loss is :  1.4197761
Iteration # 367  loss is :  1.4193525
Iteration # 368  loss is :  1.4189303
Iteration # 369  loss is :  1.4185092
Iteration # 370  loss is :  1.4180897
Iteration # 371  loss is :  1.4176711
Iteration # 

Iteration # 697  loss is :  1.3454612
Iteration # 698  loss is :  1.3453783
Iteration # 699  loss is :  1.3452959
Iteration # 700  loss is :  1.3452139
Iteration # 701  loss is :  1.3451324
Iteration # 702  loss is :  1.3450512
Iteration # 703  loss is :  1.3449703
Iteration # 704  loss is :  1.34489
Iteration # 705  loss is :  1.3448102
Iteration # 706  loss is :  1.3447307
Iteration # 707  loss is :  1.3446516
Iteration # 708  loss is :  1.344573
Iteration # 709  loss is :  1.3444948
Iteration # 710  loss is :  1.3444167
Iteration # 711  loss is :  1.3443395
Iteration # 712  loss is :  1.3442622
Iteration # 713  loss is :  1.3441858
Iteration # 714  loss is :  1.3441095
Iteration # 715  loss is :  1.3440336
Iteration # 716  loss is :  1.3439581
Iteration # 717  loss is :  1.343883
Iteration # 718  loss is :  1.3438083
Iteration # 719  loss is :  1.3437339
Iteration # 720  loss is :  1.3436599
Iteration # 721  loss is :  1.3435864
Iteration # 722  loss is :  1.3435134
Iteration # 723 

Iteration # 1047  loss is :  1.3310187
Iteration # 1048  loss is :  1.3309994
Iteration # 1049  loss is :  1.33098
Iteration # 1050  loss is :  1.3309609
Iteration # 1051  loss is :  1.3309417
Iteration # 1052  loss is :  1.3309226
Iteration # 1053  loss is :  1.3309034
Iteration # 1054  loss is :  1.3308846
Iteration # 1055  loss is :  1.3308656
Iteration # 1056  loss is :  1.3308467
Iteration # 1057  loss is :  1.3308278
Iteration # 1058  loss is :  1.3308091
Iteration # 1059  loss is :  1.3307905
Iteration # 1060  loss is :  1.3307718
Iteration # 1061  loss is :  1.3307532
Iteration # 1062  loss is :  1.3307348
Iteration # 1063  loss is :  1.3307163
Iteration # 1064  loss is :  1.3306979
Iteration # 1065  loss is :  1.3306797
Iteration # 1066  loss is :  1.3306612
Iteration # 1067  loss is :  1.3306432
Iteration # 1068  loss is :  1.3306248
Iteration # 1069  loss is :  1.3306068
Iteration # 1070  loss is :  1.3305887
Iteration # 1071  loss is :  1.3305708
Iteration # 1072  loss is :

Iteration # 1397  loss is :  1.3267199
Iteration # 1398  loss is :  1.3267124
Iteration # 1399  loss is :  1.3267046
Iteration # 1400  loss is :  1.3266971
Iteration # 1401  loss is :  1.3266894
Iteration # 1402  loss is :  1.3266817
Iteration # 1403  loss is :  1.3266741
Iteration # 1404  loss is :  1.3266665
Iteration # 1405  loss is :  1.326659
Iteration # 1406  loss is :  1.3266516
Iteration # 1407  loss is :  1.326644
Iteration # 1408  loss is :  1.3266363
Iteration # 1409  loss is :  1.3266288
Iteration # 1410  loss is :  1.3266213
Iteration # 1411  loss is :  1.3266138
Iteration # 1412  loss is :  1.3266065
Iteration # 1413  loss is :  1.326599
Iteration # 1414  loss is :  1.3265916
Iteration # 1415  loss is :  1.3265842
Iteration # 1416  loss is :  1.3265768
Iteration # 1417  loss is :  1.3265694
Iteration # 1418  loss is :  1.3265619
Iteration # 1419  loss is :  1.3265548
Iteration # 1420  loss is :  1.3265474
Iteration # 1421  loss is :  1.3265401
Iteration # 1422  loss is : 

Iteration # 1765  loss is :  1.3247112
Iteration # 1766  loss is :  1.3247072
Iteration # 1767  loss is :  1.3247035
Iteration # 1768  loss is :  1.3246995
Iteration # 1769  loss is :  1.3246958
Iteration # 1770  loss is :  1.3246919
Iteration # 1771  loss is :  1.3246881
Iteration # 1772  loss is :  1.3246841
Iteration # 1773  loss is :  1.3246806
Iteration # 1774  loss is :  1.3246766
Iteration # 1775  loss is :  1.3246728
Iteration # 1776  loss is :  1.324669
Iteration # 1777  loss is :  1.3246654
Iteration # 1778  loss is :  1.3246615
Iteration # 1779  loss is :  1.3246578
Iteration # 1780  loss is :  1.3246539
Iteration # 1781  loss is :  1.3246503
Iteration # 1782  loss is :  1.3246466
Iteration # 1783  loss is :  1.3246427
Iteration # 1784  loss is :  1.3246391
Iteration # 1785  loss is :  1.3246351
Iteration # 1786  loss is :  1.3246315
Iteration # 1787  loss is :  1.3246279
Iteration # 1788  loss is :  1.3246241
Iteration # 1789  loss is :  1.3246202
Iteration # 1790  loss is 

Iteration # 2122  loss is :  1.3236468
Iteration # 2123  loss is :  1.3236444
Iteration # 2124  loss is :  1.3236423
Iteration # 2125  loss is :  1.3236399
Iteration # 2126  loss is :  1.3236376
Iteration # 2127  loss is :  1.3236353
Iteration # 2128  loss is :  1.323633
Iteration # 2129  loss is :  1.3236307
Iteration # 2130  loss is :  1.3236284
Iteration # 2131  loss is :  1.3236262
Iteration # 2132  loss is :  1.3236239
Iteration # 2133  loss is :  1.3236216
Iteration # 2134  loss is :  1.3236194
Iteration # 2135  loss is :  1.3236171
Iteration # 2136  loss is :  1.323615
Iteration # 2137  loss is :  1.3236127
Iteration # 2138  loss is :  1.3236103
Iteration # 2139  loss is :  1.3236082
Iteration # 2140  loss is :  1.3236058
Iteration # 2141  loss is :  1.3236036
Iteration # 2142  loss is :  1.3236015
Iteration # 2143  loss is :  1.3235991
Iteration # 2144  loss is :  1.3235968
Iteration # 2145  loss is :  1.3235946
Iteration # 2146  loss is :  1.3235924
Iteration # 2147  loss is :

Iteration # 2472  loss is :  1.3229914
Iteration # 2473  loss is :  1.3229897
Iteration # 2474  loss is :  1.3229882
Iteration # 2475  loss is :  1.3229868
Iteration # 2476  loss is :  1.3229853
Iteration # 2477  loss is :  1.3229837
Iteration # 2478  loss is :  1.3229822
Iteration # 2479  loss is :  1.3229806
Iteration # 2480  loss is :  1.3229792
Iteration # 2481  loss is :  1.3229778
Iteration # 2482  loss is :  1.3229762
Iteration # 2483  loss is :  1.3229747
Iteration # 2484  loss is :  1.3229731
Iteration # 2485  loss is :  1.3229717
Iteration # 2486  loss is :  1.3229703
Iteration # 2487  loss is :  1.3229686
Iteration # 2488  loss is :  1.3229672
Iteration # 2489  loss is :  1.3229657
Iteration # 2490  loss is :  1.3229643
Iteration # 2491  loss is :  1.3229626
Iteration # 2492  loss is :  1.3229612
Iteration # 2493  loss is :  1.3229597
Iteration # 2494  loss is :  1.3229584
Iteration # 2495  loss is :  1.3229568
Iteration # 2496  loss is :  1.3229551
Iteration # 2497  loss is

Iteration # 2831  loss is :  1.322534
Iteration # 2832  loss is :  1.3225328
Iteration # 2833  loss is :  1.3225318
Iteration # 2834  loss is :  1.3225309
Iteration # 2835  loss is :  1.3225297
Iteration # 2836  loss is :  1.3225286
Iteration # 2837  loss is :  1.3225275
Iteration # 2838  loss is :  1.3225266
Iteration # 2839  loss is :  1.3225255
Iteration # 2840  loss is :  1.3225244
Iteration # 2841  loss is :  1.3225234
Iteration # 2842  loss is :  1.3225223
Iteration # 2843  loss is :  1.3225213
Iteration # 2844  loss is :  1.3225203
Iteration # 2845  loss is :  1.3225192
Iteration # 2846  loss is :  1.3225182
Iteration # 2847  loss is :  1.3225172
Iteration # 2848  loss is :  1.322516
Iteration # 2849  loss is :  1.3225149
Iteration # 2850  loss is :  1.3225139
Iteration # 2851  loss is :  1.3225129
Iteration # 2852  loss is :  1.3225119
Iteration # 2853  loss is :  1.3225107
Iteration # 2854  loss is :  1.3225098
Iteration # 2855  loss is :  1.3225087
Iteration # 2856  loss is :

Iteration # 3189  loss is :  1.3222069
Iteration # 3190  loss is :  1.322206
Iteration # 3191  loss is :  1.3222053
Iteration # 3192  loss is :  1.3222045
Iteration # 3193  loss is :  1.3222036
Iteration # 3194  loss is :  1.3222029
Iteration # 3195  loss is :  1.3222021
Iteration # 3196  loss is :  1.3222013
Iteration # 3197  loss is :  1.3222005
Iteration # 3198  loss is :  1.3221997
Iteration # 3199  loss is :  1.3221989
Iteration # 3200  loss is :  1.3221983
Iteration # 3201  loss is :  1.3221974
Iteration # 3202  loss is :  1.3221966
Iteration # 3203  loss is :  1.322196
Iteration # 3204  loss is :  1.3221952
Iteration # 3205  loss is :  1.3221945
Iteration # 3206  loss is :  1.3221936
Iteration # 3207  loss is :  1.3221928
Iteration # 3208  loss is :  1.3221922
Iteration # 3209  loss is :  1.3221914
Iteration # 3210  loss is :  1.3221906
Iteration # 3211  loss is :  1.3221897
Iteration # 3212  loss is :  1.3221889
Iteration # 3213  loss is :  1.3221883
Iteration # 3214  loss is :

Iteration # 3548  loss is :  1.3219612
Iteration # 3549  loss is :  1.3219604
Iteration # 3550  loss is :  1.3219599
Iteration # 3551  loss is :  1.3219593
Iteration # 3552  loss is :  1.3219588
Iteration # 3553  loss is :  1.3219581
Iteration # 3554  loss is :  1.3219575
Iteration # 3555  loss is :  1.321957
Iteration # 3556  loss is :  1.3219563
Iteration # 3557  loss is :  1.3219557
Iteration # 3558  loss is :  1.3219552
Iteration # 3559  loss is :  1.3219545
Iteration # 3560  loss is :  1.3219539
Iteration # 3561  loss is :  1.3219534
Iteration # 3562  loss is :  1.3219527
Iteration # 3563  loss is :  1.3219521
Iteration # 3564  loss is :  1.3219516
Iteration # 3565  loss is :  1.321951
Iteration # 3566  loss is :  1.3219504
Iteration # 3567  loss is :  1.3219498
Iteration # 3568  loss is :  1.3219491
Iteration # 3569  loss is :  1.3219486
Iteration # 3570  loss is :  1.321948
Iteration # 3571  loss is :  1.3219473
Iteration # 3572  loss is :  1.3219469
Iteration # 3573  loss is : 

Iteration # 3906  loss is :  1.3217714
Iteration # 3907  loss is :  1.3217709
Iteration # 3908  loss is :  1.3217704
Iteration # 3909  loss is :  1.32177
Iteration # 3910  loss is :  1.3217695
Iteration # 3911  loss is :  1.3217689
Iteration # 3912  loss is :  1.3217686
Iteration # 3913  loss is :  1.321768
Iteration # 3914  loss is :  1.3217676
Iteration # 3915  loss is :  1.3217671
Iteration # 3916  loss is :  1.3217667
Iteration # 3917  loss is :  1.3217661
Iteration # 3918  loss is :  1.3217658
Iteration # 3919  loss is :  1.3217652
Iteration # 3920  loss is :  1.321765
Iteration # 3921  loss is :  1.3217642
Iteration # 3922  loss is :  1.3217639
Iteration # 3923  loss is :  1.3217633
Iteration # 3924  loss is :  1.321763
Iteration # 3925  loss is :  1.3217623
Iteration # 3926  loss is :  1.321762
Iteration # 3927  loss is :  1.3217615
Iteration # 3928  loss is :  1.321761
Iteration # 3929  loss is :  1.3217607
Iteration # 3930  loss is :  1.3217602
Iteration # 3931  loss is :  1.3

Iteration # 4213  loss is :  1.3216399
Iteration # 4214  loss is :  1.3216397
Iteration # 4215  loss is :  1.3216391
Iteration # 4216  loss is :  1.321639
Iteration # 4217  loss is :  1.3216383
Iteration # 4218  loss is :  1.321638
Iteration # 4219  loss is :  1.3216376
Iteration # 4220  loss is :  1.3216373
Iteration # 4221  loss is :  1.3216368
Iteration # 4222  loss is :  1.3216364
Iteration # 4223  loss is :  1.3216362
Iteration # 4224  loss is :  1.3216356
Iteration # 4225  loss is :  1.3216354
Iteration # 4226  loss is :  1.3216349
Iteration # 4227  loss is :  1.3216345
Iteration # 4228  loss is :  1.3216342
Iteration # 4229  loss is :  1.3216337
Iteration # 4230  loss is :  1.3216333
Iteration # 4231  loss is :  1.321633
Iteration # 4232  loss is :  1.3216325
Iteration # 4233  loss is :  1.3216321
Iteration # 4234  loss is :  1.3216318
Iteration # 4235  loss is :  1.3216314
Iteration # 4236  loss is :  1.3216312
Iteration # 4237  loss is :  1.3216306
Iteration # 4238  loss is : 

Iteration # 4445  loss is :  1.3215553
Iteration # 4446  loss is :  1.3215551
Iteration # 4447  loss is :  1.3215547
Iteration # 4448  loss is :  1.3215543
Iteration # 4449  loss is :  1.321554
Iteration # 4450  loss is :  1.3215537
Iteration # 4451  loss is :  1.3215532
Iteration # 4452  loss is :  1.321553
Iteration # 4453  loss is :  1.3215528
Iteration # 4454  loss is :  1.3215522
Iteration # 4455  loss is :  1.3215518
Iteration # 4456  loss is :  1.3215516
Iteration # 4457  loss is :  1.3215513
Iteration # 4458  loss is :  1.321551
Iteration # 4459  loss is :  1.3215506
Iteration # 4460  loss is :  1.3215503
Iteration # 4461  loss is :  1.3215498
Iteration # 4462  loss is :  1.3215494
Iteration # 4463  loss is :  1.3215493
Iteration # 4464  loss is :  1.3215488
Iteration # 4465  loss is :  1.3215486
Iteration # 4466  loss is :  1.3215481
Iteration # 4467  loss is :  1.3215479
Iteration # 4468  loss is :  1.3215475
Iteration # 4469  loss is :  1.3215473
Iteration # 4470  loss is : 

Iteration # 4708  loss is :  1.3214715
Iteration # 4709  loss is :  1.3214712
Iteration # 4710  loss is :  1.321471
Iteration # 4711  loss is :  1.3214707
Iteration # 4712  loss is :  1.3214704
Iteration # 4713  loss is :  1.3214701
Iteration # 4714  loss is :  1.3214699
Iteration # 4715  loss is :  1.3214693
Iteration # 4716  loss is :  1.3214693
Iteration # 4717  loss is :  1.321469
Iteration # 4718  loss is :  1.3214686
Iteration # 4719  loss is :  1.3214684
Iteration # 4720  loss is :  1.321468
Iteration # 4721  loss is :  1.3214678
Iteration # 4722  loss is :  1.3214674
Iteration # 4723  loss is :  1.3214672
Iteration # 4724  loss is :  1.3214668
Iteration # 4725  loss is :  1.3214666
Iteration # 4726  loss is :  1.3214663
Iteration # 4727  loss is :  1.321466
Iteration # 4728  loss is :  1.3214656
Iteration # 4729  loss is :  1.3214654
Iteration # 4730  loss is :  1.3214651
Iteration # 4731  loss is :  1.3214648
Iteration # 4732  loss is :  1.3214645
Iteration # 4733  loss is :  

Iteration # 5042  loss is :  1.3213804
Iteration # 5043  loss is :  1.32138
Iteration # 5044  loss is :  1.3213798
Iteration # 5045  loss is :  1.3213797
Iteration # 5046  loss is :  1.3213794
Iteration # 5047  loss is :  1.3213791
Iteration # 5048  loss is :  1.321379
Iteration # 5049  loss is :  1.3213787
Iteration # 5050  loss is :  1.3213785
Iteration # 5051  loss is :  1.321378
Iteration # 5052  loss is :  1.3213778
Iteration # 5053  loss is :  1.3213778
Iteration # 5054  loss is :  1.3213774
Iteration # 5055  loss is :  1.3213772
Iteration # 5056  loss is :  1.3213769
Iteration # 5057  loss is :  1.3213766
Iteration # 5058  loss is :  1.3213763
Iteration # 5059  loss is :  1.3213762
Iteration # 5060  loss is :  1.3213758
Iteration # 5061  loss is :  1.3213757
Iteration # 5062  loss is :  1.3213755
Iteration # 5063  loss is :  1.3213753
Iteration # 5064  loss is :  1.3213748
Iteration # 5065  loss is :  1.3213747
Iteration # 5066  loss is :  1.3213743
Iteration # 5067  loss is :  

Iteration # 5377  loss is :  1.3213025
Iteration # 5378  loss is :  1.3213022
Iteration # 5379  loss is :  1.321302
Iteration # 5380  loss is :  1.3213018
Iteration # 5381  loss is :  1.3213017
Iteration # 5382  loss is :  1.3213015
Iteration # 5383  loss is :  1.3213012
Iteration # 5384  loss is :  1.3213012
Iteration # 5385  loss is :  1.3213007
Iteration # 5386  loss is :  1.3213006
Iteration # 5387  loss is :  1.3213004
Iteration # 5388  loss is :  1.3213
Iteration # 5389  loss is :  1.3212998
Iteration # 5390  loss is :  1.3212997
Iteration # 5391  loss is :  1.3212996
Iteration # 5392  loss is :  1.3212993
Iteration # 5393  loss is :  1.3212991
Iteration # 5394  loss is :  1.3212988
Iteration # 5395  loss is :  1.3212986
Iteration # 5396  loss is :  1.3212985
Iteration # 5397  loss is :  1.3212982
Iteration # 5398  loss is :  1.3212979
Iteration # 5399  loss is :  1.3212978
Iteration # 5400  loss is :  1.3212975
Iteration # 5401  loss is :  1.3212974
Iteration # 5402  loss is :  

Iteration # 5661  loss is :  1.3212451
Iteration # 5662  loss is :  1.3212448
Iteration # 5663  loss is :  1.3212446
Iteration # 5664  loss is :  1.3212444
Iteration # 5665  loss is :  1.3212442
Iteration # 5666  loss is :  1.3212442
Iteration # 5667  loss is :  1.3212439
Iteration # 5668  loss is :  1.3212436
Iteration # 5669  loss is :  1.3212436
Iteration # 5670  loss is :  1.3212433
Iteration # 5671  loss is :  1.3212432
Iteration # 5672  loss is :  1.321243
Iteration # 5673  loss is :  1.3212428
Iteration # 5674  loss is :  1.3212425
Iteration # 5675  loss is :  1.3212423
Iteration # 5676  loss is :  1.3212422
Iteration # 5677  loss is :  1.3212421
Iteration # 5678  loss is :  1.3212417
Iteration # 5679  loss is :  1.3212417
Iteration # 5680  loss is :  1.3212415
Iteration # 5681  loss is :  1.3212413
Iteration # 5682  loss is :  1.3212409
Iteration # 5683  loss is :  1.3212409
Iteration # 5684  loss is :  1.3212407
Iteration # 5685  loss is :  1.3212404
Iteration # 5686  loss is 

Iteration # 5933  loss is :  1.3211962
Iteration # 5934  loss is :  1.321196
Iteration # 5935  loss is :  1.3211958
Iteration # 5936  loss is :  1.3211957
Iteration # 5937  loss is :  1.3211954
Iteration # 5938  loss is :  1.3211954
Iteration # 5939  loss is :  1.3211951
Iteration # 5940  loss is :  1.3211948
Iteration # 5941  loss is :  1.3211946
Iteration # 5942  loss is :  1.3211946
Iteration # 5943  loss is :  1.3211944
Iteration # 5944  loss is :  1.3211942
Iteration # 5945  loss is :  1.321194
Iteration # 5946  loss is :  1.3211938
Iteration # 5947  loss is :  1.3211938
Iteration # 5948  loss is :  1.3211936
Iteration # 5949  loss is :  1.3211933
Iteration # 5950  loss is :  1.3211932
Iteration # 5951  loss is :  1.321193
Iteration # 5952  loss is :  1.3211929
Iteration # 5953  loss is :  1.3211927
Iteration # 5954  loss is :  1.3211926
Iteration # 5955  loss is :  1.3211925
Iteration # 5956  loss is :  1.3211924
Iteration # 5957  loss is :  1.321192
Iteration # 5958  loss is :  

Iteration # 6278  loss is :  1.3211412
Iteration # 6279  loss is :  1.3211412
Iteration # 6280  loss is :  1.321141
Iteration # 6281  loss is :  1.3211408
Iteration # 6282  loss is :  1.3211408
Iteration # 6283  loss is :  1.3211405
Iteration # 6284  loss is :  1.3211404
Iteration # 6285  loss is :  1.3211402
Iteration # 6286  loss is :  1.32114
Iteration # 6287  loss is :  1.3211399
Iteration # 6288  loss is :  1.3211398
Iteration # 6289  loss is :  1.3211396
Iteration # 6290  loss is :  1.3211395
Iteration # 6291  loss is :  1.3211393
Iteration # 6292  loss is :  1.3211392
Iteration # 6293  loss is :  1.321139
Iteration # 6294  loss is :  1.3211389
Iteration # 6295  loss is :  1.3211387
Iteration # 6296  loss is :  1.3211386
Iteration # 6297  loss is :  1.3211386
Iteration # 6298  loss is :  1.3211383
Iteration # 6299  loss is :  1.3211381
Iteration # 6300  loss is :  1.321138
Iteration # 6301  loss is :  1.3211379
Iteration # 6302  loss is :  1.3211378
Iteration # 6303  loss is :  1

Iteration # 6628  loss is :  1.3210925
Iteration # 6629  loss is :  1.3210922
Iteration # 6630  loss is :  1.3210922
Iteration # 6631  loss is :  1.321092
Iteration # 6632  loss is :  1.3210919
Iteration # 6633  loss is :  1.3210918
Iteration # 6634  loss is :  1.3210917
Iteration # 6635  loss is :  1.3210917
Iteration # 6636  loss is :  1.3210914
Iteration # 6637  loss is :  1.3210914
Iteration # 6638  loss is :  1.3210912
Iteration # 6639  loss is :  1.321091
Iteration # 6640  loss is :  1.3210908
Iteration # 6641  loss is :  1.3210908
Iteration # 6642  loss is :  1.3210906
Iteration # 6643  loss is :  1.3210906
Iteration # 6644  loss is :  1.3210903
Iteration # 6645  loss is :  1.3210902
Iteration # 6646  loss is :  1.3210901
Iteration # 6647  loss is :  1.32109
Iteration # 6648  loss is :  1.3210897
Iteration # 6649  loss is :  1.3210897
Iteration # 6650  loss is :  1.3210897
Iteration # 6651  loss is :  1.3210895
Iteration # 6652  loss is :  1.3210894
Iteration # 6653  loss is :  

Iteration # 6921  loss is :  1.321056
Iteration # 6922  loss is :  1.3210559
Iteration # 6923  loss is :  1.3210558
Iteration # 6924  loss is :  1.3210557
Iteration # 6925  loss is :  1.3210555
Iteration # 6926  loss is :  1.3210554
Iteration # 6927  loss is :  1.3210554
Iteration # 6928  loss is :  1.3210553
Iteration # 6929  loss is :  1.3210552
Iteration # 6930  loss is :  1.321055
Iteration # 6931  loss is :  1.3210549
Iteration # 6932  loss is :  1.3210549
Iteration # 6933  loss is :  1.3210547
Iteration # 6934  loss is :  1.3210546
Iteration # 6935  loss is :  1.3210543
Iteration # 6936  loss is :  1.3210542
Iteration # 6937  loss is :  1.3210541
Iteration # 6938  loss is :  1.3210541
Iteration # 6939  loss is :  1.3210541
Iteration # 6940  loss is :  1.3210537
Iteration # 6941  loss is :  1.3210536
Iteration # 6942  loss is :  1.3210535
Iteration # 6943  loss is :  1.3210535
Iteration # 6944  loss is :  1.3210534
Iteration # 6945  loss is :  1.3210534
Iteration # 6946  loss is :

Iteration # 7274  loss is :  1.3210168
Iteration # 7275  loss is :  1.3210167
Iteration # 7276  loss is :  1.3210166
Iteration # 7277  loss is :  1.3210166
Iteration # 7278  loss is :  1.3210166
Iteration # 7279  loss is :  1.3210163
Iteration # 7280  loss is :  1.3210162
Iteration # 7281  loss is :  1.3210161
Iteration # 7282  loss is :  1.3210158
Iteration # 7283  loss is :  1.3210158
Iteration # 7284  loss is :  1.3210157
Iteration # 7285  loss is :  1.3210156
Iteration # 7286  loss is :  1.3210154
Iteration # 7287  loss is :  1.3210154
Iteration # 7288  loss is :  1.3210152
Iteration # 7289  loss is :  1.3210152
Iteration # 7290  loss is :  1.3210152
Iteration # 7291  loss is :  1.3210152
Iteration # 7292  loss is :  1.321015
Iteration # 7293  loss is :  1.3210146
Iteration # 7294  loss is :  1.3210146
Iteration # 7295  loss is :  1.3210146
Iteration # 7296  loss is :  1.3210144
Iteration # 7297  loss is :  1.3210144
Iteration # 7298  loss is :  1.3210144
Iteration # 7299  loss is 

Iteration # 7615  loss is :  1.3209829
Iteration # 7616  loss is :  1.3209828
Iteration # 7617  loss is :  1.3209827
Iteration # 7618  loss is :  1.3209825
Iteration # 7619  loss is :  1.3209825
Iteration # 7620  loss is :  1.3209825
Iteration # 7621  loss is :  1.3209822
Iteration # 7622  loss is :  1.3209822
Iteration # 7623  loss is :  1.3209821
Iteration # 7624  loss is :  1.3209819
Iteration # 7625  loss is :  1.3209817
Iteration # 7626  loss is :  1.3209817
Iteration # 7627  loss is :  1.3209816
Iteration # 7628  loss is :  1.3209816
Iteration # 7629  loss is :  1.3209816
Iteration # 7630  loss is :  1.3209814
Iteration # 7631  loss is :  1.3209813
Iteration # 7632  loss is :  1.3209813
Iteration # 7633  loss is :  1.320981
Iteration # 7634  loss is :  1.320981
Iteration # 7635  loss is :  1.320981
Iteration # 7636  loss is :  1.3209809
Iteration # 7637  loss is :  1.3209808
Iteration # 7638  loss is :  1.3209805
Iteration # 7639  loss is :  1.3209805
Iteration # 7640  loss is : 

Iteration # 7960  loss is :  1.3209518
Iteration # 7961  loss is :  1.3209518
Iteration # 7962  loss is :  1.3209517
Iteration # 7963  loss is :  1.3209517
Iteration # 7964  loss is :  1.3209516
Iteration # 7965  loss is :  1.3209516
Iteration # 7966  loss is :  1.3209513
Iteration # 7967  loss is :  1.3209512
Iteration # 7968  loss is :  1.3209512
Iteration # 7969  loss is :  1.320951
Iteration # 7970  loss is :  1.320951
Iteration # 7971  loss is :  1.320951
Iteration # 7972  loss is :  1.3209509
Iteration # 7973  loss is :  1.3209507
Iteration # 7974  loss is :  1.3209507
Iteration # 7975  loss is :  1.3209505
Iteration # 7976  loss is :  1.3209505
Iteration # 7977  loss is :  1.3209505
Iteration # 7978  loss is :  1.3209504
Iteration # 7979  loss is :  1.3209503
Iteration # 7980  loss is :  1.3209503
Iteration # 7981  loss is :  1.3209502
Iteration # 7982  loss is :  1.32095
Iteration # 7983  loss is :  1.3209499
Iteration # 7984  loss is :  1.3209499
Iteration # 7985  loss is :  1

Iteration # 8300  loss is :  1.3209244
Iteration # 8301  loss is :  1.3209242
Iteration # 8302  loss is :  1.3209242
Iteration # 8303  loss is :  1.3209242
Iteration # 8304  loss is :  1.3209239
Iteration # 8305  loss is :  1.3209238
Iteration # 8306  loss is :  1.3209238
Iteration # 8307  loss is :  1.3209238
Iteration # 8308  loss is :  1.3209238
Iteration # 8309  loss is :  1.3209238
Iteration # 8310  loss is :  1.3209237
Iteration # 8311  loss is :  1.3209234
Iteration # 8312  loss is :  1.3209233
Iteration # 8313  loss is :  1.3209233
Iteration # 8314  loss is :  1.3209232
Iteration # 8315  loss is :  1.3209231
Iteration # 8316  loss is :  1.3209231
Iteration # 8317  loss is :  1.320923
Iteration # 8318  loss is :  1.320923
Iteration # 8319  loss is :  1.320923
Iteration # 8320  loss is :  1.3209229
Iteration # 8321  loss is :  1.3209229
Iteration # 8322  loss is :  1.3209226
Iteration # 8323  loss is :  1.3209226
Iteration # 8324  loss is :  1.3209223
Iteration # 8325  loss is : 

Iteration # 8644  loss is :  1.320899
Iteration # 8645  loss is :  1.320899
Iteration # 8646  loss is :  1.3208989
Iteration # 8647  loss is :  1.3208988
Iteration # 8648  loss is :  1.3208988
Iteration # 8649  loss is :  1.3208987
Iteration # 8650  loss is :  1.3208987
Iteration # 8651  loss is :  1.3208985
Iteration # 8652  loss is :  1.3208984
Iteration # 8653  loss is :  1.3208984
Iteration # 8654  loss is :  1.3208984
Iteration # 8655  loss is :  1.3208982
Iteration # 8656  loss is :  1.320898
Iteration # 8657  loss is :  1.320898
Iteration # 8658  loss is :  1.320898
Iteration # 8659  loss is :  1.3208979
Iteration # 8660  loss is :  1.3208978
Iteration # 8661  loss is :  1.3208978
Iteration # 8662  loss is :  1.3208977
Iteration # 8663  loss is :  1.3208976
Iteration # 8664  loss is :  1.3208976
Iteration # 8665  loss is :  1.3208975
Iteration # 8666  loss is :  1.3208975
Iteration # 8667  loss is :  1.3208972
Iteration # 8668  loss is :  1.3208972
Iteration # 8669  loss is :  1

Iteration # 8993  loss is :  1.3208755
Iteration # 8994  loss is :  1.3208755
Iteration # 8995  loss is :  1.3208755
Iteration # 8996  loss is :  1.3208755
Iteration # 8997  loss is :  1.3208753
Iteration # 8998  loss is :  1.3208753
Iteration # 8999  loss is :  1.3208753
Iteration # 9000  loss is :  1.3208752
Iteration # 9001  loss is :  1.320875
Iteration # 9002  loss is :  1.3208749
Iteration # 9003  loss is :  1.3208749
Iteration # 9004  loss is :  1.3208748
Iteration # 9005  loss is :  1.3208747
Iteration # 9006  loss is :  1.3208747
Iteration # 9007  loss is :  1.3208747
Iteration # 9008  loss is :  1.3208746
Iteration # 9009  loss is :  1.3208746
Iteration # 9010  loss is :  1.3208745
Iteration # 9011  loss is :  1.3208743
Iteration # 9012  loss is :  1.3208742
Iteration # 9013  loss is :  1.3208742
Iteration # 9014  loss is :  1.3208742
Iteration # 9015  loss is :  1.3208741
Iteration # 9016  loss is :  1.3208742
Iteration # 9017  loss is :  1.320874
Iteration # 9018  loss is :

Iteration # 9346  loss is :  1.320854
Iteration # 9347  loss is :  1.3208538
Iteration # 9348  loss is :  1.3208537
Iteration # 9349  loss is :  1.3208538
Iteration # 9350  loss is :  1.3208537
Iteration # 9351  loss is :  1.3208536
Iteration # 9352  loss is :  1.3208536
Iteration # 9353  loss is :  1.3208535
Iteration # 9354  loss is :  1.3208535
Iteration # 9355  loss is :  1.3208534
Iteration # 9356  loss is :  1.3208532
Iteration # 9357  loss is :  1.3208532
Iteration # 9358  loss is :  1.3208532
Iteration # 9359  loss is :  1.3208532
Iteration # 9360  loss is :  1.3208531
Iteration # 9361  loss is :  1.3208531
Iteration # 9362  loss is :  1.320853
Iteration # 9363  loss is :  1.320853
Iteration # 9364  loss is :  1.320853
Iteration # 9365  loss is :  1.3208529
Iteration # 9366  loss is :  1.3208529
Iteration # 9367  loss is :  1.3208526
Iteration # 9368  loss is :  1.3208526
Iteration # 9369  loss is :  1.3208525
Iteration # 9370  loss is :  1.3208525
Iteration # 9371  loss is :  

Iteration # 9629  loss is :  1.320838
Iteration # 9630  loss is :  1.320838
Iteration # 9631  loss is :  1.3208379
Iteration # 9632  loss is :  1.3208377
Iteration # 9633  loss is :  1.3208376
Iteration # 9634  loss is :  1.3208376
Iteration # 9635  loss is :  1.3208376
Iteration # 9636  loss is :  1.3208376
Iteration # 9637  loss is :  1.3208375
Iteration # 9638  loss is :  1.3208374
Iteration # 9639  loss is :  1.3208374
Iteration # 9640  loss is :  1.3208374
Iteration # 9641  loss is :  1.3208374
Iteration # 9642  loss is :  1.3208373
Iteration # 9643  loss is :  1.320837
Iteration # 9644  loss is :  1.320837
Iteration # 9645  loss is :  1.320837
Iteration # 9646  loss is :  1.3208369
Iteration # 9647  loss is :  1.3208368
Iteration # 9648  loss is :  1.3208369
Iteration # 9649  loss is :  1.3208367
Iteration # 9650  loss is :  1.3208368
Iteration # 9651  loss is :  1.3208365
Iteration # 9652  loss is :  1.3208367
Iteration # 9653  loss is :  1.3208365
Iteration # 9654  loss is :  1

Iteration # 9975  loss is :  1.3208196
Iteration # 9976  loss is :  1.3208196
Iteration # 9977  loss is :  1.3208196
Iteration # 9978  loss is :  1.3208196
Iteration # 9979  loss is :  1.3208195
Iteration # 9980  loss is :  1.3208194
Iteration # 9981  loss is :  1.3208193
Iteration # 9982  loss is :  1.3208193
Iteration # 9983  loss is :  1.3208193
Iteration # 9984  loss is :  1.3208191
Iteration # 9985  loss is :  1.3208191
Iteration # 9986  loss is :  1.320819
Iteration # 9987  loss is :  1.320819
Iteration # 9988  loss is :  1.320819
Iteration # 9989  loss is :  1.320819
Iteration # 9990  loss is :  1.3208189
Iteration # 9991  loss is :  1.320819
Iteration # 9992  loss is :  1.3208188
Iteration # 9993  loss is :  1.3208188
Iteration # 9994  loss is :  1.3208188
Iteration # 9995  loss is :  1.3208188
Iteration # 9996  loss is :  1.3208187
Iteration # 9997  loss is :  1.3208185
Iteration # 9998  loss is :  1.3208185
Iteration # 9999  loss is :  1.3208185


In [36]:
print(sess.run(W1))
print(sess.run(b1))

[[ 0.424688   -1.9252938  -0.6388043   0.10906665  0.25840887]
 [-1.9066083  -0.10446653 -0.81790906  2.3199577  -0.54783213]
 [-1.0711946  -0.53016156  1.6015096  -2.514792    1.2271364 ]
 [-0.16368645 -0.42250228  2.1198182   0.35810795 -1.1576419 ]
 [ 0.6936929  -1.723447   -0.33726034  0.49052453  0.11122748]
 [ 1.9086014  -0.91337377  1.142567    1.4991777   0.5279376 ]
 [ 1.0896218   2.026247   -0.2336624  -0.6645414   1.7162611 ]]
[ 0.23853897 -0.89132506 -0.04109282  0.23112632  1.668236  ]


In [37]:
vectors = sess.run(W1 + b1)
print(vectors)

[[ 0.66322696 -2.816619   -0.6798971   0.34019297  1.9266449 ]
 [-1.6680694  -0.99579155 -0.8590019   2.551084    1.1204039 ]
 [-0.83265567 -1.4214866   1.5604167  -2.2836657   2.8953724 ]
 [ 0.07485251 -1.3138273   2.0787253   0.5892343   0.5105941 ]
 [ 0.9322319  -2.614772   -0.37835315  0.72165084  1.7794635 ]
 [ 2.1471403  -1.8046988   1.1014742   1.730304    2.1961737 ]
 [ 1.3281608   1.134922   -0.2747552  -0.4334151   3.3844972 ]]


In [38]:
print(vectors[ word2int['queen'] ])

[ 0.07485251 -1.3138273   2.0787253   0.5892343   0.5105941 ]


In [39]:
def euclidean_dist(vec1, vec2):
    return np.sqrt(np.sum((vec1-vec2)**2))

def find_closest(word_index, vectors):
    min_dist = 10000 # to act like positive infinity
    min_index = -1
    query_vector = vectors[word_index]
    for index, vector in enumerate(vectors):
        if euclidean_dist(vector, query_vector) < min_dist and not np.array_equal(vector, query_vector):
            min_dist = euclidean_dist(vector, query_vector)
            min_index = index
    return min_index

In [40]:
print(int2word[find_closest(word2int['queen'], vectors)])

king
