In [59]:
import numpy as np
import sys
sys.path.append('..')

class UnigramSampler:
    def __init__(self, corpus, power, sample_size):
        self.sample_size = sample_size
        self.vocab_size = None
        self.word_p = None

        counts = collections.Counter()
        for word_id in corpus:
            counts[word_id] += 1

        vocab_size = len(counts)
        self.vocab_size = vocab_size

        self.word_p = np.zeros(vocab_size)
        for i in range(vocab_size):
            self.word_p[i] = counts[i]

        self.word_p = np.power(self.word_p, power)
        self.word_p /= np.sum(self.word_p)

    def get_negative_sample(self, target):
        batch_size = target.shape[0]
        negative_sample = np.zeros((batch_size, self.sample_size), dtype=np.int32)

        for i in range(batch_size):
            p = self.word_p.copy()
            target_idx = target[i]
            p[target_idx] = 0
            p /= p.sum()
            negative_sample[i, :] = np.random.choice(self.vocab_size, size=self.sample_size, replace=False, p=p)
            
        return negative_sample


In [67]:
corpus = np.array([0, 1, 2, 3, 4, 1, 2, 3, 5, 6, 4, 2, 9, 10, 11, 15, 3, 5, 8, 1, 9, 0])
power = 0.75
sample_size = 5

sampler = UnigramSampler(corpus, power, sample_size)
target = np.array([1, 3, 0])

In [75]:
negative_sample = sampler.get_negative_sample(target)
print(negative_sample)
print()
print(negative_sample[:, 0])
# negative_target = negative_sample[:, i]

[[11  3  5  2  4]
 [ 2  9  0  8  6]
 [ 2  5 11  1  4]]

[11  2  2]
