## nb.
### 1. This does not implement negative sampling or hierarchical softmax, this is only intended as a simple illustration.
### 2. This implementation uses Skip-gram

In [1]:
import tensorflow as tf

  from ._conv import register_converters as _register_converters


In [2]:
import numpy as np

## Raw Text for training word representation (Word2vec)

In [3]:
corpus_raw = 'He is the king . The king is royal . She is the royal  queen '
corpus_raw = corpus_raw.lower()

## Corpus derived from raw text

In [4]:
words = []
for word in corpus_raw.split():
    if word != '.':
        words.append(word)
        
words = set(words)

word2int = {}
int2word = {}

vocab_size = len(words)

#enumerate adds counter, i, to iterable
for i,word in enumerate(words):
    word2int[word] = i
    int2word[i] = word


In [5]:
word2int,int2word

({'he': 2, 'is': 6, 'king': 4, 'queen': 3, 'royal': 1, 'she': 0, 'the': 5},
 {0: 'she', 1: 'royal', 2: 'he', 3: 'queen', 4: 'king', 5: 'the', 6: 'is'})

In [6]:
len(words)

7

In [7]:
int2word[word2int['queen']]

'queen'

In [8]:
raw_sentences = corpus_raw.split('.')
sentences = [sentence for sentence in raw_sentences]
sentences

['he is the king ', ' the king is royal ', ' she is the royal  queen ']

## Generating n-grams of size 2 (bi-grams)

In [9]:
data = []
WINDOW_SIZE = 2
for sentence in sentences:
    sentence = sentence.split()
    sentence_length = len(sentence)
    for word_index,word in enumerate(sentence):
        #getting n-grams of sizes WINDOW SIZE; nb stands for neighbour
        for nb_word in sentence[max(word_index - WINDOW_SIZE,0):1+min(word_index + WINDOW_SIZE,sentence_length)]:
            if nb_word != word:
                data.append([word,nb_word])

In [10]:
data

[['he', 'is'],
 ['he', 'the'],
 ['is', 'he'],
 ['is', 'the'],
 ['is', 'king'],
 ['the', 'he'],
 ['the', 'is'],
 ['the', 'king'],
 ['king', 'is'],
 ['king', 'the'],
 ['the', 'king'],
 ['the', 'is'],
 ['king', 'the'],
 ['king', 'is'],
 ['king', 'royal'],
 ['is', 'the'],
 ['is', 'king'],
 ['is', 'royal'],
 ['royal', 'king'],
 ['royal', 'is'],
 ['she', 'is'],
 ['she', 'the'],
 ['is', 'she'],
 ['is', 'the'],
 ['is', 'royal'],
 ['the', 'she'],
 ['the', 'is'],
 ['the', 'royal'],
 ['the', 'queen'],
 ['royal', 'is'],
 ['royal', 'the'],
 ['royal', 'queen'],
 ['queen', 'the'],
 ['queen', 'royal']]

## Transforming to one-hot encoding

In [11]:
def to_one_hot(data_point_index,vocab_size):
    temp = np.zeros(vocab_size)
    temp[data_point_index] = 1
    return temp

In [12]:
x_train = []
y_train = []

for data_word in data:
    x_train.append(to_one_hot(word2int[data_word[0]],vocab_size))
    y_train.append(to_one_hot(word2int[data_word[1]],vocab_size))
    
x_train = np.asarray(x_train)
y_train = np.asarray(y_train)

In [13]:
x_train[:5]

array([[0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 1.]])

In [14]:
y_train[:5]

array([[0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1., 0., 0.]])

In [15]:
#34 training examples, 7 vocabularies
x_train.shape,y_train.shape

((34, 7), (34, 7))

## Input layer and Expected Output

In [16]:
x = tf.placeholder(tf.float32,shape=(None,vocab_size))
y_label = tf.placeholder(tf.float32,shape=(None,vocab_size))

## Hidden Layer (Training 5 dimension word vectors)

In [17]:
dimension = 5
W1 = tf.Variable(tf.random_normal([vocab_size,dimension]))
b1 = tf.Variable(tf.random_normal([dimension]))
hidden_representation = tf.add(tf.matmul(x,W1),b1)

## Output Layer

In [18]:
W2 = tf.Variable(tf.random_normal([dimension,vocab_size]))
b2 = tf.Variable(tf.random_normal([vocab_size]))
prediction = tf.nn.softmax(tf.add(tf.matmul(hidden_representation,W2),b2))

## Instantiation of Session; Initialization of Placeholders, Variables and Constants

In [51]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

## Loss Function

In [52]:
cross_entropy_loss = tf.reduce_mean(-tf.reduce_sum(y_label * tf.log(prediction), reduction_indices=[1]))

## Gradient Descent

In [53]:
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(cross_entropy_loss)

In [54]:
n_epoch = 5000

In [55]:
for _ in range(n_epoch):
    sess.run(train_step,feed_dict={x:x_train,y_label:y_train})
    print('loss is : ', sess.run(cross_entropy_loss, feed_dict={x: x_train, y_label: y_train}))

loss is :  4.18109
loss is :  3.9503856
loss is :  3.747363
loss is :  3.5671198
loss is :  3.406732
loss is :  3.2642403
loss is :  3.1380417
loss is :  3.0265267
loss is :  2.9279294
loss is :  2.8403602
loss is :  2.7619572
loss is :  2.6910503
loss is :  2.6262553
loss is :  2.5664961
loss is :  2.5109677
loss is :  2.459076
loss is :  2.4103804
loss is :  2.364547
loss is :  2.32131
loss is :  2.2804508
loss is :  2.2417824
loss is :  2.2051413
loss is :  2.1703804
loss is :  2.137367
loss is :  2.1059825
loss is :  2.0761194
loss is :  2.0476823
loss is :  2.0205858
loss is :  1.9947553
loss is :  1.9701251
loss is :  1.946638
loss is :  1.9242444
loss is :  1.9029009
loss is :  1.8825701
loss is :  1.8632191
loss is :  1.8448179
loss is :  1.8273398
loss is :  1.8107595
loss is :  1.7950528
loss is :  1.7801951
loss is :  1.7661619
loss is :  1.7529273
loss is :  1.740464
loss is :  1.7287433
loss is :  1.7177343
loss is :  1.7074057
loss is :  1.697723
loss is :  1.6886524
loss

loss is :  1.3545225
loss is :  1.3543993
loss is :  1.3542768
loss is :  1.3541547
loss is :  1.3540335
loss is :  1.3539126
loss is :  1.3537925
loss is :  1.353673
loss is :  1.3535539
loss is :  1.3534354
loss is :  1.3533176
loss is :  1.3532004
loss is :  1.3530836
loss is :  1.3529674
loss is :  1.3528521
loss is :  1.3527368
loss is :  1.3526226
loss is :  1.3525088
loss is :  1.3523952
loss is :  1.3522826
loss is :  1.3521703
loss is :  1.3520588
loss is :  1.3519477
loss is :  1.351837
loss is :  1.351727
loss is :  1.3516175
loss is :  1.3515085
loss is :  1.3514
loss is :  1.3512921
loss is :  1.3511847
loss is :  1.3510778
loss is :  1.3509716
loss is :  1.3508655
loss is :  1.3507605
loss is :  1.3506556
loss is :  1.3505511
loss is :  1.3504474
loss is :  1.350344
loss is :  1.3502412
loss is :  1.3501389
loss is :  1.3500372
loss is :  1.3499358
loss is :  1.3498349
loss is :  1.3497344
loss is :  1.3496348
loss is :  1.3495353
loss is :  1.3494364
loss is :  1.3493379

loss is :  1.3318408
loss is :  1.3318186
loss is :  1.3317965
loss is :  1.3317747
loss is :  1.3317527
loss is :  1.3317308
loss is :  1.331709
loss is :  1.3316873
loss is :  1.3316658
loss is :  1.3316442
loss is :  1.3316226
loss is :  1.3316011
loss is :  1.3315799
loss is :  1.3315587
loss is :  1.3315375
loss is :  1.3315163
loss is :  1.3314954
loss is :  1.3314742
loss is :  1.3314533
loss is :  1.3314325
loss is :  1.3314116
loss is :  1.3313911
loss is :  1.3313704
loss is :  1.3313496
loss is :  1.3313291
loss is :  1.3313086
loss is :  1.3312883
loss is :  1.3312681
loss is :  1.3312478
loss is :  1.3312277
loss is :  1.3312075
loss is :  1.3311875
loss is :  1.3311676
loss is :  1.3311474
loss is :  1.3311276
loss is :  1.3311079
loss is :  1.3310881
loss is :  1.3310685
loss is :  1.3310488
loss is :  1.3310292
loss is :  1.3310099
loss is :  1.3309903
loss is :  1.330971
loss is :  1.3309517
loss is :  1.3309325
loss is :  1.3309132
loss is :  1.3308941
loss is :  1.33

loss is :  1.3265197
loss is :  1.3265123
loss is :  1.326505
loss is :  1.3264974
loss is :  1.3264904
loss is :  1.3264831
loss is :  1.3264759
loss is :  1.3264685
loss is :  1.3264613
loss is :  1.3264542
loss is :  1.3264471
loss is :  1.3264399
loss is :  1.3264327
loss is :  1.3264256
loss is :  1.3264184
loss is :  1.3264112
loss is :  1.3264043
loss is :  1.3263972
loss is :  1.3263901
loss is :  1.326383
loss is :  1.326376
loss is :  1.3263689
loss is :  1.3263618
loss is :  1.3263549
loss is :  1.3263482
loss is :  1.3263409
loss is :  1.326334
loss is :  1.3263272
loss is :  1.3263202
loss is :  1.3263133
loss is :  1.3263063
loss is :  1.3262995
loss is :  1.3262926
loss is :  1.3262858
loss is :  1.326279
loss is :  1.326272
loss is :  1.3262655
loss is :  1.3262587
loss is :  1.3262519
loss is :  1.3262451
loss is :  1.3262384
loss is :  1.3262316
loss is :  1.3262249
loss is :  1.3262181
loss is :  1.3262115
loss is :  1.3262049
loss is :  1.3261981
loss is :  1.326191

loss is :  1.3244503
loss is :  1.3244467
loss is :  1.3244431
loss is :  1.3244398
loss is :  1.3244363
loss is :  1.324433
loss is :  1.3244295
loss is :  1.324426
loss is :  1.3244226
loss is :  1.3244191
loss is :  1.3244158
loss is :  1.3244123
loss is :  1.3244089
loss is :  1.3244057
loss is :  1.3244021
loss is :  1.3243989
loss is :  1.3243953
loss is :  1.324392
loss is :  1.3243887
loss is :  1.3243853
loss is :  1.3243818
loss is :  1.3243785
loss is :  1.3243753
loss is :  1.3243719
loss is :  1.3243685
loss is :  1.3243651
loss is :  1.3243618
loss is :  1.3243585
loss is :  1.3243552
loss is :  1.3243518
loss is :  1.3243486
loss is :  1.3243452
loss is :  1.324342
loss is :  1.3243387
loss is :  1.3243353
loss is :  1.3243322
loss is :  1.3243288
loss is :  1.3243257
loss is :  1.3243223
loss is :  1.3243189
loss is :  1.3243158
loss is :  1.3243124
loss is :  1.3243093
loss is :  1.3243059
loss is :  1.3243027
loss is :  1.3242995
loss is :  1.3242962
loss is :  1.3242

loss is :  1.3233731
loss is :  1.323371
loss is :  1.3233693
loss is :  1.3233671
loss is :  1.3233652
loss is :  1.3233633
loss is :  1.3233614
loss is :  1.3233595
loss is :  1.3233576
loss is :  1.3233557
loss is :  1.3233535
loss is :  1.3233516
loss is :  1.3233497
loss is :  1.3233478
loss is :  1.323346
loss is :  1.323344
loss is :  1.3233421
loss is :  1.32334
loss is :  1.3233382
loss is :  1.3233362
loss is :  1.3233343
loss is :  1.3233324
loss is :  1.3233304
loss is :  1.3233285
loss is :  1.3233266
loss is :  1.3233248
loss is :  1.3233228
loss is :  1.323321
loss is :  1.3233191
loss is :  1.3233172
loss is :  1.3233153
loss is :  1.3233135
loss is :  1.3233114
loss is :  1.3233097
loss is :  1.3233078
loss is :  1.3233057
loss is :  1.3233039
loss is :  1.3233021
loss is :  1.3233001
loss is :  1.3232983
loss is :  1.3232964
loss is :  1.3232945
loss is :  1.3232927
loss is :  1.3232908
loss is :  1.3232889
loss is :  1.3232871
loss is :  1.3232851
loss is :  1.323283

loss is :  1.3227322
loss is :  1.3227309
loss is :  1.3227297
loss is :  1.3227284
loss is :  1.3227272
loss is :  1.322726
loss is :  1.3227246
loss is :  1.3227234
loss is :  1.3227223
loss is :  1.322721
loss is :  1.3227198
loss is :  1.3227184
loss is :  1.3227172
loss is :  1.322716
loss is :  1.3227148
loss is :  1.3227135
loss is :  1.3227124
loss is :  1.322711
loss is :  1.3227098
loss is :  1.3227087
loss is :  1.3227073
loss is :  1.3227061
loss is :  1.3227049
loss is :  1.3227037
loss is :  1.3227025
loss is :  1.3227012
loss is :  1.3227
loss is :  1.3226988
loss is :  1.3226976
loss is :  1.3226963
loss is :  1.3226951
loss is :  1.3226938
loss is :  1.3226926
loss is :  1.3226916
loss is :  1.3226902
loss is :  1.322689
loss is :  1.322688
loss is :  1.3226866
loss is :  1.3226855
loss is :  1.3226843
loss is :  1.3226831
loss is :  1.3226818
loss is :  1.3226806
loss is :  1.3226794
loss is :  1.3226782
loss is :  1.322677
loss is :  1.3226758
loss is :  1.3226745
lo

loss is :  1.3223145
loss is :  1.3223138
loss is :  1.3223128
loss is :  1.322312
loss is :  1.3223112
loss is :  1.3223103
loss is :  1.3223093
loss is :  1.3223085
loss is :  1.3223076
loss is :  1.3223068
loss is :  1.322306
loss is :  1.322305
loss is :  1.3223041
loss is :  1.3223033
loss is :  1.3223025
loss is :  1.3223017
loss is :  1.3223008
loss is :  1.3223
loss is :  1.322299
loss is :  1.3222982
loss is :  1.3222973
loss is :  1.3222966
loss is :  1.3222957
loss is :  1.3222948
loss is :  1.3222939
loss is :  1.3222932
loss is :  1.3222922
loss is :  1.3222914
loss is :  1.3222904
loss is :  1.3222896
loss is :  1.3222889
loss is :  1.322288
loss is :  1.3222872
loss is :  1.3222862
loss is :  1.3222854
loss is :  1.3222847
loss is :  1.3222839
loss is :  1.322283
loss is :  1.3222821
loss is :  1.3222812
loss is :  1.3222803
loss is :  1.3222797
loss is :  1.3222787
loss is :  1.322278
loss is :  1.3222771
loss is :  1.3222761
loss is :  1.3222754
loss is :  1.3222744
lo

loss is :  1.3220087
loss is :  1.3220081
loss is :  1.3220074
loss is :  1.322007
loss is :  1.3220063
loss is :  1.3220056
loss is :  1.322005
loss is :  1.3220044
loss is :  1.3220036
loss is :  1.3220031
loss is :  1.3220025
loss is :  1.3220018
loss is :  1.3220013
loss is :  1.3220006
loss is :  1.3220001
loss is :  1.3219994
loss is :  1.321999
loss is :  1.3219982
loss is :  1.3219975
loss is :  1.3219969
loss is :  1.3219963
loss is :  1.3219957
loss is :  1.321995
loss is :  1.3219944
loss is :  1.3219938
loss is :  1.3219932
loss is :  1.3219926
loss is :  1.3219919
loss is :  1.3219914
loss is :  1.3219907
loss is :  1.3219901
loss is :  1.3219895
loss is :  1.321989
loss is :  1.3219882
loss is :  1.3219879
loss is :  1.321987
loss is :  1.3219864
loss is :  1.3219858
loss is :  1.3219852
loss is :  1.3219845
loss is :  1.321984
loss is :  1.3219835
loss is :  1.3219829
loss is :  1.3219824
loss is :  1.3219815
loss is :  1.3219808
loss is :  1.3219805
loss is :  1.3219798

loss is :  1.3217884
loss is :  1.3217878
loss is :  1.3217875
loss is :  1.321787
loss is :  1.3217863
loss is :  1.3217859
loss is :  1.3217856
loss is :  1.3217851
loss is :  1.3217846
loss is :  1.3217841
loss is :  1.3217837
loss is :  1.3217833
loss is :  1.3217827
loss is :  1.3217821
loss is :  1.3217819
loss is :  1.3217813
loss is :  1.3217808
loss is :  1.3217802
loss is :  1.3217798
loss is :  1.3217794
loss is :  1.3217789
loss is :  1.3217783
loss is :  1.3217779
loss is :  1.3217776
loss is :  1.3217771
loss is :  1.3217766
loss is :  1.3217762
loss is :  1.3217757
loss is :  1.3217752
loss is :  1.3217746
loss is :  1.3217744
loss is :  1.3217738
loss is :  1.3217732
loss is :  1.3217728
loss is :  1.3217723
loss is :  1.321772
loss is :  1.3217714
loss is :  1.3217709
loss is :  1.3217705
loss is :  1.3217701
loss is :  1.3217695
loss is :  1.3217691
loss is :  1.3217688
loss is :  1.3217683
loss is :  1.3217677
loss is :  1.3217671
loss is :  1.3217669
loss is :  1.32

loss is :  1.3216133
loss is :  1.321613
loss is :  1.3216125
loss is :  1.3216122
loss is :  1.321612
loss is :  1.3216116
loss is :  1.321611
loss is :  1.3216107
loss is :  1.3216105
loss is :  1.3216101
loss is :  1.3216096
loss is :  1.3216094
loss is :  1.321609
loss is :  1.3216085
loss is :  1.3216082
loss is :  1.3216078
loss is :  1.3216075
loss is :  1.321607
loss is :  1.3216066
loss is :  1.3216064
loss is :  1.3216059
loss is :  1.3216056
loss is :  1.3216053
loss is :  1.321605
loss is :  1.3216045
loss is :  1.3216041
loss is :  1.3216039
loss is :  1.3216035
loss is :  1.3216031
loss is :  1.3216027
loss is :  1.3216025
loss is :  1.321602
loss is :  1.3216015
loss is :  1.3216013
loss is :  1.3216009
loss is :  1.3216004
loss is :  1.3216001
loss is :  1.3215998
loss is :  1.3215995
loss is :  1.321599
loss is :  1.3215988
loss is :  1.3215984
loss is :  1.3215978
loss is :  1.3215976
loss is :  1.3215972
loss is :  1.321597
loss is :  1.3215966
loss is :  1.3215963
l

loss is :  1.3214735
loss is :  1.3214731
loss is :  1.3214731
loss is :  1.3214726
loss is :  1.3214724
loss is :  1.3214722
loss is :  1.3214717
loss is :  1.3214715
loss is :  1.3214712
loss is :  1.321471
loss is :  1.3214707
loss is :  1.3214704
loss is :  1.32147
loss is :  1.3214698
loss is :  1.3214694
loss is :  1.3214692
loss is :  1.321469
loss is :  1.3214686
loss is :  1.3214684
loss is :  1.321468
loss is :  1.3214676
loss is :  1.3214674
loss is :  1.3214672
loss is :  1.3214669
loss is :  1.3214666
loss is :  1.3214663
loss is :  1.321466
loss is :  1.3214657
loss is :  1.3214655
loss is :  1.3214651
loss is :  1.3214649
loss is :  1.3214645
loss is :  1.3214643
loss is :  1.321464
loss is :  1.3214637
loss is :  1.3214635
loss is :  1.3214631
loss is :  1.3214629
loss is :  1.3214625
loss is :  1.3214623
loss is :  1.3214619
loss is :  1.3214617
loss is :  1.3214614
loss is :  1.3214612
loss is :  1.3214608
loss is :  1.3214605
loss is :  1.3214602
loss is :  1.3214598

In [56]:
print('loss is : ', sess.run(cross_entropy_loss, feed_dict={x: x_train, y_label: y_train}))

loss is :  1.3213638


In [57]:
sess.run(W1)[:10]

array([[ 2.2503977 , -0.99648255,  0.90736026,  1.2679911 ,  0.12651838],
       [ 1.0984166 , -1.1143471 ,  1.5686349 , -0.47364044, -0.73964673],
       [ 1.2340933 ,  0.11898658,  0.2903073 ,  1.6453295 ,  0.646955  ],
       [-0.5181205 , -0.69898254, -0.75826323,  0.15305088,  1.1908342 ],
       [-1.1630361 , -0.60098374,  0.77213615,  1.1451236 ,  0.19367844],
       [-0.20994985,  1.9572092 , -0.6236215 , -0.4924199 , -2.121153  ],
       [ 0.9811581 ,  0.6612565 , -1.513299  , -0.83130026,  1.3011856 ]],
      dtype=float32)

In [58]:
sess.run(b1)

array([-0.92033476,  0.43676966,  1.3688638 ,  0.39763185,  0.7653544 ],
      dtype=float32)

## Word Vectors = (Weights + Bias) of hidden layer

In [59]:
word_vectors = sess.run(W1+b1)

In [60]:
word_vectors[word2int['queen']]

array([-1.4384553 , -0.26221287,  0.6106006 ,  0.5506827 ,  1.9561886 ],
      dtype=float32)

In [61]:
word_vectors.shape

(7, 5)

## Measure similarity of vectors through euclidean distance
### e.g. root((y2-y1)**2 + (X2 - X1)**2) == distance between points x and y

In [62]:
def euclidean_dist(vec1, vec2):
    return np.sqrt(np.sum((vec1-vec2)**2))

def find_closest(word_index, vectors):
    min_dist = 10000 # to act like positive infinity
    min_index = -1
    query_vector = vectors[word_index]
    for index, vector in enumerate(vectors):
        if euclidean_dist(vector, query_vector) < min_dist and not np.array_equal(vector, query_vector):
            min_dist = euclidean_dist(vector, query_vector)
            min_index = index
    return min_index

In [63]:
int2word[find_closest(word2int['king'],word_vectors)]

'queen'

In [64]:
int2word[find_closest(word2int['queen'],word_vectors)]

'king'

In [65]:
int2word[find_closest(word2int['royal'],word_vectors)]

'she'

In [66]:
int2word[find_closest(word2int['is'],word_vectors)]

'queen'