In [1]:
import numpy as np
import pandas as pd
import tqdm
import jax
import jax.numpy as jnp
import string
import tensorflow as tf
import time
from tqdm import tqdm
import nltk
from nltk.corpus import stopwords

## Data Preprocessing

In [2]:
data = pd.read_csv('./data/raw data/raw_data.csv', header=0, names=['text'], usecols=[1])
print(f'Data Shape: {data.shape}')
data.head()

Data Shape: (13368, 1)


Unnamed: 0,text
0,"Sally Forrest, an actress-dancer who graced th..."
1,A middle-school teacher in China has inked hun...
2,A man convicted of killing the father and sist...
3,Avid rugby fan Prince Harry could barely watch...
4,A Triple M Radio producer has been inundated w...


In [3]:
# remove punctuation
punctuations = string.punctuation
def remove_punctuation(txt):
    for char in punctuations:
        if char in txt:
            txt = txt.replace(char, "")
    return txt

# change to lower caps
data['text'] = data['text'].str.lower()

# remove punctuations
data['text'] = data['text'].apply(remove_punctuation)

In [4]:
# remove stopwords
# read stopwords from data/raw data/stopwords.txt
stop_words = []
with open('./data/raw data/stopwords.txt', 'r') as f:
    for line in f:
        stop_words.append(line.strip())

def remove_stopwords(txt):
    txt = [word for word in txt.split() if word not in stop_words]
    return ' '.join(txt)

data['text'] = data['text'].apply(remove_stopwords)

In [5]:
# split each row into list of words
data_lst = data['text'].apply(lambda txt: txt.split(" "))

# select number of rows to be used as training data
nrows = 200
random_indices = np.random.randint(low=0, high=len(data_lst), size=nrows)
data_lst = data_lst[random_indices].reset_index(drop=True)

print(f'Number of rows of data: {len(data_lst)}')
data_lst[:5]

Number of rows of data: 200


0    [cnnafter, handful, events, months, hillary, c...
1    [account, manager, quit, job, professional, vl...
2    [missing, philae, space, probe, bumped, onto, ...
3    [cnnigniting, live, cage, severing, heads, doz...
4    [final, scoreline, suggested, otherwise, gordo...
Name: text, dtype: object

In [6]:
# vocab dict
vocab, index = {}, 1
vocab['<pad>'] = 0
for line in data_lst:
    for word in line:
        if word not in vocab:
            vocab[word] = index
            index += 1

# inverse_vocab dict
inverse_vocab = {}
for word, index in vocab.items():
    inverse_vocab[index] = word

print(f'Vocab size: {len(vocab)}')

Vocab size: 14519


In [7]:
# sequences
sequences = []
for line in data_lst:
    vectorized_line = [vocab[word] for word in line]
    sequences.append(vectorized_line)

## Split Train and Test Sets

In [8]:
# split into train and test sets
# choose 20 random sequences
ntest = 20
test_indices = np.random.randint(low=0, high=len(sequences), size=ntest)
test_sequences = [sequences[i] for i in test_indices]
train_sequences = [sequences[i] for i in range(len(sequences)) if i not in test_indices]

In [9]:
# function to generate samples
def generate_training_data(sequences, window_size, num_ns, vocab_size, seed):
  # Elements of each training example are appended to these lists.
  targets, contexts, labels = [], [], []

  # Build the sampling table for `vocab_size` tokens.
  sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)

  # Iterate over all sequences (sentences) in the dataset.
  for sequence in tqdm(sequences):

    # Generate positive skip-gram pairs for a sequence (sentence).
    positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
          sequence,
          vocabulary_size=vocab_size,
          sampling_table=sampling_table,
          window_size=window_size,
          negative_samples=0,
          shuffle=True)

    # Iterate over each positive skip-gram pair to produce training examples
    # with a positive context word and negative samples.
    for target_word, context_word in positive_skip_grams:
      context_class = tf.reshape(tf.constant([context_word], dtype="int64"), (1,1))
      negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
          true_classes=context_class,
          num_true=1,
          num_sampled=num_ns,
          unique=True,
          range_max=vocab_size,
          seed=seed,
          name="negative_sampling")

      # Build context and label vectors (for one target word)
      context = tf.concat([tf.squeeze(context_class,1), negative_sampling_candidates], 0)
      label = tf.constant([1] + [0]*num_ns, dtype="int64")

      # Append each element from the training example to global lists.
      targets.append(target_word)
      contexts.append(context)
      labels.append(label)

  return targets, contexts, labels

In [10]:
# function to generate testing data
def generate_testing_data(sequences, vocab_size, window_size):
    targets, contexts, labels = [], [], []
    for sequence in tqdm(sequences):
        positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
            sequence,
            vocabulary_size=vocab_size,
            window_size=window_size,
            negative_samples=0)
    for target_word, context_word in positive_skip_grams:
        targets.append(target_word)
        contexts.append(context_word)
        labels.append(1)
    return targets, contexts, labels

In [11]:
# generate training data
window_size = 5
num_ns = 4
vocab_size = len(vocab)
seed = 4212

targets, contexts, labels = generate_training_data(sequences=train_sequences,
                                                 window_size=window_size,
                                                 num_ns=num_ns,
                                                 vocab_size=vocab_size,
                                                 seed=seed)

targets = np.array(targets)
contexts = np.array(contexts)
labels = np.array(labels)

print(f'targets shape: {targets.shape}')
print(f'contexts shape: {contexts.shape}')
print(f'labels shape: {labels.shape}')

100%|██████████| 183/183 [02:29<00:00,  1.23it/s]


targets shape: (309499,)
contexts shape: (309499, 5)
labels shape: (309499, 5)


In [12]:
# generate testing data
targets_test, contexts_test, labels_test = generate_testing_data(sequences=test_sequences,
                                                                    vocab_size=vocab_size,
                                                                    window_size=window_size)

targets_test = np.array(targets_test)
contexts_test = np.array(contexts_test)
labels_test = np.array(labels_test)

print(f'targets_test shape: {targets_test.shape}')
print(f'contexts_test shape: {contexts_test.shape}')
print(f'labels_test shape: {labels_test.shape}')

100%|██████████| 20/20 [00:00<00:00, 158.04it/s]

targets_test shape: (3970,)
contexts_test shape: (3970,)
labels_test shape: (3970,)





**Sanity Check on quality of training and testing data**

In [13]:
# training data
print(f"target_index    : {targets[0]}")
print(f"target_word     : {inverse_vocab[targets[0]]}")
print(f"context_indices : {contexts[0]}")
print(f"context_words   : {[inverse_vocab[c] for c in contexts[0]]}")
print(f"label           : {labels[0]}")

print("target  :", targets[0])
print("context :", contexts[0])
print("label   :", labels[0])

target_index    : 120
target_word     : irish
context_indices : [ 128   38 1557 1567 1050]
context_words   : ['welsh', 'gala', 'portuguese', 'prodded', 'wee']
label           : [1 0 0 0 0]
target  : 120
context : [ 128   38 1557 1567 1050]
label   : [1 0 0 0 0]


In [14]:
# testing data
print(f"target_index    : {targets_test[0]}")
print(f"target_word     : {inverse_vocab[targets_test[0]]}")
print(f"context_index : {contexts_test[0]}")
print(f"context_word   : {inverse_vocab[contexts_test[0]]}")
print(f"label           : {labels_test[0]}")

print("target  :", targets_test[0])
print("context :", contexts_test[0])
print("label   :", labels_test[0])

target_index    : 709
target_word     : health
context_index : 14466
context_word   : cadwaladr
label           : 1
target  : 709
context : 14466
label   : 1


## Minimizing Objective Function for SGNS

$$

\min_{\theta} = \frac{1}{N} \sum_{i=1}^{N} [log \sigma(u_{ic}^T)  + \sum_{k=1}^{K}log \sigma(-u_{kc}^T v_{iw})]

\\
\\
\theta = [U, V]

$$

In [51]:
# sigmoid function
def sigmoid(x):
    """Inputs a real number, outputs a real number"""
    return 1 / (1 + jnp.exp(-x))


In [52]:
# define function to get embedding vectors
def get_embedding_vectors(target, context, V, U):
    """
    Input (example)
    target = (188,)
    context = (93, 40, 1648, 1659, 1109)
    V: matrix of dim (n x |v|)
    U: matrix of dim (|v| x n)
        n = embedding dimension, |v| = vocab size

    Output
    v_t: target word vector, dimension: (n,)
    u_c: context word vectors, consists of u_pos and u_neg: dimension: (n, len(context))
    """
    target = target.astype(int)
    context = context.astype(int)
    v_t = V.T[target]
    u_c = U[context]
    return v_t, u_c

# t = targets[0]
# c = contexts[0]
# v_test, u_test = get_embedding_vectors(t, c, V, U)
# print(f'v_test shape: {v_test.shape}')
# print(f'u_test shape: {u_test.shape}')

In [53]:
# define local_loss function
@jax.jit
def local_loss(params):
    """
    Input (example)
    params = [v_t, jnp.array([u_pos, u_neg])]
        v_t: target word vector, dimension: (n,)
        u_c: context word vectors, consists of u_pos and u_neg: dimension: (len(context), n)

    Output
    local_loss: real number
    """
    v_t = params[0]
    u_c = params[1]
    return -jnp.log(sigmoid(jnp.dot(u_c[0], v_t))) - jnp.sum(jnp.log(sigmoid(-jnp.dot(u_c[1:], v_t))))

# p = [V[:, 0], U[:5, :]]
# print(f'local_loss: {local_loss(p)}')


In [54]:
# define gradient function
L_grad = jax.grad(local_loss)

# g = L_grad([V[:, 0], U[:5, :]])
# print(f"g[0] shape: {g[0].shape}")
# print(f"g[1] shape: {g[1].shape}")

In [65]:
# define a function to update parameters
# def update_params(target, context, partial_v, partial_u, V, U):
#     t = target; print(f'target: {type(t)}')
#     c = context; print(f'context: {type(c)}')
#     V[:, t] = V[:, t] - lr * partial_v
#     U[c, :] = U[c, :] - lr * partial_u

def update_params(target, context, partial_v, partial_u, V, U):
    # Update V and U
    V_updated = V.at[:, target].add(-lr * partial_v)
    U_updated = U.at[context, :].add(-lr * partial_u)

    return V_updated, U_updated

In [66]:
# define batch functions
get_batch_embedding_vectors = jax.vmap(get_embedding_vectors, in_axes=(0, 0, None, None))

batch_losses = jax.vmap(local_loss, in_axes=(0))

batch_grads = jax.vmap(L_grad, in_axes=(0))

batch_update_params = jax.vmap(update_params, in_axes=(0, 0, 0, 0, None, None))

In [67]:
# set up
n = 300
v = len(vocab)
V = jnp.array(np.random.normal(0, 1, size=(n, v)) / np.sqrt(v))
U = jnp.array(np.random.normal(0, 1, size=(v, n)) / np.sqrt(v))

params = [V, U]
targets_data = jnp.array(targets.astype(float))
contexts_data = jnp.array(contexts.astype(float))
labels_data = jnp.array(labels.astype(float))

print(f'V shape: {V.shape}')
print(f'U shape: {U.shape}')
print(f'targets_data shape: {targets_data.shape} \t type: {type(targets_data).__name__}')
print(f'contexts_data shape: {contexts_data.shape} \t type: {type(contexts_data).__name__}')
print(f'labels_data shape: {labels_data.shape} \t type: {type(labels_data).__name__}')

V shape: (300, 14519)
U shape: (14519, 300)
targets_data shape: (309499,) 	 type: ArrayImpl
contexts_data shape: (309499, 5) 	 type: ArrayImpl
labels_data shape: (309499, 5) 	 type: ArrayImpl


In [68]:
# train using stochastic gradient descent

# set up
N = len(targets_data)
lr = 1.
n_epochs = 5
batch_size = 1000
n_batches = N // batch_size

epoch_losses = []

# gradient descent
for epoch in range(n_epochs):
    start_time = time.time()

    # shuffle data
    perm = np.random.permutation(N)
    targets_epoch = targets_data[perm]
    contexts_epoch = contexts_data[perm]
    labels_epoch = labels_data[perm]

    # decrease learning rate
    if epoch % 2 == 0:
        lr /= 2

    # losses in each epoch
    losses = []
    for batch in tqdm(range(n_batches)):
        targets_batch = targets_epoch[batch*batch_size: (batch+1)*batch_size]
        contexts_batch = contexts_epoch[batch*batch_size: (batch+1)*batch_size]
        labels_batch = labels_epoch[batch*batch_size: (batch+1)*batch_size]

        # get batch embedding vectors
        v_t_batch, u_c_batch = get_batch_embedding_vectors(targets_batch, contexts_batch, V, U)

        # calculate batch_losses and batch_grads
        loss_values = np.mean(batch_losses([v_t_batch, u_c_batch]))
        grads = batch_grads([v_t_batch, u_c_batch])

        # batch update parameters
        partial_V_batch = grads[0]
        partial_U_batch = grads[1]

        batch_update_params(targets_batch, contexts_batch, partial_V_batch, partial_U_batch, V, U)

        # store the loss value
        losses.append(np.mean(loss_values))
    
    # store epoch losses
    epoch_losses.append(np.mean(losses))

    end_time = time.time()
    
    # print epoch loss
    if epoch % 1 == 0:
        print(f"Epoch {epoch+1}/{n_epochs} \t loss = {np.mean(epoch_losses)} \t time = {end_time - start_time:.2f}s")

  0%|          | 0/309 [00:00<?, ?it/s]

  0%|          | 0/309 [00:00<?, ?it/s]


TypeError: Indexer must have integer or boolean type, got indexer with type float32 at position 1, indexer value Traced<ShapedArray(float32[])>with<BatchTrace(level=1/0)> with
  val = Array([ 6837.,  9358.,  8486., 13757.,  6831.,  3024.,   976.,   648.,
       10064.,  9505.,   296.,  7598., 11652.,  3101.,  9277., 12584.,
       11340.,  1044.,  8537.,  1257., 10175.,  6819.,  6354.,  6369.,
        3074.,   684.,  5253.,  4097., 13407.,  1133., 12121.,  9409.,
       14395.,  3654.,  4420.,  8502.,   846.,  7232.,  2572., 11862.,
        1541.,  6562., 14391.,  9669., 10520.,  2263.,  1339.,  7096.,
        1834.,  4628.,  6641.,  5897.,  1284., 11829.,  6472.,  8953.,
        4005.,  4673.,  7956.,  2556.,  5122.,  3722.,  1313., 11397.,
        4021.,  2469., 13226.,  1468.,   888.,  1902.,  5171.,  1689.,
        6466.,  8192.,  3573.,  2815.,  5617., 10437.,  3110.,  7802.,
       11910.,  2488.,  8364., 11290.,   902., 14353.,  9152.,  1481.,
        4397.,   316.,  7635.,  7067.,  1389.,  5652.,  8946.,   497.,
        9814.,  7318.,  8535.,  3559.,  1159.,  1753.,  7154.,  4229.,
        6099., 12700.,  3093.,  2728.,  2203., 12566.,  3902.,  5963.,
       13437.,  3727.,   694.,  5795.,  2395.,  2285.,  6954.,  1420.,
        2131.,  8668., 13153.,  3513.,  5061.,  7757.,  3281.,  4969.,
       10643., 11136.,  1490.,  1308.,   747.,  3924.,  1050.,  3403.,
       10127., 10369.,  3798.,  6757.,  4713.,  2622., 12094., 13047.,
        5855.,  4643.,  8381.,   747.,  9661.,  7481.,  7799.,  2815.,
        3581.,   487., 10853., 11469.,  5176.,  9779.,  9120.,  7296.,
        8044.,  6663.,  4880.,  2740.,  3619.,  6597.,  3814.,  3770.,
        3524., 13789.,  4468.,  1979., 11546.,  4734.,  8183.,  3564.,
         875.,   112.,  1990.,  2219., 13638.,  1140., 11823.,  5751.,
         839.,  2686., 12791.,  7115., 10779.,  2917., 13825.,  9080.,
       11382.,  9027.,  4425.,   475.,   227., 10979.,  8520.,   875.,
       11828.,  3040.,  8170.,  3186.,  6471.,  2734.,  6883.,  4290.,
         980., 13346., 10108.,  1363.,  5870., 10761.,  7065.,  9807.,
       11012.,  1268., 14273.,  4513.,  9494.,  6508.,  1111., 10595.,
       13850.,  4155.,  3639., 13574., 12770., 12839.,  2297.,  4114.,
       13158.,  2264., 13548.,  8035., 10437.,  7629., 10156., 10445.,
        2621.,  1395.,  4307.,    63.,  1158.,   320.,  3391.,   413.,
        9754.,  5812.,   964.,  2683.,  9357.,  1056., 13790.,  1317.,
        5956.,  1328.,  2762.,  3557.,  9358.,  5161.,  3177., 13540.,
        7623.,  1933., 11315., 13834., 10076.,  3104.,  1223.,  9700.,
        6243.,  4657.,  5229.,  8590.,  5786.,  4416.,   616., 11219.,
       10300.,  4772.,  9236.,  3349.,  3285.,  2602.,  5112., 12790.,
        6082., 11992., 14401., 12591.,  3309.,  9194.,  8821.,  6025.,
        9835.,   517.,  7760., 10109., 10068.,  4420.,  5329., 13333.,
        8603.,    76.,  3398.,  5645.,  3980.,  4148., 14175.,   623.,
        5225.,  4238.,  5248.,  2937., 10826.,  3466.,  5627.,   996.,
         459.,  4021.,  3304., 14011.,  8961.,  5764.,  8513.,  2381.,
        1306.,  5163., 14197., 12969.,  4757.,  9463., 10510.,  6186.,
        8910.,  7928., 10948.,  1937.,  9265.,  1990.,  7217.,  8572.,
        1893.,  1422.,  9398., 11439.,  9630.,  1327.,  3223.,  4716.,
        2864.,  1540.,  3661., 11082.,  7277.,  8908.,  1499.,  4021.,
       13158.,  2283.,  7113.,  3751.,   779.,  5491.,  8104.,  2379.,
       13785., 13851.,  5526.,  5416.,  6523.,  3692.,  7634.,  3987.,
        2902.,  1911.,  5202.,  9215.,  4422.,  4161., 11015., 10983.,
        7068.,  3401., 13431., 12935.,  3108., 14223.,  3702., 11170.,
        1484.,   577.,  2998., 13239.,  1877.,  3009.,  9560.,  4322.,
        9126.,  2637., 11065.,  1647.,  5497.,   847.,  9324.,  7949.,
        1529.,  5957.,  3128.,  4647.,  3661.,  3740.,   515.,  6926.,
        2972.,  4772.,  4905.,  4896.,  8503.,  6959.,  9047.,  2444.,
        7367.,  8846.,  1607., 14028.,  1924., 13403.,  7113.,  7251.,
        4510.,   316.,  7065.,  1476., 10150.,  9670.,  8508.,  5448.,
        5073.,  2445., 13592.,  6279.,  6357.,  6879.,  3872., 10855.,
        1956.,  4801., 13049., 14261.,  3135.,  7966.,  1577.,  2973.,
        2214.,  6900.,  5985.,   511., 13630.,  3692.,  1396.,  6763.,
       11219.,  2263.,   545.,  6153., 10580., 11307.,  1956.,  3171.,
        3550.,  1773.,   845.,  8609., 13437.,  1193.,  7967.,  1245.,
       12861.,  2533.,  4468.,  1007.,  3432.,  5694.,  1777.,  1397.,
        7030., 10638.,  3987.,  6388.,  3090.,  8502.,  3344.,  7106.,
        5485.,   218.,  1012.,   748.,  8736.,  6957.,  5793.,  7154.,
       12502.,  2995.,  8958., 11456., 10149.,  3735.,  1275.,  4778.,
        2516.,  4518., 10580.,  4819.,  3098., 12558.,  4071.,  7975.,
        8332.,  8861.,  8143.,  6829.,  4602.,  1541.,  9633.,  8525.,
        2363.,  4241., 11861.,  1879.,  7715.,  3264., 13231.,  3713.,
        7169.,  4822.,  6016.,  2911.,  2210.,  3573., 11456.,  1138.,
       11127., 10478.,  3321.,  4021.,  4696.,   547., 12258.,  7469.,
        3294., 10608., 12027.,  2525.,  1850.,  5055., 11146.,  8895.,
       10978.,  6341.,  7980.,  1105.,  8219.,  6457.,  6977.,  7996.,
        8537.,  6923.,  8910.,  8599., 14123.,  2565., 11751.,  3236.,
        1228.,  9600.,  5393.,  3345.,  8428., 13677.,  9521.,  1311.,
        5097.,  6925., 11215.,  5103.,  3395.,  9389.,  4672.,  6887.,
       13044.,   152.,  7577.,  2800., 13786.,  6230.,  9180., 12816.,
        3523., 12063.,  1313., 10856.,  1592.,  9970.,  4674.,  8005.,
        3791., 12065., 10437., 12245.,   413.,  1223.,  2063., 14418.,
         262., 12941.,   680.,  7796.,  7487.,  1754.,  6056.,  2356.,
        3372.,  2813.,  6087.,  5002.,  4638.,  6340.,  2656., 10722.,
       11217.,  8402.,  7744.,  3640.,   744.,  5285.,  6165.,  2917.,
        1710.,  9968., 13306.,  2718.,  8482.,  2016.,  6082., 11560.,
       10239.,  5454.,  2753.,  1499., 11861.,  6073.,  1661., 13427.,
         364.,  3066.,  4566.,  8506.,  8124., 10160.,  9220.,  1928.,
        3457.,   638., 13858., 10766.,  3489.,  5354.,  9609.,  2196.,
        8364.,  2034.,  6267., 11397., 10360.,  5983., 13827.,  8656.,
        4473.,  3672.,  6115.,  6716.,  5361.,  2189.,  1308.,  5944.,
        6207.,  7468.,  3694.,  2085.,  2853.,  8462.,  6194.,  4249.,
        3313.,  8377.,  2550.,  4937.,  2283.,  9831., 11788.,  4033.,
        6047.,  6395.,  6153., 12591.,  1774.,   647.,  1020.,  1765.,
        8222.,   284.,  2248., 12365.,  6114.,  6288.,  9663., 12962.,
        8368., 12922.,  1715.,  8848.,  4404., 14103.,  1389.,  2346.,
        8353.,  4397.,  6111.,  5254.,  4607.,  1360., 14255.,  4693.,
        7518., 10891., 11652.,  6509.,  5468., 13067., 13202.,  4783.,
         744., 11534., 14292.,  5890.,  1994.,  5696., 10370.,  3015.,
        4887.,  7834.,  2861.,  7744.,   723., 12746., 13586.,  5088.,
         181.,  3384., 14143.,  4042., 11641.,  5480., 12257.,  5419.,
       11104.,  9916.,   431.,  4995.,  3736.,  1284.,  6817.,  9840.,
        8957.,  8419.,  2508., 13277.,  3810., 11795., 10865.,  1600.,
       10881.,  2343.,  4098.,  1728.,  8848.,  9106., 14002.,  5136.,
        7368.,  2853.,  4452.,  6065., 10718.,  5353., 11512., 10853.,
       13084.,  6024., 13238.,  5136.,  7209.,  8159.,  2587.,   763.,
       10203.,  3733.,  9320., 12771., 12120.,   891.,  2759., 10425.,
        3393.,  9082.,  8953., 11651.,  7208.,  1844.,  3775.,  8165.,
        4307., 11554.,  2742.,  7636.,  3516.,  3907.,  9242.,  7206.,
       13380.,  5657., 10880.,  6586., 11871.,  2808., 10027.,  3221.,
        1813., 10623., 10194.,  6046.,   842.,  3872.,  9496.,  9209.,
        2264., 12970.,  2301.,  5783.,  7798., 10217.,  3987.,  3787.,
        8584.,  2777.,  2219.,  4475.,  2399.,  6161.,  4476.,  3917.,
         687.,  8613.,  8435.,  6845.,  6124.,  1376.,  3565.,  9264.,
        6608., 13998.,  5080.,  3569.,  9628.,  4700.,  3530.,  1255.,
        9940.,   691.,   443., 14086.,  6658.,  6649.,  5114.,  5053.,
        1313.,  4365., 13028.,  3504., 11300.,   449., 11011.,  1392.,
       11475.,  5378.,  2053.,  5939.,  2411.,   145.,  3692.,  5968.,
        3496.,  7809.,  3506.,  1702.,  5999.,  6353.,  3401.,  3802.,
        9559.,  3843.,  1676.,  3823.,  2266.,  9034.,  2968.,  1805.,
        3466.,  9349.,  1300.,  9572.,  8915.,  3434.,  4807.,  5583.,
        3545.,  8079., 11536., 11841., 12337.,   693.,  3235., 10591.,
        3793.,  6051., 13839., 12947.,  3791., 11552.,  5972., 12622.,
       11550.,  1896.,  2562.,  7711.,  1520.,  2947., 11754.,  7920.,
        3877.,  4028.,  7476., 12321., 14374., 10246., 11418.,   680.,
        4905., 10278., 14354.,  7152.,  3174.,  1655., 11161.,   218.,
        8249., 10923., 12375.,  3584.,   218., 11417., 11606.,   535.,
        3431.,  3135., 12113., 13408.,  1798.,  1403.,  3401.,  9905.,
        3561.,  2218., 11699.,  8148.,   931.,  3950.,  3382., 13343.,
       10003.,  1921., 14192.,  6902.,  5739., 10797.,   780.,  2675.],      dtype=float32)
  batch_dim = 0

In [None]:
# train using gradient descent

# set up
N = len(targets_data)
n_epochs = 5

lr = 1.

epoch_losses = []

# gradient descent
for epoch in range(n_epochs):

    start_time = time.time()

    # shuffle data
    perm = np.random.permutation(N)
    targets = targets_data[perm]
    contexts = contexts_data[perm]
    labels = labels_data[perm]

    losses = []
    for i in tqdm(range(N)):
        # stop after 1000 iterations
        if i == 1000:
            break
        target = targets[i].astype(int) # ; print(f'target shape: {target.shape}')
        context = contexts[i].astype(int)# ; print(f'context shape: {context}')
        label = labels[i].astype(int)

        # get the embedding vectors
        v_t, u_c = get_embedding_vectors(target, context, V, U)

        # get the loss value and gradient
        loss_value = local_loss([v_t, u_c])
        grad = L_grad([v_t, u_c])

        # update the parameters
        partial_V = grad[0]
        partial_U = grad[1]
        
        V[:, target] = V[:, target] - lr * partial_V
        U[context, :] = U[context, :] - lr * partial_U

        # store the loss value
        losses.append(loss_value)
    
    # store epoch losses
    epoch_losses.append(np.mean(losses))

    end_time = time.time()
    
    # print epoch loss
    if epoch % 1 == 0:
        print(f"Epoch {epoch+1}/{n_epochs} \t loss = {np.mean(epoch_losses)} \t time = {end_time - start_time:.2f}s")

In [None]:
# # train using stochastic gradient descent

# # number of training examples
# N = len(targets_data)

# # learning rate
# lr = 3.

# # number of epochs
# n_epochs = 100

# # batch size
# batch_size = 10000

# # number of batches per epoch
# n_batches = N // batch_size

# # keep track of losses
# epoch_losses = []

# # training the network
# for epoch in range(n_epochs):
#     start_time = time.time()
#     # shuffle data
#     perm = np.random.permutation(N)
#     targets = targets[perm]
#     contexts = contexts[perm]
#     labels = labels[perm]

#     # half the learning rate every 25 epochs
#     if epoch == 50 or epoch == 75:
#         lr /= 2.

#     # losses in each epoch
#     losses = []
#     for batch in range(n_batches):
#         targets_batch = targets_data[batch*batch_size: (batch+1)*batch_size]
#         contexts_batch = contexts_data[batch*batch_size: (batch+1)*batch_size]
#         labels_batch = labels_data[batch*batch_size: (batch+1)*batch_size]

#         # calculate and save losses
#         loss_value, gradient = loss_value_and_grad(params, targets_batch, contexts_batch)
#         losses.append(loss_value)

#         params = [(V - lr*dV, U - lr*dU) for (V, U), (dV, dU) in zip(params, gradient)]

#     epoch_losses.append(np.mean(losses))

#     end_time = time.time()
#     if epoch % 1 == 0:
#         print(f"Epoch {epoch+1}/{n_epochs} \t loss = {np.mean(epoch_losses)} \t time = {end_time - start_time:.2f}s")

In [None]:
# plot losses
import matplotlib.pyplot as plt
plt.plot(epoch_losses)
plt.xlabel('Iterations')
plt.ylabel('Loss')

In [None]:
# V_ = np.copy(params[0][0])
# U_ = np.copy(params[0][1])

# # check dimensions of U and V, 10 epoch, 1r = 3, 1.5, 0.75, batch size = 500
# print(f'V_ shape: {V_.shape}')
# print(f'U_ shape: {U_.shape}')

In [None]:
# copy U and V
V_trained = np.copy(params[0][0])
U_trained = np.copy(params[0][1])

# check dimensions of U and V, 100 epochs, lr = 1, batch_size = 500
print(f'V_trained shape: {V_trained.shape}')
print(f'U_trained shape: {U_trained.shape}')

## Evaluate against test set

In [None]:
# define function that takes in an index and vocab size and returns the one-hot encoding
def getOneHot(index, vocab_size):
    onehot = np.zeros(vocab_size)
    onehot[index] = 1
    return onehot

# define softmax function
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(x) / np.sum(np.exp(x), axis=0)

# test getOneHot and softmax function
print(getOneHot(1, 10))
print(softmax(np.array([1, 2, 3])))


In [None]:
# check dimensions of U and V
print(f'U shape: {U.shape}')
print(f'V shape: {V.shape}')

In [None]:
# see first 20 words in the vocab
test_words = list(vocab.keys())[20:40]
test_words

In [None]:
# define cosine similarity scores between 2 word vectors
def similarity_score(target_word_embedding, context_word_embedding):
    return np.dot(target_word_embedding, context_word_embedding) / (np.linalg.norm(target_word_embedding) * np.linalg.norm(context_word_embedding))

# define a function that find the most similar words to a given word
def most_similar_words(word, V, n=5):
    scores = []
    target_word_idx = vocab[word]
    for i in range(V.shape[1]):
        if i == target_word_idx or inverse_vocab[i] == '<pad>':
            continue
        scores.append((inverse_vocab[i], similarity_score(V[:, target_word_idx], V[:, i])))
    scores = sorted(scores, key=lambda x: x[1], reverse=True)
    return scores[:n]

In [None]:
# check similarity between words
# word: photos
print(most_similar_words('amazed', V_trained))

In [None]:
# compute a forward pass through the skip-gram model

# define the forward pass function
def net(V, U, target_word_idx):
    target_hot = getOneHot(target_word_idx, len(vocab))
    return softmax( U @ V @ target_hot )

def predict(word, V, U):
    target_word_idx = vocab[word]
    y_hat = net(V, U, target_word_idx)
    # y_hat is the probability distribution over the vocab
    # select the top 5 words with the highest probability
    top_5 = np.argsort(y_hat)[-10:][::-1]
    top_5_words = [inverse_vocab[i] for i in top_5]
    return top_5_words

In [None]:
# randomly select 1 word from vocab
word = np.random.choice(list(vocab.keys()))
y_hat = net(V_trained, U_trained, vocab[word])
print(f'Word: {word}')
predict(word, V_trained, U_trained)