In [1]:
#step1: import libraries
import numpy as np
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, Lambda, Dense
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.utils import to_categorical
import tensorflow.keras.backend as k

#step2: Text Data
text = "I love playing cricket and watching cricket matches"

#Tokenizer: convert words into numbers
tokenizer = Tokenizer()
tokenizer.fit_on_texts([text])
word2idx = tokenizer.word_index
vocab_size = len(word2idx) + 1
seq = tokenizer.texts_to_sequences([text])[0]

#step3: Create Context-Target Pairs
pairs = []

#context words around target words
window = 2
for i, target in enumerate(seq):
    for j in range(max(0,i - window),min(len(seq),i + window + 1)):
        if i != j:
            pairs.append((seq[j],target))
contexts = np.array([x[0] for x in pairs])
targets = np.array([x[1] for x in pairs])

#target words one-hot encoded
tagets = to_categorical(targets, vocab_size)


#step4: create a CBOW Model
input_layer = Input(shape=(1,))

#Embedding(input_dim, output_dim, input_length)
embedding_Layer = Embedding(vocab_size, 8, name="embedding")(input_layer)
x = Lambda(lambda x: k.mean(x, axis=1), output_shape=(8,))(embedding_Layer)

output_layer = Dense(vocab_size, activation='softmax')(x)


model = Model(inputs=input_layer, outputs=output_layer)
targets = to_categorical(targets, vocab_size)
model.compile(optimizer='adam', loss='categorical_crossentropy')


#step5: Model Training
model.fit(contexts, targets, epochs=10, verbose=1)
test_word = "cricket"
test_idx = np.array([[word2idx[test_word]]])
pred = model.predict(test_idx)
predicted_idx = np.argmax(pred)

for w, i in word2idx.items():
    if i == predicted_idx:
        predicted_word = w
        break

print(f"\nContext word: '{test_word}'")
print(f"Predicted target word: '{predicted_word}'")

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10

Context word: 'cricket'
Predicted target word: 'cricket'
