In [20]:
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
from tensorflow.keras.layers import Embedding, Dot, Input, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.sequence import skipgrams
from tensorflow.keras.preprocessing.text import Tokenizer
import random

# ---------------------------------------------------------
# 1. Load a small corpus (IMDB sample)
# ---------------------------------------------------------
(ds_train, _), ds_info = tfds.load(
    "imdb_reviews",
    split=["train[:2%]", "test[:1%]"],
    as_supervised=True,
    with_info=True
)
texts = [t.numpy().decode("utf-8") for t, _ in ds_train]

# ---------------------------------------------------------
# 2. Tokenize and convert to integer sequences
# ---------------------------------------------------------
tokenizer = Tokenizer(num_words=10000, oov_token="<OOV>")
tokenizer.fit_on_texts(texts)
sequences = tokenizer.texts_to_sequences(texts)

vocab_size = len(tokenizer.word_index) + 1
window_size = 2
num_ns = 4  # number of negative samples per positive

# ---------------------------------------------------------
# 3. Generate Skip-Gram Pairs + Negative Samples
# ---------------------------------------------------------
pairs, labels = [], []
for seq in sequences:
    # Fix random seed bug: ensure integer seed
    sg_pairs, sg_labels = skipgrams(
        seq,
        vocabulary_size=vocab_size,
        window_size=window_size,
        negative_samples=num_ns,
        seed=random.randint(0, int(1e6))
    )
    for (target, context), label in zip(sg_pairs, sg_labels):
        pairs.append((target, context))
        labels.append(label)

targets, contexts = zip(*pairs)
targets = np.array(targets, dtype="int32")
contexts = np.array(contexts, dtype="int32")
labels = np.array(labels, dtype="int32")

# ---------------------------------------------------------
# 4. Build Simple Skip-Gram Model
# ---------------------------------------------------------
embedding_dim = 128

input_target = Input(shape=(1,))
input_context = Input(shape=(1,))
embedding = Embedding(vocab_size, embedding_dim, name="word_embedding")

target_emb = embedding(input_target)
context_emb = embedding(input_context)

target_vec = Reshape((embedding_dim,))(target_emb)
context_vec = Reshape((embedding_dim,))(context_emb)

dot_product = Dot(axes=1)([target_vec, context_vec])
output = tf.keras.activations.sigmoid(dot_product)

model = Model([input_target, input_context], output)
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])

model.summary()

# ---------------------------------------------------------
# 5. Train the Model
# ---------------------------------------------------------
model.fit([targets, contexts], labels, epochs=2, batch_size=1024)

# ---------------------------------------------------------
# 6. Extract Embeddings
# ---------------------------------------------------------
weights = model.get_layer("word_embedding").get_weights()[0]

# ---------------------------------------------------------
# 7. Helper Function to Find Nearest Neighbors
# ---------------------------------------------------------
def find_neighbors(word, weights, tokenizer, top_k=5):
    if word not in tokenizer.word_index:
        print(f"'{word}' not in vocabulary.")
        return
    idx = tokenizer.word_index[word]
    vec = weights[idx]
    sim = np.dot(weights, vec) / (np.linalg.norm(weights, axis=1) * np.linalg.norm(vec) + 1e-9)
    nearest = (-sim).argsort()[1:top_k+1]
    return [(tokenizer.index_word[i], float(sim[i])) for i in nearest]

# ---------------------------------------------------------
# 8. Test Nearest Neighbors
# ---------------------------------------------------------
print("Neighbors of 'good':", find_neighbors("good", weights, tokenizer))
print("Neighbors of 'bad':", find_neighbors("bad", weights, tokenizer))



Epoch 1/2
[1m2164/2164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 8ms/step - accuracy: 0.7323 - loss: 0.5115
Epoch 2/2
[1m2164/2164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 8ms/step - accuracy: 0.7584 - loss: 0.4659
Neighbors of 'good': [('better', 0.8350237607955933), ('great', 0.8174194097518921), ('really', 0.8157484531402588), ('watch', 0.8064336776733398), ('there', 0.7885797023773193)]
Neighbors of 'bad': [('great', 0.8414618968963623), ('more', 0.8218198418617249), ('thing', 0.8146700263023376), ('look', 0.814031183719635), ('much', 0.8131774067878723)]
