## Word Embeddings

Goal: Implement a simple word embedding in Python (from scratch) and use it to find the most similar words to a given word. Come up with a dataset and evaluation metrics to evaluate the word embeddings.

In [2]:
import jax
import jax.numpy as jnp
import numpy as np
from matplotlib import pyplot as plt

### Downloading NLTK's stopwords

In [33]:
import nltk
# nltk.download("stopwords") # Uncomment this line to download the stopwords
from nltk.corpus import stopwords

stopwords = set(stopwords.words('english'))
print(stopwords)

{"she's", 'herself', 'further', 're', 'whom', 'because', 'didn', 'such', 'his', 'out', 'those', 'doesn', 'm', 'all', 'be', 'each', 'theirs', 'the', 'while', 'and', "it's", 'been', 'them', 'under', 't', 'just', 'him', 'does', 'there', 'they', "hadn't", 'itself', 'o', 'over', 'below', 'than', 'if', 'what', 's', 'against', "hasn't", 'during', 'mightn', 'can', 'have', 'hadn', 'ma', 'don', 'through', 'once', 'in', 'only', "couldn't", 'why', 'yourself', 'between', "won't", 'himself', 'more', 'hasn', 'into', 'with', 'shan', 'when', 'up', 'where', "needn't", 'who', 'about', 'should', 'doing', 'to', 'both', 'he', 'off', 'some', "mustn't", 'needn', 'aren', 'so', 'an', 'had', 'y', 'few', "haven't", 'down', "shouldn't", 'ours', 'own', 'i', 'me', "didn't", 'now', 'too', 'then', 'from', 'its', 'again', "weren't", 'shouldn', 'd', 'which', 'ourselves', 'wasn', 'hers', 'or', 'will', 've', "should've", 'very', 'but', 'her', 'at', 'our', 'it', 'wouldn', 'for', 'ain', 'until', 'their', 'a', 'my', 'that', 

### Downloading Brown Corpus

The Brown Corpus was the first million-word electronic corpus of English, created in 1961 at Brown University. This corpus contains text from 500 sources, and the sources have been categorized by genre, such as news, editorial, and so on.

In [6]:
# nltk.download("brown") # Uncomment this line to download the brown corpus
from nltk.corpus import brown

In [34]:
# Show the categories of the brown corpus 
brown.categories()

['adventure',
 'belles_lettres',
 'editorial',
 'fiction',
 'government',
 'hobbies',
 'humor',
 'learned',
 'lore',
 'mystery',
 'news',
 'religion',
 'reviews',
 'romance',
 'science_fiction']

In [35]:
# Get the sentences from the "fiction" category
sentences = brown.sents(categories="fiction")

# Number of sentences
print(f"Number of sentences: {len(sentences)}")

# Print the first 5 sentences
for i in range(5):
    print(f"Sentence {i}: {sentences[i]}")

Number of sentences: 4249
Sentence 0: ['Thirty-three']
Sentence 1: ['Scotty', 'did', 'not', 'go', 'back', 'to', 'school', '.']
Sentence 2: ['His', 'parents', 'talked', 'seriously', 'and', 'lengthily', 'to', 'their', 'own', 'doctor', 'and', 'to', 'a', 'specialist', 'at', 'the', 'University', 'Hospital', '--', 'Mr.', 'McKinley', 'was', 'entitled', 'to', 'a', 'discount', 'for', 'members', 'of', 'his', 'family', '--', 'and', 'it', 'was', 'decided', 'it', 'would', 'be', 'best', 'for', 'him', 'to', 'take', 'the', 'remainder', 'of', 'the', 'term', 'off', ',', 'spend', 'a', 'lot', 'of', 'time', 'in', 'bed', 'and', ',', 'for', 'the', 'rest', ',', 'do', 'pretty', 'much', 'as', 'he', 'chose', '--', 'provided', ',', 'of', 'course', ',', 'he', 'chose', 'to', 'do', 'nothing', 'too', 'exciting', 'or', 'too', 'debilitating', '.']
Sentence 3: ['His', 'teacher', 'and', 'his', 'school', 'principal', 'were', 'conferred', 'with', 'and', 'everyone', 'agreed', 'that', ',', 'if', 'he', 'kept', 'up', 'with', '

In [36]:
# Filter out non-alphanumeric words and convert to lowercase
filtered_sentences = []

for sent in sentences:
    filtered_sent = [word.lower() for word in sent if word.isalnum()]
    filtered_sentences.append(filtered_sent)

# Print the first 5 sentences
for i in range(5):
    print(f"Filtered sentence {i}: {filtered_sentences[i]}")

Filtered sentence 0: []
Filtered sentence 1: ['scotty', 'did', 'not', 'go', 'back', 'to', 'school']
Filtered sentence 2: ['his', 'parents', 'talked', 'seriously', 'and', 'lengthily', 'to', 'their', 'own', 'doctor', 'and', 'to', 'a', 'specialist', 'at', 'the', 'university', 'hospital', 'mckinley', 'was', 'entitled', 'to', 'a', 'discount', 'for', 'members', 'of', 'his', 'family', 'and', 'it', 'was', 'decided', 'it', 'would', 'be', 'best', 'for', 'him', 'to', 'take', 'the', 'remainder', 'of', 'the', 'term', 'off', 'spend', 'a', 'lot', 'of', 'time', 'in', 'bed', 'and', 'for', 'the', 'rest', 'do', 'pretty', 'much', 'as', 'he', 'chose', 'provided', 'of', 'course', 'he', 'chose', 'to', 'do', 'nothing', 'too', 'exciting', 'or', 'too', 'debilitating']
Filtered sentence 3: ['his', 'teacher', 'and', 'his', 'school', 'principal', 'were', 'conferred', 'with', 'and', 'everyone', 'agreed', 'that', 'if', 'he', 'kept', 'up', 'with', 'a', 'certain', 'amount', 'of', 'work', 'at', 'home', 'there', 'was', 

In [37]:
# Filter out stopwords and non-alphanumeric words and convert to lowercase
filtered_stopwords_sentences = []

for sent in filtered_sentences:
    filtered_sent = [word.lower() for word in sent if word.isalnum() and word.lower() not in stopwords]
    filtered_stopwords_sentences.append(filtered_sent)

# Print the first 5 sentences
for i in range(5):
    print(f"Filtered sentence {i}: {filtered_stopwords_sentences[i]}")

Filtered sentence 0: []
Filtered sentence 1: ['scotty', 'go', 'back', 'school']
Filtered sentence 2: ['parents', 'talked', 'seriously', 'lengthily', 'doctor', 'specialist', 'university', 'hospital', 'mckinley', 'entitled', 'discount', 'members', 'family', 'decided', 'would', 'best', 'take', 'remainder', 'term', 'spend', 'lot', 'time', 'bed', 'rest', 'pretty', 'much', 'chose', 'provided', 'course', 'chose', 'nothing', 'exciting', 'debilitating']
Filtered sentence 3: ['teacher', 'school', 'principal', 'conferred', 'everyone', 'agreed', 'kept', 'certain', 'amount', 'work', 'home', 'little', 'danger', 'losing', 'term']
Filtered sentence 4: ['scotty', 'accepted', 'decision', 'indifference', 'enter', 'arguments']


### GloVe: Global Vectors for Word Representation

In [None]:
corpus = brown.sents(categories="news")

In [None]:
# lower case all words, remove all punctuations and stop words
def preprocess_sentence(sentence):
    return [word.lower() for word in sentence if word.lower() not in stopwords and word not in punctuation]

corpus_preprocessed = [preprocess_sentence(sentence) for sentence in corpus]
corpus_preprocessed[0]

Building Co-occurrence Matrix

In [None]:
from collections import defaultdict
import numpy as np

# Build vocabulary index
def build_vocab_idx(corpus_preprocessed):
    vocab_count = defaultdict(int)
    for sentence in corpus_preprocessed:
        for word in sentence:
            vocab_count[word] += 1
    return {word: idx for idx, (word, _) in enumerate(vocab_count.items())}, vocab_count

vocab_idx, vocab_count = build_vocab_idx(corpus_preprocessed)
vocab_size = len(vocab_idx)

# Window size for Co-occurrence matrix - 2 words before, 2 words after
window_size = 2

# Build co-occurrence matrix
def build_cooccurrence_matrix(corpus_preprocessed, vocab_idx, window_size):
    cooccurrence_matrix = defaultdict(lambda: defaultdict(float))
    
    for sentence in corpus_preprocessed:
        sentence_length = len(sentence)
        for i, word in enumerate(sentence):
            word_idx = vocab_idx[word]
            # Context window
            start = max(0, i - window_size)
            end = min(sentence_length, i + window_size + 1)
            
            for j in range(start, end):
                # Skip target word
                if i != j:
                    context_word = sentence[j]
                    context_word_idx = vocab_idx[context_word]
                    # Increment the co-occurrence count with inverse distance weighting
                    cooccurrence_matrix[word_idx][context_word_idx] += 1.0 / abs(i - j)
    return cooccurrence_matrix

cooccurrence_matrix = build_cooccurrence_matrix(corpus_preprocessed, vocab_idx, window_size)
cooccurrence_matrix[0]

In [None]:
# hyperparameters
x_max = 100
alpha = 0.75
embedding_dim = 50

# Initialize word vectors and bias
key = jax.random.PRNGKey(0)
v = jax.random.normal(key, (vocab_size, embedding_dim))
key, subkey = jax.random.split(key)
v_tilde = jax.random.normal(subkey, (vocab_size, embedding_dim))
bias = jnp.zeros(vocab_size)
bias_tilde = jnp.zeros(vocab_size)

In [None]:
# weight function
def f(x):
    return jax.minimum(1.0, (x/x_max)**alpha)

# loss function
def gLove_loss(v, v_tilde, bias, bias_tilde, cooccurrence_matrix):
    loss = 0.0
    for word in cooccurrence_matrix:
        for context_word in cooccurrence_matrix[word]:
            X_ij = cooccurrence_matrix[word][context_word]
            weight = f(X_ij)
            diff = jnp.dot(v[word], v_tilde[context_word]) + bias[word] + bias_tilde[context_word] - jnp.log(X_ij)
            loss += weight * diff**2
    return loss

### Word2Vec Word Embedding

#### Skip-gram model

In Skip-Gram, the goal is to predict the context words $w_c$ given a target word $w_t$. 

##### Skip-Gram with Negative Sampling (SGNS)

Negative sampling simplifies this process by approximating the softmax function, 
focusing only on a few context words (positive samples) and a few randomly selected words from the vocabulary (negative samples) instead of the entire vocabulary.

Loss Function: \
The loss function for Skip-Gram with Negative Sampling (SGNS) is typically a binary cross-entropy loss. For each target word $w_t$ and its context words $w_c$, the loss function tries to:
- Maximize the probability that $w_c$ appears in the context of $w_t$
- Minimize the probability that randomly selected (negative) words appear in the context

The loss function is
$$ L = - \left[ \sum_{w_c \in \text{context}} \log \sigma \left( \langle v_{w_c}, v_{w_t} \rangle \right) 
             +  \sum_{w_n \in \text{negatives}} \log \sigma \left( \langle -v_{w_n}, v_{w_t} \rangle \right) \right] $$

where 
- $v_{w_t}$ is the embedding vector of target word $w_t$
- $v_{w_c}$ is the embedding vector of context word $w_c$
- $v_{w_n}$ is the embedding vector of negative word $w_n$ 
- $\sigma(x)$ is the sigmoid function $\sigma(x) = \frac{1}{1 + e^{-x}}$ 

#####  Skip-Gram with Hierarchical Softmax

Hierarchical Softmax is an alternative to Negative Sampling, where the softmax function is approximated using a binary tree. The loss function is structured similarly to SGNS but involves traversing the binary tree structure and computing a conditional probability at each node.

In this case, the loss function minimizes the difference between the predicted probability and the actual label for each word pair, traversing the tree for each prediction. The key idea is to reduce the computational complexity of softmax when dealing with a large vocabulary.

#### Continuous Bag of Words (CBOW) model

In CBOW, the objective is to predict the target word ​$w_t$ given the context words $w_c$. 

#####  CBOW with Negative Sampling


The model takes the average of the context word embeddings and tries to predict the target word. The loss function for CBOW with Negative Sampling is also binary cross-entropy but formulated slightly differently due to the averaging of context word embeddings.

The loss function is
$$ L = - \left[ \log \sigma \left( \langle v_{w_t}, v_C \rangle \right) 
             +  \sum_{w_n \in \text{negatives}} \log \sigma \left( \langle -v_{w_n}, v_C \rangle \right) \right] $$  
where 
$$ v_C = \frac{1}{|C|} \sum_{w_c \in C} v_{w_c} $$

##### CBOW with Hierarchical Softmax

Like in Skip-Gram, CBOW can also use hierarchical softmax as an alternative to Negative Sampling, traversing a binary tree to approximate the softmax function.