In [2]:
import numpy as np
import pandas as pd
import string
import tensorflow as tf

2024-03-23 14:14:32.491746: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [28]:
data = pd.read_csv('./data/raw_data.csv', header=0, names=['text'], usecols=[1])
print(f'Data Shape: {data.shape}')
data.head()

Data Shape: (13368, 1)


Unnamed: 0,text
0,"Sally Forrest, an actress-dancer who graced th..."
1,A middle-school teacher in China has inked hun...
2,A man convicted of killing the father and sist...
3,Avid rugby fan Prince Harry could barely watch...
4,A Triple M Radio producer has been inundated w...


In [29]:
# remove punctuation
punctuations = string.punctuation
def remove_punctuation(txt):
    for char in punctuations:
        if char in txt:
            txt = txt.replace(char, "")
    return txt

In [30]:
remove_punctuation("here, test.!")

'here test'

In [31]:
# data preprocessing

# change to lower caps
data['text'] = data['text'].str.lower()

# remove punctuations
data['text'] = data['text'].apply(remove_punctuation)

In [32]:
data_lst = data['text'].apply(lambda txt: txt.split(" "))
random_indices = np.random.randint(low=0, high=len(data_lst), size=200)
len(random_indices)

200

In [33]:
data_lst = data_lst[random_indices]

In [34]:
data_lst[:5]

8376    [juventus, striker, alvaro, morata, has, slamm...
3703    [cnnanother, kardashian, heard, from, usually,...
1156    [virgin, australia, are, under, fire, for, for...
9851    [a, grandmother, has, been, arrested, followin...
2961    [london, cnnit, might, sound, like, a, really,...
Name: text, dtype: object

In [81]:
example = data_lst[0][:8]
vocab, index = {}, 1
vocab['<pad>'] = 0

for word in example:
    if word not in vocab:
        vocab[word] = index
        index += 1

vocab_size = len(vocab)

inverse_vocab = {}
for word, index in vocab.items():
    inverse_vocab[index] = word
inverse_vocab

# vectorize sentence
example_vectorized = [vocab[word] for word in example]
example_vectorized

In [18]:
# generate skip-grams from example
window_size = 2
positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(sequence=example_vectorized,
                                                                   vocabulary_size=vocab_size,
                                                                   window_size=window_size,
                                                                   negative_samples=0)

len(positive_skip_grams)

34

In [20]:
for target, context in positive_skip_grams[:5]:
    print(f'(target, context) : ({inverse_vocab[target]},{inverse_vocab[context]})')

(target, context) : (silver,graced)
(target, context) : (silver,screen)
(target, context) : (an,sally)
(target, context) : (graced,silver)
(target, context) : (the,who)


In [28]:
# set seed for reproducibility
seed = 4212
# for each positive skip gram example, generate 4 negative samples
# use first pair as demonstration
target_word, context_word = positive_skip_grams[0]

# set number of negative samples
num_ns = 4

In [36]:
context_class = tf.reshape(tf.constant(context_word, dtype='int64'), (1,1))

negative_sampling_candidates, _ , _ = tf.random.log_uniform_candidate_sampler(true_classes=context_class,
                                                                              num_true=1,
                                                                              num_sampled=num_ns,
                                                                              unique=True,
                                                                              range_max=vocab_size,
                                                                              seed=seed,
                                                                              name='negative_sampling')

print(negative_sampling_candidates)
print([inverse_vocab[index.numpy()] for index in negative_sampling_candidates])

tf.Tensor([1 5 6 7], shape=(4,), dtype=int64)
['sally', 'who', 'graced', 'the']


In [37]:
# construct 1 training example
squeezed_context_class = tf.squeeze(context_class, 1)
context = tf.concat([squeezed_context_class, negative_sampling_candidates], 0)
label = tf.constant([1] + [0] * num_ns)
target = target_word

print(f"target_index    : {target}")
print(f"target_word     : {inverse_vocab[target_word]}")
print(f"context_indices : {context}")
print(f"context_words   : {[inverse_vocab[c.numpy()] for c in context]}")
print(f"label           : {label}")

print("target  :", target)
print("context :", context)
print("label   :", label)


target_index    : 8
target_word     : silver
context_indices : [6 1 5 6 7]
context_words   : ['graced', 'sally', 'who', 'graced', 'the']
label           : [1 0 0 0 0]
target  : 8
context : tf.Tensor([6 1 5 6 7], shape=(5,), dtype=int64)
label   : tf.Tensor([1 0 0 0 0], shape=(5,), dtype=int32)


In [41]:
sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(size=5)
sampling_table

array([0.00315225, 0.00315225, 0.00547597, 0.00741556, 0.00912817])

In [22]:
import time
from tqdm import tqdm

# Define a loop
for i in tqdm(range(10)):
    # Simulate some computation
    time.sleep(0.5)


100%|██████████| 10/10 [00:05<00:00,  1.99it/s]


In [82]:
# function to generate samples
def generate_training_data(sequences, window_size, num_ns, vocab_size, seed):
  # Elements of each training example are appended to these lists.
  targets, contexts, labels = [], [], []

  # Build the sampling table for `vocab_size` tokens.
  sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)

  # Iterate over all sequences (sentences) in the dataset.
  for sequence in tqdm(sequences):

    # Generate positive skip-gram pairs for a sequence (sentence).
    positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
          sequence,
          vocabulary_size=vocab_size,
          sampling_table=sampling_table,
          window_size=window_size,
          negative_samples=0,
          shuffle=False)

    # Iterate over each positive skip-gram pair to produce training examples
    # with a positive context word and negative samples.
    for target_word, context_word in positive_skip_grams:
      context_class = tf.reshape(tf.constant([context_word], dtype="int64"), (1,1))
      negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
          true_classes=context_class,
          num_true=1,
          num_sampled=num_ns,
          unique=True,
          range_max=vocab_size,
          seed=seed,
          name="negative_sampling")

      # Build context and label vectors (for one target word)
      context = tf.concat([tf.squeeze(context_class,1), negative_sampling_candidates], 0)
      label = tf.constant([1] + [0]*num_ns, dtype="int64")

      # Append each element from the training example to global lists.
      targets.append(target_word)
      contexts.append(context)
      labels.append(label)

  return targets, contexts, labels

In [83]:
# prep data

# vocab dict
vocab, index = {}, 1
vocab['<pad>'] = 0
for line in data_lst:
    for word in line:
        if word not in vocab:
            vocab[word] = index
            index += 1

# inverse_vocab dict
inverse_vocab = {}
for word, index in vocab.items():
    inverse_vocab[index] = word

# sequences
sequences = []
for line in data_lst:
    vectorized_line = [vocab[word] for word in line]
    sequences.append(vectorized_line)

In [84]:
len(vocab)

15641

In [85]:
# generate training data
window_size = 5
num_ns = 4
vocab_size = len(vocab)
seed = 4212

targets, contexts, labels = generate_training_data(sequences=sequences,
                                                 window_size=window_size,
                                                 num_ns=num_ns,
                                                 vocab_size=vocab_size,
                                                 seed=seed)

targets = np.array(targets)
contexts = np.array(contexts)
labels = np.array(labels)

print(f'targets shape: {targets.shape}')
print(f'contexts shape: {contexts.shape}')
print(f'labels shape: {labels.shape}')

100%|██████████| 200/200 [03:34<00:00,  1.07s/it]


targets shape: (458431,)
contexts shape: (458431, 5)
labels shape: (458431, 5)


In [77]:
BATCH_SIZE = 1000
BUFFER_SIZE = 10000
dataset = tf.data.Dataset.from_tensor_slices(((targets, contexts), labels))
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

# cache and prefetch data for efficient loading
dataset = dataset.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

print(dataset)

<_PrefetchDataset element_spec=((TensorSpec(shape=(1000,), dtype=tf.int64, name=None), TensorSpec(shape=(1000, 5), dtype=tf.int64, name=None)), TensorSpec(shape=(1000, 5), dtype=tf.int64, name=None))>


In [78]:
# define model
class Word2Vec(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim):
    super(Word2Vec, self).__init__()
    self.target_embedding = tf.keras.layers.Embedding(vocab_size,
                                      embedding_dim,
                                      name="w2v_embedding")
    self.context_embedding = tf.keras.layers.Embedding(vocab_size,
                                       embedding_dim)

  def call(self, pair):
    target, context = pair
    # target: (batch,)
    # context: (batch, context)
    word_emb = self.target_embedding(target)
    # word_emb: (batch, embed)
    context_emb = self.context_embedding(context)
    # context_emb: (batch, context, embed)
    dots = tf.einsum('be,bce->bc', word_emb, context_emb)
    # dots: (batch, context)
    return dots

In [79]:
embedding_dim = 64
word2vec = Word2Vec(vocab_size, embedding_dim)
word2vec.compile(optimizer='adam',
                 loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                 metrics=['accuracy'])

In [80]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="logs")

In [81]:
word2vec.fit(dataset, epochs=10, callbacks=[tensorboard_callback])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.src.callbacks.History at 0x7f88cb6bcdc0>

In [82]:
word2vec.summary()

Model: "word2_vec"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 w2v_embedding (Embedding)   multiple                  2355456   
                                                                 
 embedding (Embedding)       multiple                  2355456   
                                                                 
Total params: 4710912 (17.97 MB)
Trainable params: 4710912 (17.97 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [107]:
inverse_vocab[targets[4]]

'killed'

In [108]:
inverse_vocab[contexts[4][0]]

'bombing'

## Implementation from scratch

In [27]:
import jax
import jax.numpy as jnp

In [None]:
# prep data

# vocab dict
vocab, index = {}, 1
vocab['<pad>'] = 0
for line in data_lst:
    for word in line:
        if word not in vocab:
            vocab[word] = index
            index += 1

# inverse_vocab dict
inverse_vocab = {}
for word, index in vocab.items():
    inverse_vocab[index] = word

# sequences
sequences = []
for line in data_lst:
    vectorized_line = [vocab[word] for word in line]
    sequences.append(vectorized_line)

In [94]:
[inverse_vocab[i] for i in sequences[0][:10]]

['juventus',
 'striker',
 'alvaro',
 'morata',
 'has',
 'slammed',
 'real',
 'madrid',
 'boss',
 'carlo']

In [103]:
# code to generate positive skip grams
def generate_positive_skip_grams(sequence, window_size):
    positive_skip_grams = []
    # for each sentence
    for center_word_pos in range(len(sequence)):
        # for each window position
        for w in range(-window_size, window_size + 1):
            context_word_pos = center_word_pos + w
            # make soure not jump out sentence
            if context_word_pos < 0 or context_word_pos >= len(sequence) or center_word_pos == context_word_pos:
                continue
            context_word_idx = sequence[context_word_pos]
            positive_skip_grams.append((sequence[center_word_pos], context_word_idx))

    positive_skip_grams = np.array(positive_skip_grams) # it will be useful to have this as numpy array
    return positive_skip_grams

In [101]:
generate_positive_skip_grams(sequences[0], 5)

array([[  1,   2],
       [  1,   3],
       [  1,   4],
       ...,
       [ 71, 199],
       [ 71,  21],
       [ 71, 200]])

In [104]:
# function to generate samples
def generate_training_data(sequences, window_size, num_ns, vocab_size, seed):
  # Elements of each training example are appended to these lists.
  targets, contexts, labels = [], [], []

  # Build the sampling table for `vocab_size` tokens.
  sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)

  # Iterate over all sequences (sentences) in the dataset.
  for sequence in tqdm(sequences):

    # Generate positive skip-gram pairs for a sequence (sentence).
    positive_skip_grams = generate_positive_skip_grams(sequence, window_size)

    # Iterate over each positive skip-gram pair to produce training examples
    # with a positive context word and negative samples.
    for target_word, context_word in positive_skip_grams:
      context_class = tf.reshape(tf.constant([context_word], dtype="int64"), (1,1))
      negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
          true_classes=context_class,
          num_true=1,
          num_sampled=num_ns,
          unique=True,
          range_max=vocab_size,
          seed=seed,
          name="negative_sampling")

      # Build context and label vectors (for one target word)
      context = tf.concat([tf.squeeze(context_class,1), negative_sampling_candidates], 0)
      label = tf.constant([1] + [0]*num_ns, dtype="int64")

      # Append each element from the training example to global lists.
      targets.append(target_word)
      contexts.append(context)
      labels.append(label)

  return targets, contexts, labels

In [105]:
# generate training data
window_size = 5
num_ns = 4
vocab_size = len(vocab)
seed = 4212

targets, contexts, labels = generate_training_data(sequences=sequences,
                                                 window_size=window_size,
                                                 num_ns=num_ns,
                                                 vocab_size=vocab_size,
                                                 seed=seed)

targets = np.array(targets)
contexts = np.array(contexts)
labels = np.array(labels)

print(f'targets shape: {targets.shape}')
print(f'contexts shape: {contexts.shape}')
print(f'labels shape: {labels.shape}')

100%|██████████| 200/200 [11:13<00:00,  3.37s/it]


targets shape: (1401730,)
contexts shape: (1401730, 5)
labels shape: (1401730, 5)


In [106]:
# see what the data looks like
print(f'Example (1 data point)\nTargets: {targets[0]}')
print(f'Contexts: {contexts[0]}')
print(f'Labels: {labels[0]}')

# size of data
print(f'Total number of data: {len(targets)}')

Example (1 data point)
Targets: 1
Contexts: [   2 6619    4    0  811]
Labels: [1 0 0 0 0]
Total number of data: 1401730


In [87]:
# initialize weights
n = 300
v = len(vocab)
V = np.random.normal(0, 1, size=(n, v)) / np.sqrt(v)
U = np.random.normal(0, 1, size=(v, n))/ np.sqrt(n)

print(f'V shape: {V.shape}')
print(f'U shape: {U.shape}')


V shape: (300, 15641)
U shape: (15641, 300)


In [88]:
# sigmoid function
def sigmoid(x):
    """Inputs a real number, outputs a real number"""
    return 1 / (1 + jnp.exp(-x))

# 
def local_loss(target, context, label, V_embedding, U_embedding):
    """
    Input (example)
    target = (188,)
    context = (93, 40, 1648, 1659, 1109)
    label = (1, 0, 0, 0, 0)
    V_embedding: matrix of dim (n x |v|)
    U_embedding: matrix of dim (|v| x n)
    where n = embedding dimension, |v| = vocab size

    Outputs the local_loss -> real number
    """
    v_t = V_embedding.T[target] # shape (300,)
    u_pos = U_embedding[context[0]] # shape(300,)
    u_neg = U_embedding[context[1:]] # shape(4, 300)

    return -jnp.log(sigmoid(-jnp.dot(u_pos.T, v_t))) - jnp.sum(jnp.log(sigmoid(jnp.dot(u_neg, v_t))))


In [89]:
t = targets[0]
c = contexts[0]
l = labels[0]
local_loss(t, c, l, V, U)


Array(3.486331, dtype=float32)

In [107]:
[inverse_vocab[targets[i]] for i in range(20)]

['juventus',
 'juventus',
 'juventus',
 'juventus',
 'juventus',
 'striker',
 'striker',
 'striker',
 'striker',
 'striker',
 'striker',
 'alvaro',
 'alvaro',
 'alvaro',
 'alvaro',
 'alvaro',
 'alvaro',
 'alvaro',
 'morata',
 'morata']

In [71]:
#  initialize weights
# U, V