# Skip-Gram (Not Optimized)

#### Outline
1. Data Preprocessing
2. Generate Training and Testing Data
3. Define the Skip-Gram Model
4. SGD
5. Evaluation of Model


In [1]:
import numpy as np
import pandas as pd
import tqdm
import jax
import jax.numpy as jnp
import string
import tensorflow as tf
import time
from tqdm import tqdm
import nltk
from nltk.corpus import stopwords

## Data Preprocessing

In [2]:
# Read Data
data = pd.read_csv('./data/raw data/raw_data.csv', header=0, names=['text'], usecols=[1])
print(f'Data Shape: {data.shape}')
data.head()

Data Shape: (13368, 1)


Unnamed: 0,text
0,"Sally Forrest, an actress-dancer who graced th..."
1,A middle-school teacher in China has inked hun...
2,A man convicted of killing the father and sist...
3,Avid rugby fan Prince Harry could barely watch...
4,A Triple M Radio producer has been inundated w...


In [3]:
# remove punctuation
punctuations = string.punctuation
def remove_punctuation(txt):
    for char in punctuations:
        if char in txt:
            txt = txt.replace(char, "")
    return txt

# change to lower caps
data['text'] = data['text'].str.lower()

# remove punctuations
data['text'] = data['text'].apply(remove_punctuation)

In [4]:
# remove stopwords
# read stopwords from data/raw data/stopwords.txt
stop_words = []
with open('./data/raw data/stopwords.txt', 'r') as f:
    for line in f:
        stop_words.append(line.strip())

def remove_stopwords(txt):
    txt = [word for word in txt.split() if word not in stop_words]
    return ' '.join(txt)

data['text'] = data['text'].apply(remove_stopwords)

In [5]:
# split each row into list of words
data_lst = data['text'].apply(lambda txt: txt.split(" "))

# select number of rows to be used as training data
nrows = 200
random_indices = np.random.randint(low=0, high=len(data_lst), size=nrows)
data_lst = data_lst[random_indices].reset_index(drop=True)

print(f'Number of rows of data: {len(data_lst)}')
data_lst[:5]

Number of rows of data: 200


0    [school, children, country, watching, tomorrow...
1    [oh, beating, heart, standing, gravel, outside...
2    [former, guard, rikers, island, convicted, smu...
3    [move, kim, kardashian, ideal, female, body, s...
4    [trevor, noah, host, daily, spoken, defense, f...
Name: text, dtype: object

In [6]:
# vocab dict
vocab, index = {}, 1
vocab['<pad>'] = 0
for line in data_lst:
    for word in line:
        if word not in vocab:
            vocab[word] = index
            index += 1

# inverse_vocab dict
inverse_vocab = {}
for word, index in vocab.items():
    inverse_vocab[index] = word

print(f'Vocab size: {len(vocab)}')

Vocab size: 14800


In [7]:
# sequences
sequences = []
for line in data_lst:
    vectorized_line = [vocab[word] for word in line]
    sequences.append(vectorized_line)

## Generate Training and Testing Data

In [8]:
# split into train and test sets
# choose 20 random sequences
ntest = 20
test_indices = np.random.randint(low=0, high=len(sequences), size=ntest)
test_sequences = [sequences[i] for i in test_indices]
train_sequences = [sequences[i] for i in range(len(sequences)) if i not in test_indices]

In [9]:
#[inverse_vocab[i] for i in train_sequences[0]]

In [10]:
#[inverse_vocab[i] for i in test_sequences[0]]

In [171]:
# function to generate training and testing data
def generate_data(sequences, window_size):
    targets, contexts = [], []
    for sequence in tqdm(sequences):
        # for each sentence
        for center_word_pos in range(len(sequence)):
            # add to targets
            targets.append(sequence[center_word_pos])

            context = []
            # for each window position
            for w in range(-window_size, window_size + 1):
                context_word_pos = center_word_pos + w
                # make sure not jump out sentence
                if context_word_pos < 0 or context_word_pos >= len(sequence) or center_word_pos == context_word_pos:
                    continue
                context_word_idx = sequence[context_word_pos]
                context.append(context_word_idx)

            # if length of context < 2*window_size
            # pad context with None until length is 2 * window_size
            if len(context) < 2 * window_size:
                pad_length = (2 * window_size) - len(context)
                for i in range(pad_length):
                    context.append(-1)

            # add to contexts
            context = np.array(context)#; print(context)
            contexts.append(context)

    targets = np.array(targets)
    contexts = np.array(contexts)

    return targets, contexts

In [172]:
# generate training data
targets, contexts = generate_data(train_sequences, 5)

print(f'targets shape: {targets.shape}')
print(f'contexts shape: {contexts.shape}')


# generate testing data
test_targets, test_contexts = generate_data(test_sequences, 5)

print(f'test_targets shape: {test_targets.shape}')
print(f'test_contexts shape: {test_contexts.shape}')

 27%|██▋       | 49/182 [00:00<00:00, 480.50it/s]

100%|██████████| 182/182 [00:00<00:00, 483.67it/s]


targets shape: (59569,)
contexts shape: (59569, 10)


100%|██████████| 20/20 [00:00<00:00, 433.20it/s]

test_targets shape: (7452,)
test_contexts shape: (7452, 10)





**Print a few examples of training and testing data**

In [173]:
# training data
index = 0
print(f"target_index    : {targets[index]}")
print(f"target_word     : {inverse_vocab[targets[index]]}")
print(f"context_indices : {contexts[index]}")
for c in contexts[index]:
    if c == -1:
        continue
    print(f"context_words   : {[inverse_vocab[c]]}")

print("target  :", targets[index])
print("context :", contexts[index])

target_index    : 1
target_word     : school
context_indices : [ 2  3  4  5  6 -1 -1 -1 -1 -1]
context_words   : ['children']
context_words   : ['country']
context_words   : ['watching']
context_words   : ['tomorrows']
context_words   : ['onceinadecade']
target  : 1
context : [ 2  3  4  5  6 -1 -1 -1 -1 -1]


In [174]:
# testing data
index = 0
print(f"target_index    : {test_targets[index]}")
print(f"target_word     : {inverse_vocab[test_targets[index]]}")
print(f"context_indices : {test_contexts[index]}")
for c in test_contexts[index]:
    if c == -1:
        continue
    print(f"context_words   : {[inverse_vocab[c]]}")

print("target  :", test_targets[index])
print("context :", test_contexts[index])

target_index    : 11403
target_word     : taulupe
context_indices : [11404  5338  1542  4282  4283    -1    -1    -1    -1    -1]
context_words   : ['faletau']
context_words   : ['believes']
context_words   : ['winning']
context_words   : ['rbs']
context_words   : ['6']
target  : 11403
context : [11404  5338  1542  4282  4283    -1    -1    -1    -1    -1]


## Define the Skip-Gram Model

In [184]:
# define a function that takes in as input target and vocab length
# outputs a one-hot vector, x_hot of dimension = (vocab_length,)
def get_x_hot(target_idx, vocab_length):
    x_hot = np.zeros(vocab_length, dtype=float)
    x_hot[target_idx] = 1.0
    return jnp.array(x_hot)

def get_x_hot(target_idx, vocab_length):
    x_hot = jnp.zeros(vocab_length, dtype=jnp.float32)
    x_hot = x_hot.at[target_idx].set(1.0)
    return x_hot

get_x_hot(3, 10)

Array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [185]:
def get_y_true(context_idxs, vocab_length):
    y_true = np.zeros(vocab_length)
    for i in context_idxs:
        y_true[int(i)] = 1.
    return jnp.array(y_true)
get_y_true(np.array([1, 2, 3]), 10)

Array([0., 1., 1., 1., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [186]:
# define softmax function
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    return jnp.exp(x) / jnp.sum(jnp.exp(x), axis=0)

x = softmax(np.array([1.,2.,3.]))
x

Array([0.09003057, 0.24472848, 0.66524094], dtype=float32)

In [187]:
# define a forward pass in the skip gram model 
def net(target_hot, V, U):
    """
    Input
    x_hot is one-hot vector, dimensions = (|v| x 1)
    V: input embedding matrix, dimension = (n x |v|)
    U: output embedding matrix, dimension = (|v| x n)
        |v| = vocab size
        n = no. of embedding dimensions
    Output
    z: score vector, dimension = (|v| x 1)
    """
    return jnp.array(U @ V @ target_hot)

n = 300
v = 10000
V = np.random.normal(0, 1, size=(n, v))
U = np.random.normal(0, 1, size=(v, n))
x_hot = get_x_hot(3, v)
z = net(x_hot, V, U)
type(z)

jaxlib.xla_extension.ArrayImpl

In [192]:
# define local loss func
def local_loss(params, target, context):
    """
    Input
    params = [V, U]
    V = input embedding matrix, dimensions = (n x |v|)
    U = output embedding matrix, dimensions = (|v| x n)
    target: target index, e,g. (98,)
    context: context indices, e.g. (94, 95, 95, 97, 99, 100, 101, None, None), dimensions = (2*window_size)
    
    Output
    loss_value: real number
    """
    V = params[0] #; print(f'V shape: {V.shape} \t type: {type(V)}')
    U = params[1] #; print(f'U shape: {U.shape} \t type: {type(U)}')
    x_hot = get_x_hot(target, len(U)) # ; print(f'x_hot shape: {x_hot.shape} \t type: {type(x_hot)}')
    # y_true = get_y_true(context)
    z = net(x_hot, V, U) #; print(f'z shape: {z.shape} \t type: {type(z)}')
    y_hat = softmax(z) #; print(f'y_hat shape: {y_hat.shape} \t type: {type(y_hat)}')
    loss = 0
    for c in context:
        if jnp.bool_(c == -1):
            continue
        loss += y_hat[c]
    return -1*loss

# n = 300; v = 10000
# V = np.random.normal(0, 1, size=(n, v)) / np.sqrt(v)
# U = np.random.normal(0, 1, size=(v, n)) / np.sqrt(v)
# p = [V, U]
# t = targets[0]
# c = contexts[0]
# local_loss(p, t, c)
# loss_value_and_grad = jax.value_and_grad(local_loss, argnums=0)
# v, g = loss_value_and_grad(p, t, c)


In [193]:
loss_all = jax.vmap(local_loss, in_axes=(None, 0, 0))

@jax.jit
def loss(params, targets, contexts):
    """return average of all the local losses"""
    all_losses = loss_all(params, targets, contexts)
    print(f'all_losses shape: {all_losses.shape}')
    return jnp.mean(all_losses)

In [194]:
# get the loss value and gradient
n = 300; v = 10000
V = np.random.normal(0, 1, size=(n, v)) / np.sqrt(v)
U = np.random.normal(0, 1, size=(v, n)) / np.sqrt(v)
p = [V, U]
ts = targets[:5]
cs = contexts[:5]
loss_value_and_grad = jax.jit( jax.value_and_grad(loss) )
v, g  = loss_value_and_grad(p, ts, cs)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[].
The problem arose with the `bool` function. 
This BatchTracer with object id 140548154691152 was created on line:
  /var/folders/_3/8z9s_23x6w349w1_9vlqlhzh0000gn/T/ipykernel_1631/3823550736.py:22 (local_loss)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [71]:
jnp.array([1., 2., None])

Array([ 1.,  2., nan], dtype=float32)

## Stochastic Gradient Descent