In [None]:
"""A tiny word2vec model for learning."""

In [1]:
import collections 
import math 
import os 
import random 
import torch
from nlp import Vocab

In [10]:
with open("../data/ptb/ptb.train.txt") as f:
    raw_text = f.read()
words_in_sentences = [line.split() for line in raw_text.split('\n')]
words = [token for line in words_in_sentences for token in line]

In [11]:
vocab = Vocab(words, min_freq=10)
len(vocab)

6719

## Subsampling

In [12]:
T = 1e-4

In [21]:
def discard_probability(t: float, freq: int, num_tokens: int) -> float:
    """
    Calculates the probability for this word to be discarded. 

    Parameters
    ----------
    t : float
        Hyperparameter to adjust for subsampling.
    freq : int
        Frequency of the word in the corpus.
    num_tokens : int
        Total number of tokens in the corpus.

    Returns
    -------
    float
        The probability for discarding this word.
    """
    return max(1 - math.sqrt(t / (freq / num_tokens)), 0)

def keep(prob: float) -> bool:
    """
    Returns True if this word is kept under the roll of a imaginary dice.

    Parameters
    ----------
    prob : float
        Probability for keeping

    Returns
    -------
    bool
        To keep or not to keep :)
    """
    rand = random.uniform(0, 1)
    return rand < prob

def subsample(words: list[str], unk: str) -> tuple[list[str], collections.Counter]:
    """
    Subsamples the words in the vocabulary according to their frequencies.

    Parameters
    ----------
    words : list[str]
        All the words in the corpus
    unk : str
        The <unk> token in this case

    Returns
    -------
    tuple[list[str], collections.Counter]
        The subsampled words and the counter
    """
    words_not_unk = [word for word in words if word != unk]
    counter = collections.Counter(words_not_unk)
    subsampled = []
    for word in words_not_unk:
        if keep(discard_probability(T, counter[word], sum(counter.values()))):
            subsampled.append(word)
    return subsampled, counter

In [22]:
subsampled, counter = subsample(words, vocab)