## Word Embeddings

Goal: Implement a simple word embedding in Python (from scratch) and use it to find the most similar words to a given word. Come up with a dataset and evaluation metrics to evaluate the word embeddings.

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from matplotlib import pyplot as plt

Importing NLTK's stopwords

In [2]:
import nltk
import string

nltk.download("stopwords")
from nltk.corpus import stopwords
stopwords = set(stopwords.words('english'))
punctuation = set(string.punctuation)

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\okk15\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Using Brown Corpus

In [3]:
nltk.download("brown")
from nltk.corpus import brown

[nltk_data] Downloading package brown to
[nltk_data]     C:\Users\okk15\AppData\Roaming\nltk_data...
[nltk_data]   Package brown is already up-to-date!


In [4]:
brown.categories()

['adventure',
 'belles_lettres',
 'editorial',
 'fiction',
 'government',
 'hobbies',
 'humor',
 'learned',
 'lore',
 'mystery',
 'news',
 'religion',
 'reviews',
 'romance',
 'science_fiction']

In [5]:
corpus = brown.sents(categories="news")

In [6]:
# lower case all words, remove all punctuations and stop words
def preprocess_sentence(sentence):
    return [word.lower() for word in sentence if word.lower() not in stopwords and word not in punctuation]

corpus_preprocessed = [preprocess_sentence(sentence) for sentence in corpus]
corpus_preprocessed[0]

['fulton',
 'county',
 'grand',
 'jury',
 'said',
 'friday',
 'investigation',
 "atlanta's",
 'recent',
 'primary',
 'election',
 'produced',
 '``',
 'evidence',
 "''",
 'irregularities',
 'took',
 'place']

### Building Co-occurrence Matrix

In [None]:
from collections import defaultdict
import numpy as np

# Build vocabulary index
def build_vocab_idx(corpus_preprocessed):
    vocab_count = defaultdict(int)
    for sentence in corpus_preprocessed:
        for word in sentence:
            vocab_count[word] += 1
    return {word: idx for idx, (word, _) in enumerate(vocab_count.items())}, vocab_count

vocab_idx, vocab_count = build_vocab_idx(corpus_preprocessed)
vocab_size = len(vocab_idx)

# Window size for Co-occurrence matrix - 2 words before, 2 words after
window_size = 2

# Build co-occurrence matrix
def build_cooccurrence_matrix(corpus_preprocessed, vocab_idx, window_size):
    cooccurrence_matrix = defaultdict(lambda: defaultdict(float))
    
    for sentence in corpus_preprocessed:
        sentence_length = len(sentence)
        for i, word in enumerate(sentence):
            word_idx = vocab_idx[word]
            # Context window
            start = max(0, i - window_size)
            end = min(sentence_length, i + window_size + 1)
            
            for j in range(start, end):
                # Skip target word
                if i != j:
                    context_word = sentence[j]
                    context_word_idx = vocab_idx[context_word]
                    # Increment the co-occurrence count with inverse distance weighting
                    cooccurrence_matrix[word_idx][context_word_idx] += 1.0 / abs(i - j)
    return cooccurrence_matrix

cooccurrence_matrix = build_cooccurrence_matrix(corpus_preprocessed, vocab_idx, window_size)
cooccurrence_matrix[0]

### GloVe: Global Vectors for Word Representation

In [None]:
# hyperparameters
x_max = 100
alpha = 0.75
embedding_dim = 50

# Initialize word vectors and bias
key = jax.random.PRNGKey(0)
v = jax.random.normal(key, (vocab_size, embedding_dim))
key, subkey = jax.random.split(key)
v_tilde = jax.random.normal(subkey, (vocab_size, embedding_dim))
bias = jnp.zeros(vocab_size)
bias_tilde = jnp.zeros(vocab_size)

In [None]:
# weight function
def f(x):
    return jax.minimum(1.0, (x/x_max)**alpha)

# loss function
def gLove_loss(v, v_tilde, bias, bias_tilde, cooccurrence_matrix):
    loss = 0.0
    for word in cooccurrence_matrix:
        for context_word in cooccurrence_matrix[word]:
            X_ij = cooccurrence_matrix[word][context_word]
            weight = f(X_ij)
            diff = jnp.dot(v[word], v_tilde[context_word]) + bias[word] + bias_tilde[context_word] - jnp.log(X_ij)
            loss += weight * diff**2
    return loss