In [15]:
import torch
from torch import nn
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt 

# Setup device agnostic code
device = "mps" if torch.backends.mps.is_available() else "cpu"
device

# Setup random seed
RANDOM_SEED = 42

In [16]:
df = pd.read_csv('../HN Score, Title 10k.csv')
df.title = df.title.str.lower()
df.dropna(inplace= True)
df.head()

Unnamed: 0,score,title
0,8.0,nasa's 3d-printed rotating detonation rocket e...
1,62.0,heat pumps of the 1800s are becoming the techn...
2,1.0,why you should develop local-first web apps
3,1.0,tool to make twitter archive publishable
4,2.0,fedora packages versus upstream flatpaks


In [33]:
words = []
num_lines = 10
lines = df.title.tolist()[:num_lines]

for i in df.title[:num_lines]:
    for j in str(i).split():
        if j not in words and j != "nan":
            words.append(j)

vocab_size = len(words)

vocab_size, len(lines)

(71, 10)

In [18]:
itos = {num:word for num, word in zip(range(len(words)),words)}
stoi = {word:num for num,word in itos.items()}

In [19]:
def create_skipgram_pairs(input_lines, context_len):
    """
    Generate Skip-Gram pairs (target, context) from a list of input lines.

    Args:
        input_lines (list of str): The input lines of text.
        context_len (int): Context window size (number of words on each side).

    Returns:
        targets (list): List of target words.
        context_words (list of lists): List of context word lists for each target.
    """
    context_words = []
    targets = []

    for line in input_lines:
        # Split the line into words (not letters)
        words = line.split()
        for i, target in enumerate(words):
            # Define context window boundaries
            start = max(0, i - context_len)
            end = min(len(words), i + context_len + 1)

            # Collect context words, excluding the target word itself
            context = [words[j] for j in range(start, end) if j != i]
            targets.append(target)
            context_words.append(context)

            # print(f"Target: {target} ; Context: {context}")

    return targets, context_words

context_len = 1

targets, context_words = create_skipgram_pairs(lines, context_len)

targets_len = len(targets)
targets_len


8216

In [20]:
# Convert targets to indices
target_indices = [stoi[target] for target in targets if target in stoi]

# Convert context words to indices
context_indices = [[stoi[context] for context in contexts if context in stoi] for contexts in context_words]


In [21]:
# Convert target indices to PyTorch tensor
X = torch.tensor(target_indices, dtype=torch.long)

# Create one-hot encoded targets (X)
# X = torch.zeros((targets_len, vocab_size))
# X.scatter_(1, target_indices.unsqueeze(1), 1)  # Scatter 1s into the appropriate indices
# X = X.long()

# Convert context indices to PyTorch tensor and one-hot encode (Y)
Y = torch.zeros((targets_len, vocab_size))
for i, context in enumerate(context_indices):
    Y[i, context] = 1  # Set 1s for all indices in the context for each target

X.shape, Y.shape

(torch.Size([8216]), torch.Size([8216, 3989]))

In [22]:
emb_dims = 20

class Word2Vec(nn.Module):
    def __init__(self, vocab_size, emb_dims):
        super().__init__()
        self.emb = nn.Embedding(num_embeddings= vocab_size, embedding_dim= emb_dims)
        self.layer2 = nn.Linear(in_features = emb_dims, out_features = vocab_size)

    def forward(self, x):
        self.out = self.layer2(self.emb(x))
        return self.out

word2vec = Word2Vec(vocab_size = vocab_size, emb_dims= emb_dims)

# W1 = torch.randn(vocab_size, emb_dims)
# W2 = torch.randn(vocab_size, emb_dims)

# parameters = [W1, W2]

In [31]:
# Setup loss function
loss_fn = nn.BCEWithLogitsLoss()

# Setup optimizer to optimize model's parameters
optimiser = torch.optim.SGD(params= word2vec.parameters(), lr = 0.5)

In [24]:
# Let's calculuate the accuracy using accuracy from TorchMetrics
# !pip -q install torchmetrics # Colab doesn't come with torchmetrics
from torchmetrics import Accuracy

## TODO: Uncomment this code to use the Accuracy function
acc_fn = Accuracy(task="multiclass", num_classes=vocab_size) # send accuracy function to device

In [32]:
torch.manual_seed(RANDOM_SEED)

# Setup epochs
epochs = 1000

# Send data to the device
# X_train, y_train = X_train.to(device), y_train.to(device)
# X_test, y_test = X_test.to(device), y_test.to(device)

# Loop through the data
for epoch in range(epochs):
  ### Training

  # 1. Forward pass (logits output)
  y_logits = word2vec(X)
  # Turn logits into prediction probabilities
  y_pred_probs = torch.sigmoid(y_logits)

  # Turn prediction probabilities into prediction labels
  y_pred_labels = torch.round(y_pred_probs)

  # 2. Calculaute the loss
  loss = loss_fn(y_logits, Y) # loss = compare model raw outputs to desired model outputs

  # Calculate the accuracy
  acc = acc_fn(y_pred_labels, Y) # the accuracy function needs to compare pred labels (not logits) with actual labels

  # 3. Zero the gradients
  optimiser.zero_grad()

  # 4. Loss backwards
  loss.backward()

  # 5. Step the optimiser
  optimiser.step()

  # Print out what's happening every 100 epochs
  if epoch % 100 == 0:
    print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {acc:.2%}")


Epoch: 0 | Loss: 0.71978, Accuracy: 51.90%
Epoch: 100 | Loss: 0.71629, Accuracy: 52.36%
Epoch: 200 | Loss: 0.71284, Accuracy: 52.83%
Epoch: 300 | Loss: 0.70941, Accuracy: 53.30%
Epoch: 400 | Loss: 0.70600, Accuracy: 53.76%
Epoch: 500 | Loss: 0.70261, Accuracy: 54.23%
Epoch: 600 | Loss: 0.69924, Accuracy: 54.69%
Epoch: 700 | Loss: 0.69590, Accuracy: 55.17%
Epoch: 800 | Loss: 0.69258, Accuracy: 55.64%
Epoch: 900 | Loss: 0.68928, Accuracy: 56.12%


In [28]:
# for p in parameters:
#     p.requires_grad = True

In [29]:
# epochs = 30000
# lossi = []

# for i in range(epochs):
#     h = X @ W1
#     logits = h @ W2.T

#     # Compute the loss
#     loss = F.binary_cross_entropy_with_logits(logits, Y)

#     # backward pass
#     for p in parameters:
#         p.grad = None
#     loss.backward()
#     # Update
#     lr = 0.1 if i <= 30000 else 0.01 if i <= 55000 else 0.001
#     for p in parameters:
#         p.data -= lr * p.grad

#     # track stats
#     if i % 1000 == 0: # print every once in a while
#         print(f'{i:7d}/{epochs:7d}: {loss.item():.4f}')
#     lossi.append(loss.item())


In [30]:
# plt.plot(torch.tensor(lossi).view(-1, 1000).mean(1))

NameError: name 'lossi' is not defined

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

def get_nearest_neighbors(word, stoi, embeddings, n=5):
    """
    Find the top-n nearest neighbors of a word in the embedding space.
    
    Args:
        word (str): The target word.
        stoi (dict): Mapping from words to indices.
        embeddings (torch.Tensor): Learned word embeddings (shape: V x d).
        n (int): Number of nearest neighbors to retrieve.
    
    Returns:
        List of tuples (neighbor_word, similarity_score).
    """
    if word not in stoi:
        return f"'{word}' not in vocabulary."
    
    word_idx = stoi[word]
    word_embedding = embeddings[word_idx].unsqueeze(0)  # Shape: 1 x d
    
    # Compute cosine similarity between the target embedding and all embeddings
    similarities = cosine_similarity(word_embedding.detach().numpy(), embeddings.detach().numpy())
    similarities = similarities[0]  # Flatten
    
    # Get top-n similar words (excluding the word itself)
    nearest_indices = similarities.argsort()[-n-1:][::-1][1:]  # Exclude the word itself
    nearest_words = [(list(stoi.keys())[idx], similarities[idx]) for idx in nearest_indices]
    return nearest_words

# Example usage
word = "rocket"
nearest_neighbors = get_nearest_neighbors(word, stoi, W1, n=5)
print(f"Nearest neighbors of '{word}':")
print(nearest_neighbors)


Nearest neighbors of 'rocket':
[('guardsman', 0.6874943), ('spreading', 0.68435574), ('solutions', 0.68200964), ('douglas', 0.66121185), ('cz', 0.61781037)]


In [None]:
words[:20]

["nasa's",
 '3d-printed',
 'rotating',
 'detonation',
 'rocket',
 'engine',
 'test',
 'a',
 'success',
 'heat',
 'pumps',
 'of',
 'the',
 '1800s',
 'are',
 'becoming',
 'technology',
 'future',
 'why',
 'you']