# 2. Skip-gram with negative sampling

I recommend you take a look at these material first.

* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf
* http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf

In [1]:
import os
import mindspore
from mindspore import nn, Tensor, ops
import nltk
import random
import numpy as np
from collections import Counter
from mindnlp.modules import Accumulator
flatten = lambda l: [item for sublist in l for item in sublist]
random.seed(1024)

  from tqdm.autonotebook import tqdm


In [2]:
print(mindspore.__version__)
print(nltk.__version__)

2.0.0.20230623
3.7


In [3]:
gpu = '0'
# 设置使用哪些显卡进行训练
os.environ["CUDA_VISIBLE_DEVICES"] = gpu

In [4]:
def getBatch(batch_size, train_data):
    random.shuffle(train_data)
    sindex = 0
    eindex = batch_size
    while eindex < len(train_data):
        batch = train_data[sindex: eindex]
        temp = eindex
        eindex = eindex + batch_size
        sindex = temp
        yield batch

    if eindex >= len(train_data):
        batch = train_data[sindex:]
        yield batch

In [5]:
def prepare_sequence(seq, word2index):
    idxs = list(map(lambda w: word2index[w]
                    if word2index.get(w) is not None
                    else word2index["<UNK>"], seq))
    sequence = Tensor(idxs, dtype=mindspore.int64)
    return sequence


def prepare_word(word, word2index):
    return Tensor([word2index[word]], dtype=mindspore.int64) \
        if word2index.get(word) is not None \
        else Tensor([word2index["<UNK>"]], dtype=mindspore.int64)

## Data load and Preprocessing 

In [6]:
corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]
corpus = [[word.lower() for word in sent] for sent in corpus]

### Exclude sparse words 

In [7]:
word_count = Counter(flatten(corpus))

In [8]:
MIN_COUNT = 3
exclude = []

In [9]:
for w, c in word_count.items():
    if c < MIN_COUNT:
        exclude.append(w)

### Prepare train data 

In [10]:
vocab = list(set(flatten(corpus)) - set(exclude))

In [11]:
word2index = {}
for vo in vocab:
    if word2index.get(vo) is None:
        word2index[vo] = len(word2index)

index2word = {v: k for k, v in word2index.items()}

In [12]:
WINDOW_SIZE = 5
windows = flatten([list(nltk.ngrams(['<DUMMY>'] * WINDOW_SIZE + c + ['<DUMMY>'] * WINDOW_SIZE, WINDOW_SIZE * 2 + 1)) for c in corpus])

train_data = []

for window in windows:
    for i in range(WINDOW_SIZE * 2 + 1):
        if window[i] in exclude or window[WINDOW_SIZE] in exclude:
            continue  # min_count
        if i == WINDOW_SIZE or window[i] == '<DUMMY>':
            continue
        train_data.append((window[WINDOW_SIZE], window[i]))

X_p = []
y_p = []

for tr in train_data:
    X_p.append(prepare_word(tr[0], word2index).view(1, -1))
    y_p.append(prepare_word(tr[1], word2index).view(1, -1))

train_data = list(zip(X_p, y_p))

In [13]:
len(train_data)

50242

### Build Unigram Distribution**0.75 

$$P(w)=U(w)^{3/4}/Z$$

In [14]:
Z = 0.001

In [15]:
word_count = Counter(flatten(corpus))
num_total_words = sum([c for w, c in word_count.items() if w not in exclude])

In [16]:
unigram_table = []

for vo in vocab:
    unigram_table.extend([vo] * int(((word_count[vo] / num_total_words)**0.75) / Z))

In [17]:
print(len(vocab), len(unigram_table))

478 3500


### Negative Sampling 

In [18]:
def negative_sampling(targets, unigram_table, k):
    batch_size = targets.shape[0]
    neg_samples = []
    for i in range(batch_size):
        nsample = []
        target_index = targets[i].asnumpy().item(0)
        while len(nsample) < k:  # num of sampling
            neg = random.choice(unigram_table)
            if word2index[neg] == target_index:
                continue
            nsample.append(neg)
        neg_samples.append(prepare_sequence(nsample, word2index).view(1, -1))

    return ops.cat(neg_samples)

## Modeling 

<img src="../images/02.skipgram-objective.png">
<center>borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf</center>

In [19]:
class SkipgramNegSampling(nn.Cell):

    def __init__(self, vocab_size, projection_dim):
        super(SkipgramNegSampling, self).__init__()
        self.embedding_v = nn.Embedding(vocab_size, projection_dim)  # center embedding
        self.embedding_u = nn.Embedding(vocab_size, projection_dim)  # out embedding
        self.logsigmoid = nn.LogSigmoid()

        initrange = (2.0 / (vocab_size + projection_dim))**0.5  # Xavier init
        minval = Tensor(-initrange, mindspore.float32)
        maxval = Tensor(initrange, mindspore.float32)
        self.embedding_v.embedding_table.set_data(ops.uniform(self.embedding_v.embedding_table.shape, minval, maxval))  # init
        self.embedding_u.embedding_table.set_data(ops.zeros(self.embedding_u.embedding_table.shape, mindspore.float32))  # init

    def construct(self, center_words, target_words, negative_words):
        center_embeds = self.embedding_v(center_words)  # B x 1 x D
        target_embeds = self.embedding_u(target_words)  # B x 1 x D

        neg_embeds = -self.embedding_u(negative_words)  # B x K x D

        positive_score = ops.BatchMatMul()(target_embeds, ops.transpose(center_embeds, (0, 2, 1))).squeeze(2)  # Bx1
        negative_score = ops.sum(ops.BatchMatMul()(neg_embeds, ops.transpose(center_embeds, (0, 2, 1))).squeeze(2), 1).view(negs.shape[0], -1)  # BxK -> Bx1

        loss = self.logsigmoid(positive_score) + self.logsigmoid(negative_score)

        return -ops.mean(loss)

    def prediction(self, inputs):
        embeds = self.embedding_v(inputs)

        return embeds

## Train 

In [20]:
EMBEDDING_SIZE = 30
BATCH_SIZE = 256
EPOCH = 101
NEG = 10  # Num of Negative Sampling

In [21]:
losses = []
model = SkipgramNegSampling(len(word2index), EMBEDDING_SIZE)
optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)

In [22]:
accumulate_step = 2
accumulator = Accumulator(optimizer, accumulate_step)


def forward_fn(inputs, targets, negs):
    loss = model(inputs, targets, negs)
    return loss / accumulate_step


# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, model.trainable_params())


# Define function of one-step training
@mindspore.jit
def train_step(inputs, targets, negs):
    loss, grads = grad_fn(inputs, targets, negs)
    loss = ops.depend(loss, accumulator(grads))
    return loss

In [23]:
for epoch in range(EPOCH):
    for i, batch in enumerate(getBatch(BATCH_SIZE, train_data)):

        inputs, targets = zip(*batch)

        inputs = ops.cat(inputs)  # B x 1
        targets = ops.cat(targets)  # B x 1
        negs = negative_sampling(targets, unigram_table, NEG)

        loss = train_step(inputs, targets, negs)

        losses.append(loss.asnumpy().item(0) * accumulate_step)

    if epoch % 10 == 0:
        print("Epoch : %d, mean_loss : %.02f" % (epoch, np.mean(losses)))
        losses = []

Epoch : 0, mean_loss : 1.17
Epoch : 10, mean_loss : 0.87
Epoch : 20, mean_loss : 0.83
Epoch : 30, mean_loss : 0.78
Epoch : 40, mean_loss : 0.74
Epoch : 50, mean_loss : 0.72
Epoch : 60, mean_loss : 0.70
Epoch : 70, mean_loss : 0.68
Epoch : 80, mean_loss : 0.67
Epoch : 90, mean_loss : 0.66
Epoch : 100, mean_loss : 0.65


## Test 

In [24]:
def word_similarity(target, vocab):
    target_V = model.prediction(prepare_word(target, word2index))
    similarities = []
    for i in range(len(vocab)):
        if vocab[i] == target:
            continue

        vector = model.prediction(prepare_word(list(vocab)[i], word2index))

        cosine_sim = ops.cosine_similarity(target_V, vector).asnumpy().tolist()[0]
        similarities.append([vocab[i], cosine_sim])
    return sorted(similarities, key=lambda x: x[1], reverse=True)[:10]

In [25]:
test = random.choice(list(vocab))
test

'her'

In [26]:
word_similarity(test, vocab)

[['nantucket', 0.7158238291740417],
 ['yet', 0.6215175986289978],
 ['behind', 0.6064300537109375],
 ['man', 0.5916519165039062],
 ['without', 0.5414537191390991],
 ['whale', 0.5358570218086243],
 ['bed', 0.5276554226875305],
 ['craft', 0.5193333625793457],
 ['who', 0.5097337365150452],
 ['being', 0.5067386031150818]]