# **word2vec на PyTorch**

Идея, лежащая в основе [word2vec](https://arxiv.org/pdf/1310.4546), достаточно общая. Здесь мы попробуем реализовать его самостоятельно.

Реализация от `gensim` (или аналоги), разумеется, обучается быстрее и работает точнее. Она использует множество доработок и ускорений, а также достаточно эффективный код. А мы хотим лишь добиться промежуточных результатов за разумное время.

__Requirements:__ if you're running locally, in the selected environment run the following command:

```pip install --upgrade nltk bokeh umap-learn```


In [1]:
!pip install --upgrade nltk bokeh umap-learn -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m20.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m35.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.8/88.8 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.9/56.9 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
panel 1.4.5 requires bokeh<3.5.0,>=3.4.0, but you have bokeh 3.6.0 which is incompatible.[0m[31m
[0m

In [2]:
import itertools
import random
import string
from collections import Counter
from itertools import chain

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import umap
from IPython.display import clear_output
from matplotlib import pyplot as plt
from nltk.tokenize import WordPunctTokenizer
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from tqdm.auto import tqdm as tqdma
from tqdm import tqdm

In [6]:
# download the data:
# !wget https://www.dropbox.com/s/obaitrix9jyu84r/quora.txt?dl=1 -O ./quora.txt -nc
# alternative download link: https://yadi.sk/i/BPQrUu1NaTduEw
# !wget https://yadi.sk/i/BPQrUu1NaTduEw -O ./quora.txt -nc

In [3]:
data = list(open("./quora.txt", encoding="utf-8"))
data[50]

"What TV shows or books help you read people's body language?\n"

Токенизация – первый шаг.
Тексты, с которыми мы работаем, включают в себя пунктуацию, смайлики и прочие нестандартные токены, так что простой `str.split` не подойдет.

Обратимся к `nltk` – популярная библиотека, нашедшая широкое применение в области NLP.

In [4]:
tokenizer = WordPunctTokenizer()

print(tokenizer.tokenize(data[50]))

['What', 'TV', 'shows', 'or', 'books', 'help', 'you', 'read', 'people', "'", 's', 'body', 'language', '?']


In [5]:
data_tok = [
    tokenizer.tokenize(
        line.translate(str.maketrans("", "", string.punctuation)).lower()
    )
    for line in tqdm(data)
]
data_tok = [x for x in data_tok if len(x) >= 3]

100%|██████████| 537272/537272 [00:05<00:00, 90161.98it/s] 


Несколько проверок:

In [6]:
assert all(
    isinstance(row, (list, tuple)) for row in data_tok
), "please convert each line into a list of tokens (strings)"
assert all(
    all(isinstance(tok, str) for tok in row) for row in data_tok
), "please convert each line into a list of tokens (strings)"
is_latin = lambda tok: all("a" <= x.lower() <= "z" for x in tok)
assert all(
    map(lambda l: not is_latin(l) or l.islower(), map(" ".join, data_tok))
), "please make sure to lowercase the data"

Ниже заданы константы ширины окна контекста и проведена предобработка для построения skip-gram модели.

In [7]:
min_count = 5
window_radius = 5

In [8]:
vocabulary_with_counter = Counter(chain.from_iterable(data_tok))

word_count_dict = dict()
for word, counter in vocabulary_with_counter.items():
    if counter >= min_count:
        word_count_dict[word] = counter

vocabulary = set(word_count_dict.keys())
del vocabulary_with_counter

In [9]:
word_to_index = {word: index for index, word in enumerate(vocabulary)}
index_to_word = {index: word for word, index in word_to_index.items()}

Пары `(слово, контекст)` на основе доступного датасета сгенерированы ниже.

In [10]:
context_pairs = []

for text in tqdma(data_tok):
    for i, central_word in enumerate(text):
        context_indices = range(
            max(0, i - window_radius), min(i + window_radius, len(text))
        )
        for j in context_indices:
            if j == i:
                continue
            context_word = text[j]
            if central_word in vocabulary and context_word in vocabulary:
                context_pairs.append(
                    (word_to_index[central_word], word_to_index[context_word])
                )

print(f"Generated {len(context_pairs)} pairs of target and context words.")

  0%|          | 0/537174 [00:00<?, ?it/s]

Generated 40220313 pairs of target and context words.


## subsampling

Для того, чтобы сгладить разницу в частоте встречаемсости слов, необходимо реализовать механизм subsampling'а.<br>
Для этого необходимо реализовать функцию ниже.

Вероятность **исключить** слово из обучения (на фиксированном шаге) вычисляется как
$$
P_\text{drop}(w_i)=1 - \sqrt{\frac{t}{f(w_i)}},
$$
где $f(w_i)$ – нормированная частота встречаемости слова, а $t$ – заданный порог (threshold).

In [11]:
def subsample_frequent_words(word_count_dict, threshold=1e-5):
    """
    Calculates the subsampling probabilities for
    words based on their frequencies.

    This function is used to determine the
    probability of keeping a word in the dataset
    when subsampling frequent words. The method
    used is inspired by the subsampling approach
    in Word2Vec, where each word's frequency
    affects its probability of being kept.

    Parameters:
    - word_count_dict (dict): A dictionary where
      keys are words and values are the counts of those words.
    - threshold (float, optional): A threshold
      parameter used to adjust the frequency of word subsampling.
      Defaults to 1e-5.

    Returns:
    - dict: A dictionary where keys are words
      and values are the probabilities of keeping each word.

    Example:
    >>> word_counts = {'the': 5000, 'is': 1000, 'apple': 50}
    >>> subsample_frequent_words(word_counts)
    {'the': 0.028, 'is': 0.223, 'apple': 1.0}
    """
    total_words = sum(word_count_dict.values())
    probs = {}

    for word, count in word_count_dict.items():
        freq = count / total_words
        # Calculate the probability of keeping the word
        prob = (threshold / freq) ** 0.5 + (threshold / freq)
        # Ensure probability is not negative
        probs[word] = prob

    return probs

In [12]:
test_wcd = {'the': 5000, 'is': 1000, 'apple': 50}
subsample_frequent_words(test_wcd, threshold=1e-5)

{'the': 0.0034906054261852177,
 'is': 0.007838674593052023,
 'apple': 0.03599505426185218}

## negative sampling

Для более эффективного обучения необходимо не только предсказывать высокие вероятности для слов из контекста, но и предсказывать низкие для слов, не встреченных в контексте.<br>Для этого необходимо вычислить вероятност использовать слово в качестве negative sample, реализовав функцию ниже.

В оригинальной статье предлагается оценивать вероятность слов выступать в качестве negative sample согласно распределению $P_n(w)$
$$
P_n(w) = \frac{U(w)^{3/4}}{Z},
$$

где $U(w)$ распределение слов по частоте (или, как его еще называют, по униграммам), а $Z$ – нормировочная константа, чтобы общая мера была равна $1$.

In [13]:
def get_negative_sampling_prob(word_count_dict):
    """
    Calculates the negative sampling probabilities
    for words based on their frequencies.

    This function adjusts the frequency of each word
    raised to the power of 0.75, which is
    commonly used in algorithms like Word2Vec to
    moderate the influence of very frequent words.
    It then normalizes these adjusted frequencies
    to ensure they sum to 1, forming a probability
    distribution used for negative sampling.

    Parameters:
    - word_count_dict (dict): A dictionary where
      keys are words and values are the counts of those words.

    Returns:
    - dict: A dictionary where keys are words and
            values are the probabilities of selecting each word
            for negative sampling.

    Example:
    >>> word_counts = {'the': 5000, 'is': 1000, 'apple': 50}
    >>> get_negative_sampling_prob(word_counts)
    {'the': 0.298, 'is': 0.160, 'apple': 0.042}
    """

    # Step 1: Adjust frequencies by raising to the power of 0.75
    adjusted_frequencies = {
        word: count ** 0.75 for word, count in word_count_dict.items()
        }

    # Step 2: Calculate the total of adjusted frequencies
    total_adjusted = sum(adjusted_frequencies.values())

    # Step 3: Normalize to get probabilities
    probabilities = {
        word: freq / total_adjusted for word, freq in adjusted_frequencies.items()
        }

    return probabilities

In [14]:
get_negative_sampling_prob(test_wcd)

{'the': 0.751488398196177,
 'is': 0.2247474520689081,
 'apple': 0.023764149734914898}

Для удобства, преобразуем полученные словари в массивы (т.к. все слова все равно уже пронумерованы).

In [15]:
keep_prob_dict = subsample_frequent_words(word_count_dict)
assert keep_prob_dict.keys() == word_count_dict.keys()

In [16]:
negative_sampling_prob_dict = get_negative_sampling_prob(word_count_dict)
assert negative_sampling_prob_dict.keys() == negative_sampling_prob_dict.keys()
assert np.allclose(sum(negative_sampling_prob_dict.values()), 1)

In [17]:
keep_prob_array = np.array(
    [keep_prob_dict[index_to_word[idx]] for idx in range(len(word_to_index))]
)
negative_sampling_prob_array = np.array(
    [
        negative_sampling_prob_dict[index_to_word[idx]]
        for idx in range(len(word_to_index))
    ]
)

Если все прошло успешно, функция ниже поможет с генерацией батчей.

In [None]:
def generate_batch_with_neg_samples(
    context_pairs,
    batch_size,
    keep_prob_array,
    word_to_index,
    num_negatives,
    negative_sampling_prob_array,
):
    centers_negs = context_pairs[np.random.randint(len(context_pairs), size=batch_size), :]
    keep_mask = np.random.random() < keep_prob_array[centers_negs[:, 0]]
    valid = centers_negs[keep_mask]

    while len(valid) < batch_size:
        add_centers_negs = context_pairs[np.random.randint(len(context_pairs), size=batch_size), :]
        keep_mask = np.random.random() < keep_prob_array[add_centers_negs[:, 0]]
        valid = np.concatenate((valid, add_centers_negs[keep_mask]))

    batch = valid[:batch_size]

    neg_samples = np.random.choice(
        range(len(negative_sampling_prob_array)),
        size=(batch_size, num_negatives),
        p=negative_sampling_prob_array
    )
    return batch, neg_samples

In [None]:
context_pairs = np.array(context_pairs)
batch_size = 1024
num_negatives = 15
batch, neg_samples = generate_batch_with_neg_samples(
    context_pairs,
    batch_size,
    keep_prob_array,
    word_to_index,
    num_negatives,
    negative_sampling_prob_array,
)

Наконец, время реализовать модель. Использование линейных слоев (`nn.Linear`) далеко не всегда оправданно, поэтому мы обойдемся эмбеддингами.

В случае negative sampling решается задача максимизации следующего функционала:

$$
\mathcal{L} = \log \sigma({\mathbf{v}'_{w_O}}^\top \mathbf{v}_{w_I}) + \sum_{i=1}^{k} \mathbb{E}_{w_i \sim P_n(w)} \left[ \log \sigma({-\mathbf{v}'_{w_i}}^\top \mathbf{v}_{w_I}) \right],
$$

где:
- $\mathbf{v}_{w_I}$ – вектор центрального слова $w_I$,
- $\mathbf{v}'_{w_O}$ – вектор слова из контекста $w_O$,
- $k$ – число negative samplesЮ,
- $P_n(w)$ – распределение negative samples, заданное выше,
- $\sigma$ – сигмоида.

In [None]:
class SkipGramModelWithNegSampling(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SkipGramModelWithNegSampling, self).__init__()
        self.center_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.context_embeddings = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, center_words, pos_context_words, neg_context_words):
        # Получаем эмбеддинги для центральных и положительных контекстных слов
        center_embeds = self.center_embeddings(center_words)
        pos_context_embeds = self.context_embeddings(pos_context_words)
        # Вычисляем положительные оценки (скалярное произведение)
        pos_scores = torch.mul(center_embeds, pos_context_embeds)
        pos_scores = torch.sum(pos_scores, dim=1)
        # Получаем эмбеддинги для отрицательных контекстных слов
        neg_context_embeds = self.context_embeddings(neg_context_words)
        # Вычисляем отрицательные оценки (скалярное произведение)
        neg_scores = torch.bmm(
            neg_context_embeds, center_embeds.unsqueeze(2)
            ).squeeze()

        return pos_scores, neg_scores

In [95]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [96]:
vocab_size = len(word_to_index)
embedding_dim = 32
num_negatives = 15

model = SkipGramModelWithNegSampling(vocab_size, embedding_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.05)
lr_scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=150)
criterion = nn.BCEWithLogitsLoss()

In [97]:
params_counter = 0
for weights in model.parameters():
    params_counter += weights.shape.numel()
assert params_counter == len(word_to_index) * embedding_dim * 2

In [None]:
def train_skipgram_with_neg_sampling(
    model,
    context_pairs,
    keep_prob_array,
    word_to_index,
    batch_size,
    num_negatives,
    negative_sampling_prob_array,
    steps,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    device=device,
):
    pos_labels = torch.ones(batch_size).to(device)
    neg_labels = torch.zeros(batch_size, num_negatives).to(device)
    loss_history = []
    for step in tqdma(range(steps)):
        batch, neg_samples = generate_batch_with_neg_samples(
            context_pairs,
            batch_size,
            keep_prob_array,
            word_to_index,
            num_negatives,
            negative_sampling_prob_array,
        )
        center_words = torch.tensor([pair[0] for pair in batch], dtype=torch.long).to(
            device
        )
        pos_context_words = torch.tensor(
            [pair[1] for pair in batch], dtype=torch.long
        ).to(device)
        neg_context_words = torch.tensor(neg_samples, dtype=torch.long).to(device)

        optimizer.zero_grad()
        pos_scores, neg_scores = model(
            center_words, pos_context_words, neg_context_words
        )

        loss_pos = criterion(pos_scores, pos_labels)
        loss_neg = criterion(neg_scores, neg_labels)

        loss = loss_pos + loss_neg
        loss.backward()
        optimizer.step()

        loss_history.append(loss.item())
        lr_scheduler.step(loss_history[-1])

        if step % 100 == 0:
            print(
                f"Step {step}, Loss: {np.mean(loss_history[-100:])}, learning rate: {lr_scheduler._last_lr}"
            )

In [98]:
steps = 2500
batch_size = 4096
train_skipgram_with_neg_sampling(
    model,
    context_pairs,
    keep_prob_array,
    word_to_index,
    batch_size,
    num_negatives,
    negative_sampling_prob_array,
    steps, optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    device=device,
)

  0%|          | 0/2500 [00:00<?, ?it/s]

Step 0, Loss: 4.692890167236328, learning rate: [0.05]
Step 100, Loss: 2.6268701016902924, learning rate: [0.05]
Step 200, Loss: 1.6126753771305085, learning rate: [0.05]
Step 300, Loss: 1.4328165531158448, learning rate: [0.05]
Step 400, Loss: 1.3844779455661773, learning rate: [0.05]
Step 500, Loss: 1.3582147204875945, learning rate: [0.05]
Step 600, Loss: 1.3493667829036713, learning rate: [0.05]
Step 700, Loss: 1.334305877685547, learning rate: [0.05]
Step 800, Loss: 1.3263539361953736, learning rate: [0.05]
Step 900, Loss: 1.326192138195038, learning rate: [0.05]
Step 1000, Loss: 1.3200436007976533, learning rate: [0.05]
Step 1100, Loss: 1.2835122752189636, learning rate: [0.025]
Step 1200, Loss: 1.22152615070343, learning rate: [0.025]
Step 1300, Loss: 1.1950163865089416, learning rate: [0.025]
Step 1400, Loss: 1.1762350702285767, learning rate: [0.025]
Step 1500, Loss: 1.161331182718277, learning rate: [0.025]
Step 1600, Loss: 1.1554627478122712, learning rate: [0.025]
Step 1700

Наконец, воспользуемся полученной матрицей весов в качестве матрицы с векторными представлениями слов.

In [99]:
_model_parameters = model.parameters()
embedding_matrix_center = next(
    _model_parameters
).detach()  # Assuming that first matrix was for central word
embedding_matrix_context = next(
    _model_parameters
).detach()  # Assuming that second matrix was for context word

In [100]:
def get_word_vector(word, embedding_matrix, word_to_index=word_to_index):
    return embedding_matrix[word_to_index[word]]

Простые проверки:

In [101]:
similarity_1 = F.cosine_similarity(
    get_word_vector("iphone", embedding_matrix_context)[None, :],
    get_word_vector("apple", embedding_matrix_context)[None, :],
)
similarity_2 = F.cosine_similarity(
    get_word_vector("iphone", embedding_matrix_context)[None, :],
    get_word_vector("dell", embedding_matrix_context)[None, :],
)
print(f'sim1: {similarity_1}; sim2: {similarity_2}')

try:
    assert similarity_1 > similarity_2
    print('Correct!')
    print('similarity_1 greater than similarity_2')
except AssertionError:
    print('ALARM!')
    print('Assertion failed: similarity_1 is not greater than similarity_2')

sim1: tensor([0.7109], device='cuda:0'); sim2: tensor([0.6677], device='cuda:0')
Correct!
similarity_1 greater than similarity_2


In [102]:
similarity_1 = F.cosine_similarity(
    get_word_vector("windows", embedding_matrix_context)[None, :],
    get_word_vector("laptop", embedding_matrix_context)[None, :],
)
similarity_2 = F.cosine_similarity(
    get_word_vector("windows", embedding_matrix_context)[None, :],
    get_word_vector("macbook", embedding_matrix_context)[None, :],
)
print(f'sim1: {similarity_1}; sim2: {similarity_2}')

try:
    assert similarity_1 > similarity_2
    print('Correct!')
    print('similarity_1 greater than similarity_2')
except AssertionError:
    print('ALARM!')
    print('Assertion failed: similarity_1 is not greater than similarity_2')

sim1: tensor([0.7801], device='cuda:0'); sim2: tensor([0.6731], device='cuda:0')
Correct!
similarity_1 greater than similarity_2


Наконец, взглянем на ближайшие по косинусной мере слова.

In [103]:
def find_nearest(word, embedding_matrix, word_to_index=word_to_index, k=10):
    word_vector = get_word_vector(word, embedding_matrix)[None, :]
    dists = F.cosine_similarity(embedding_matrix, word_vector)
    index_sorted = torch.argsort(dists)
    top_k = index_sorted[-k:]
    return [(index_to_word[x], dists[x].item()) for x in top_k.cpu().numpy()]

In [104]:
find_nearest("python", embedding_matrix_context, k=20)

[('unigraphics', 0.8048782348632812),
 ('expressjs', 0.8067203760147095),
 ('hybris', 0.8082059621810913),
 ('sql', 0.8087780475616455),
 ('ide', 0.8109314441680908),
 ('applet', 0.8153489232063293),
 ('nodejs', 0.8153706789016724),
 ('angularjs', 0.8171805739402771),
 ('scripting', 0.8192536234855652),
 ('unix', 0.8192691802978516),
 ('c', 0.8199998140335083),
 ('gui', 0.8220341801643372),
 ('js', 0.8269432783126831),
 ('learning', 0.833995521068573),
 ('plsql', 0.834445595741272),
 ('programming', 0.8370746374130249),
 ('framework', 0.8375260829925537),
 ('css3', 0.8380035161972046),
 ('java', 0.8711803555488586),
 ('python', 1.0)]

Также можем визуально проверить, как представлены в латентном пространстве часто встречающиеся слова.

In [105]:
top_k = 5000
_top_words = sorted([x for x in word_count_dict.items()], key=lambda x: x[1])[
    -top_k - 100 : -100
]  # ignoring 100 most frequent words
top_words = [x[0] for x in _top_words]
del _top_words

In [106]:
word_embeddings = torch.cat(
    [embedding_matrix_context[word_to_index[x]][None, :] for x in top_words], dim=0
).cpu().numpy()

In [107]:
import bokeh.models as bm
import bokeh.plotting as pl
from bokeh.io import output_notebook

output_notebook()


def draw_vectors(
    x,
    y,
    radius=10,
    alpha=0.25,
    color="blue",
    width=600,
    height=400,
    show=True,
    **kwargs,
):
    """draws an interactive plot for data points with auxilirary info on hover"""
    if isinstance(color, str):
        color = [color] * len(x)
    data_source = bm.ColumnDataSource({"x": x, "y": y, "color": color, **kwargs})

    fig = pl.figure(active_scroll="wheel_zoom", width=width, height=height)
    fig.scatter("x", "y", size=radius, color="color", alpha=alpha, source=data_source)

    fig.add_tools(bm.HoverTool(tooltips=[(key, "@" + key) for key in kwargs.keys()]))
    if show:
        pl.show(fig)
    return fig

In [108]:
embedding = umap.UMAP(n_neighbors=5).fit_transform(word_embeddings)

In [109]:
draw_vectors(embedding[:, 0], embedding[:, 1], token=top_words)