# Skipgram Word Embeddings on the Simpsons Dataset

## Load the Simpsons Dialogues Dataset 

In [1]:
import pandas as pd
df = pd.read_csv("../datasets/simpsons_dataset.csv")
sentences = df["spoken_words"].tolist()
sentences = [ s for s in sentences if type(s) == str]

### Tokenize and Lemmatize Sentences

In [2]:
import nltk

nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('stopwords')

from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize

from nltk.corpus import wordnet
from nltk.corpus import stopwords
import string

lemmatizer = WordNetLemmatizer()

# Generate a set of stopwords to remove
stop_words = set(stopwords.words('english'))
stop_words = stop_words.union(set(["'s", "n't", "'m", "'re", "'ll", "'d"]))
punctuation = set(string.punctuation)
stop_words = stop_words.union(punctuation)
stop_words = stop_words.union(set(["--", "..", "''", "...", "``"]))
digits = set([str(n) for n in range(10)])
stop_words = stop_words.union(digits)

# Map the NLTK/Treebank pos tags onto Wordnet tags
def get_wordnet_pos(treebank_tag):
    if treebank_tag.startswith('J'):
        return wordnet.ADJ
    elif treebank_tag.startswith('V'):
        return wordnet.VERB
    elif treebank_tag.startswith('N'):
        return wordnet.NOUN
    elif treebank_tag.startswith('R'):
        return wordnet.ADV
    else:
        return wordnet.NOUN  # Default to noun if the tag is unknown

# Generate cleaned sentences
clean_sentences = []
clean_tokens = []

for s in sentences:
  tokens = word_tokenize(s)
  tokens = [t for t in tokens if t.lower() not in stop_words]
  tagged_tokens = nltk.pos_tag(tokens)
  tagged_tokens = [ (w, get_wordnet_pos(t)) for w,t in tagged_tokens ]
  lemma_words = [lemmatizer.lemmatize(w.lower(), get_wordnet_pos(t)) for w,t in tagged_tokens]
  clean_sentences.append(lemma_words)
  clean_tokens += lemma_words

[nltk_data] Downloading package wordnet to /Users/vikram/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /Users/vikram/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package punkt to /Users/vikram/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /Users/vikram/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/vikram/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


### Generate the Vocabulary

In [3]:
from collections import Counter

VOCAB_SIZE=6000

counts = Counter(clean_tokens).items()
sorted_counts = sorted(counts, key=lambda k: k[1], reverse=True)
vocab = sorted_counts[:VOCAB_SIZE]

In [4]:
vocab[:10]

[('oh', 8125),
 ('like', 6131),
 ('well', 6037),
 ('get', 5511),
 ('one', 4947),
 ('know', 4914),
 ("'ve", 4664),
 ('got', 4612),
 ('hey', 4277),
 ('homer', 4232)]

In [5]:
idx_to_tok = dict()
tok_to_idx = dict()
for idx,(tok, count) in enumerate(vocab):
  idx_to_tok[idx] = tok
  tok_to_idx[tok] = idx

### Generate Positive Examples

In [6]:
import jax.numpy as jnp

WINDOW_SIZE=6

len_buffer = WINDOW_SIZE//2


def gen_dataset():
  dataset = []
  for s in clean_sentences:
    for i in range(len(s)):
      for j in range(max(0,i-len_buffer), min(len(s),i+len_buffer)):
        if i != j:
            if s[i] in tok_to_idx and s[j] in tok_to_idx:
              idx_i = tok_to_idx[s[i]]
              idx_j = tok_to_idx[s[j]]
              dataset.append((idx_i,idx_j))
  return dataset


dataset = jnp.array(gen_dataset())
print(dataset.shape)

(1881262, 2)


### Train the Model

In [9]:
import logging

import jax
import xjax
from xjax.signals import train_epoch_completed

# Module logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Generate a random seed
seed = 42
rng = jax.random.PRNGKey(seed)

# Model Params
EMBEDDING_SIZE=100
K=50
NEG_PER_POS = 20

# Train Params
BATCH_SIZE = 100000
NUM_ITER = len(dataset)//BATCH_SIZE
NUM_EPOCHS = 50
LR = 0.03

# Create the model and initialize params
model, params = xjax.models.sgns.sgns(rng=rng, vocab_size=VOCAB_SIZE, 
                             embedding_size=EMBEDDING_SIZE)


#Log events
@train_epoch_completed.connect_via(model)
def collect_events(_, *, epoch, loss, elapsed, **__):
    logger.info(f"epoch={epoch}, loss={loss:0.4f}, elapsed={elapsed:0.4f}")

# Train
params = xjax.models.sgns.train(model, rng=rng, params=params,
                               X=dataset,
                               neg_per_pos=NEG_PER_POS,
                               K=K,
                               epochs=NUM_EPOCHS,
                               batch_size=BATCH_SIZE,
                               learning_rate=LR)


INFO:__main__:epoch=0, loss=92.3090, elapsed=7.0175
INFO:__main__:epoch=1, loss=6.4165, elapsed=13.6926
INFO:__main__:epoch=2, loss=5.2345, elapsed=20.4883
INFO:__main__:epoch=3, loss=4.2222, elapsed=27.1027
INFO:__main__:epoch=4, loss=3.7006, elapsed=33.8759
INFO:__main__:epoch=5, loss=3.3958, elapsed=40.5573
INFO:__main__:epoch=6, loss=3.1780, elapsed=47.8143
INFO:__main__:epoch=7, loss=3.0147, elapsed=54.4782
INFO:__main__:epoch=8, loss=2.9175, elapsed=61.1838
INFO:__main__:epoch=9, loss=2.8482, elapsed=67.8347
INFO:__main__:epoch=10, loss=2.7969, elapsed=74.5253
INFO:__main__:epoch=11, loss=2.7111, elapsed=81.1747
INFO:__main__:epoch=12, loss=2.6632, elapsed=87.9641
INFO:__main__:epoch=13, loss=2.6241, elapsed=94.5831
INFO:__main__:epoch=14, loss=2.5984, elapsed=101.1758
INFO:__main__:epoch=15, loss=2.5634, elapsed=107.8927
INFO:__main__:epoch=16, loss=2.5412, elapsed=114.5396
INFO:__main__:epoch=17, loss=2.5138, elapsed=121.3133
INFO:__main__:epoch=18, loss=2.4846, elapsed=127.981

### Inspect the trained embeddings

In [10]:
def find_most_similar(params, word, n=10):
    if word not in tok_to_idx:
        raise ValueError(f"Word '{word}' not found in word vectors dictionary.")

    # Get the vector for the selected word
    emb1 = params[tok_to_idx[word], :, 0]

    # Calculate cosine similarities with all other words
    similarities = {}
    for other_word, other_idx in tok_to_idx.items():
        if other_word != word:
            emb2 = params[other_idx, :, 0]
            similarity = jnp.dot(emb1,emb2)/(jnp.linalg.norm(emb1)*jnp.linalg.norm(emb2))
            similarities[other_word] = similarity

    # Sort by similarity and return the top n words
    most_similar = sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:n]
    return most_similar


In [11]:
find_most_similar(params, "springfield")

[('elementary', Array(0.7011735, dtype=float32)),
 ('town', Array(0.62027335, dtype=float32)),
 ('city', Array(0.61156374, dtype=float32)),
 ('school', Array(0.5965495, dtype=float32)),
 ('today', Array(0.59161603, dtype=float32)),
 ('first', Array(0.5809512, dtype=float32)),
 ('country', Array(0.5762146, dtype=float32)),
 ('simpson', Array(0.55961263, dtype=float32)),
 ('history', Array(0.5421662, dtype=float32)),
 ('world', Array(0.53902626, dtype=float32))]