In [1]:
from abc import ABC
from typing import List, Dict, Tuple, Set
import random
import torch.nn.functional as F

import torch
from dataset import SkipgramDataset


#############################################
# Helper functions below. DO NOT MODIFY!    #
#############################################


class Word2Vec(torch.nn.Module, ABC):
    """
    A helper class that wraps your word2vec losses.
    """
    def __init__(self, n_tokens: int, word_dimension: int):
        super().__init__()

        self.center_vectors = torch.nn.Parameter(torch.empty([n_tokens, word_dimension]))
        self.outside_vectors = torch.nn.Parameter(torch.empty([n_tokens, word_dimension]))

        self.init_weights()

    def init_weights(self):
        torch.nn.init.normal_(self.center_vectors.data)
        torch.nn.init.normal_(self.outside_vectors.data)

        
class NegSamplingWord2Vec(Word2Vec):
    def __init__(self, n_tokens: int, word_dimension: int, negative_sampler, K: int=10):
        super().__init__(n_tokens, word_dimension)

        self._negative_sampler = negative_sampler
        self._K = K

    def forward(self, center_word_index: torch.Tensor, outside_word_indices: torch.Tensor):
        return neg_sampling_loss(self.center_vectors, self.outside_vectors, center_word_index, outside_word_indices, self._negative_sampler, self._K)

#############################################
# Testing functions below.                  #
#############################################


def test_neg_sampling_loss():
    print ("======Negative Sampling Loss Test Case======")
    center_word_index = torch.randint(1, 100, [5])
    outside_word_indices = []
    for _ in range(5):
        random_window_size = random.randint(3, 6)
        outside_word_indices.append([random.randint(1, 99) for _ in range(random_window_size)] + [0] * (6 - random_window_size))
    outside_word_indices = torch.Tensor(outside_word_indices).to(torch.long)

    neg_sampling_prob = torch.ones([100])
    neg_sampling_prob[0] = 0.

    dummy_database = type('dummy', (), {'_neg_sample_prob': neg_sampling_prob})

    sampled_negatives = list()

    def negative_sampler_wrapper(outside_word_indices, K):
        result = SkipgramDataset.negative_sampler(dummy_database, outside_word_indices, K)
        sampled_negatives.clear()
        sampled_negatives.append(result)
        return result

    model = NegSamplingWord2Vec(n_tokens=100, word_dimension=3, negative_sampler=negative_sampler_wrapper, K=5)

    loss = model(center_word_index, outside_word_indices).mean()
    loss.backward()

    # first test
    assert (model.center_vectors.grad[0, :] == 0).all() and (model.outside_vectors.grad[0, :] == 0).all(), \
        "<PAD> token should not affect the result."
    print("The first test passed! Howerver, this test dosen't guarantee you that <PAD> tokens really don't affects result.")    

    # Second test
    temp = model.center_vectors.grad.clone().detach()
    temp[center_word_index] = 0.
    assert (temp == 0.).all() and (model.center_vectors.grad[center_word_index] != 0.).all(), \
        "Only batched center words can affect the centerword embedding."
    print("The second test passed!")

    # Third test
    sampled_negatives = sampled_negatives[0]
    sampled_negatives[outside_word_indices.unsqueeze(-1).expand(-1, -1, 5) == 0] = 0
    affected_indices = list((set(sampled_negatives.flatten().tolist()) | set(outside_word_indices.flatten().tolist())) - {0})
    temp = model.outside_vectors.grad.clone().detach()
    temp[affected_indices] = 0.
    assert (temp == 0.).all() and (model.outside_vectors.grad[affected_indices] != 0.).all(), \
        "Only batched outside words and sampled negatives can affect the outside word embedding."
    print("The third test passed!")

    
    # forth test
    print(loss)
    assert loss.detach().allclose(torch.tensor(35.82903290)) or loss.detach().allclose(torch.tensor(24.76907349)), \
        "Loss of negative sampling do not match expected result."
    print("The forth test passed!")


    print("All 4 tests passed!")

![대체 텍스트](./figures/embedding.png)

In [2]:
def neg_sampling_loss(
    center_vectors: torch.Tensor, outside_vectors: torch.Tensor,
    center_word_index: torch.Tensor, outside_word_indices: torch.Tensor,
    negative_sampler, K: int=10
) -> torch.Tensor:
    """ Negative sampling loss function for word2vec models

    Implement the negative sampling loss for each pair of (center_word_index, outside_word_indices) in a batch.
    As same with naive_softmax_loss, all inputs are batched with batch_size.

    Note: Implementing negative sampler is a quite tricky job so we pre-implemented this part. See below comments to check how to use it.
    If you want to know how the sampler works, check SkipgramDataset.negative_sampler code in dataset.py file

    Arguments/Return Specifications: same as naiveSoftmaxLoss

    Additional arguments:
    negative_sampler -- the negative sampler
    K -- the number of negative samples to take
    """
    assert center_word_index.shape[0] == outside_word_indices.shape[0]

    n_tokens, word_dim = center_vectors.shape
    batch_size, outside_word_size = outside_word_indices.shape
    PAD = SkipgramDataset.PAD_TOKEN_IDX

    ##### Sampling negtive indices #####
    # Because each outside word needs K negatives samples,
    # negative_sampler takes a tensor in shape [batch_size, outside_word_size] and gives a tensor in shape [batch_size, outside_word_size, K]
    # where values in last dimension are the indices of sampled negatives for each outside_word.
    negative_samples: torch.Tensor = negative_sampler(outside_word_indices, K)
    assert negative_samples.shape == torch.Size([batch_size, outside_word_size, K])

    ###  YOUR CODE HERE
    
    # positive sample loss
    batch_center_vectors = center_vectors[center_word_index]
    batch_dot_product = torch.einsum('bj,kj->bk', [None, None])

    batch_true_loss = torch.log(torch.sigmoid(torch.gather(None, 1, None)))
    
    # negative sample loss
    batch_neg_dots = batch_dot_product.gather(1, negative_samples.reshape(None, None * None))
    batch_neg_dots = batch_neg_dots.view(None, None, None)    
    batch_neg_loss = torch.sum(torch.log(torch.sigmoid(None)), dim=-1)
    
    loss_matrix = -(None + None)
    losses = torch.sum(loss_matrix * (None != 0).int().float(), dim=-1)
    
    ### END YOUR CODE

    assert losses.shape == torch.Size([batch_size])
    return losses

In [3]:
torch.set_printoptions(precision=8)
torch.manual_seed(4321)
random.seed(4321)

test_neg_sampling_loss()



TypeError: expected Tensor as element 0 in argument 1, but got NoneType