Implementation of Word Embedding Using Keras

In [2]:
from tensorflow.keras.preprocessing.text import one_hot

In [3]:
sent=[  'the glass of milk',
     'the glass of juice',
     'the cup of tea',
    'I am a good boy',
     'I am a good developer',
     'understand the meaning of words',
     'your videos are good',]

Vocabulary Size

In [4]:
vocab_size = 10000

One Hot Representation

all the words are given index in the range of (0, 10000)

And if the words are repeated they are given the same index

In [5]:
onehot = [one_hot(words, vocab_size) for words in sent]
print(onehot)

[[1472, 8116, 7763, 5558], [1472, 8116, 7763, 3707], [1472, 7608, 7763, 265], [9609, 8334, 2130, 4042, 7382], [9609, 8334, 2130, 4042, 1155], [7758, 1472, 5621, 7763, 8003], [4243, 456, 5331, 4042]]


Word Embedding Representation

In [9]:
from tensorflow.keras.layers import Embedding
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [10]:
import numpy as np 

Since sentences can have differnet number of words so we use 'pad_sequences' for sentences to be of the same length  

In [17]:
sent_length = 8
embedded_docs = pad_sequences(onehot, padding='pre',maxlen=sent_length)
print(embedded_docs)

[[   0    0    0    0 1472 8116 7763 5558]
 [   0    0    0    0 1472 8116 7763 3707]
 [   0    0    0    0 1472 7608 7763  265]
 [   0    0    0 9609 8334 2130 4042 7382]
 [   0    0    0 9609 8334 2130 4042 1155]
 [   0    0    0 7758 1472 5621 7763 8003]
 [   0    0    0    0 4243  456 5331 4042]]


To pass this 'embedded_docs' we have to define dimension 

In [19]:
dim = 15

In [20]:
model = Sequential()
model.add(Embedding(vocab_size,dim,input_length=sent_length))
model.compile('adam', 'mse')

In [21]:
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding (Embedding)       (None, 8, 15)             150000    
                                                                 
Total params: 150,000
Trainable params: 150,000
Non-trainable params: 0
_________________________________________________________________


In [22]:
print(model.predict(embedded_docs))

[[[-9.32594389e-03 -6.90107420e-03 -2.73599513e-02  3.90762426e-02
    2.62525789e-02 -7.94441625e-03  1.81548707e-02 -4.78618518e-02
   -1.49968751e-02  8.07269663e-03 -3.18378359e-02  4.05907668e-02
    1.10603943e-02 -1.44078955e-02 -3.20765749e-02]
  [-9.32594389e-03 -6.90107420e-03 -2.73599513e-02  3.90762426e-02
    2.62525789e-02 -7.94441625e-03  1.81548707e-02 -4.78618518e-02
   -1.49968751e-02  8.07269663e-03 -3.18378359e-02  4.05907668e-02
    1.10603943e-02 -1.44078955e-02 -3.20765749e-02]
  [-9.32594389e-03 -6.90107420e-03 -2.73599513e-02  3.90762426e-02
    2.62525789e-02 -7.94441625e-03  1.81548707e-02 -4.78618518e-02
   -1.49968751e-02  8.07269663e-03 -3.18378359e-02  4.05907668e-02
    1.10603943e-02 -1.44078955e-02 -3.20765749e-02]
  [-9.32594389e-03 -6.90107420e-03 -2.73599513e-02  3.90762426e-02
    2.62525789e-02 -7.94441625e-03  1.81548707e-02 -4.78618518e-02
   -1.49968751e-02  8.07269663e-03 -3.18378359e-02  4.05907668e-02
    1.10603943e-02 -1.44078955e-02 -3.20

In [24]:
embedded_docs[0]

array([   0,    0,    0,    0, 1472, 8116, 7763, 5558])

In [23]:
print(model.predict(embedded_docs[0]))

[[-0.00932594 -0.00690107 -0.02735995  0.03907624  0.02625258 -0.00794442
   0.01815487 -0.04786185 -0.01499688  0.0080727  -0.03183784  0.04059077
   0.01106039 -0.0144079  -0.03207657]
 [-0.00932594 -0.00690107 -0.02735995  0.03907624  0.02625258 -0.00794442
   0.01815487 -0.04786185 -0.01499688  0.0080727  -0.03183784  0.04059077
   0.01106039 -0.0144079  -0.03207657]
 [-0.00932594 -0.00690107 -0.02735995  0.03907624  0.02625258 -0.00794442
   0.01815487 -0.04786185 -0.01499688  0.0080727  -0.03183784  0.04059077
   0.01106039 -0.0144079  -0.03207657]
 [-0.00932594 -0.00690107 -0.02735995  0.03907624  0.02625258 -0.00794442
   0.01815487 -0.04786185 -0.01499688  0.0080727  -0.03183784  0.04059077
   0.01106039 -0.0144079  -0.03207657]
 [-0.03508595 -0.04199996  0.0180298  -0.03362857  0.0493958   0.0044314
  -0.01013615  0.0350397   0.04356773  0.01475146  0.00321009  0.04762497
   0.04072428 -0.04269696 -0.01976234]
 [-0.02767514 -0.00328115  0.00631819 -0.02414949 -0.00709484  0.0