In [42]:
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn as nn

In [30]:
np.__version__, torch.__version__

('1.26.4', '2.5.0+cu121')

#**Load Data**

In [31]:
corpus = ["apple banana fruit","banana apple fruit","banana fruit apple",
    "cat dog animal","cat animal dog","dog cat animal"]

In [32]:
corpus

['apple banana fruit',
 'banana apple fruit',
 'banana fruit apple',
 'cat dog animal',
 'cat animal dog',
 'dog cat animal']

In [33]:
# 1. tokenization
corpus = [sent.split(" ") for sent in corpus]
corpus

[['apple', 'banana', 'fruit'],
 ['banana', 'apple', 'fruit'],
 ['banana', 'fruit', 'apple'],
 ['cat', 'dog', 'animal'],
 ['cat', 'animal', 'dog'],
 ['dog', 'cat', 'animal']]

In [34]:
# 2. numericalization
# find unique words
flatten = lambda l: [item for sublist in l for item in sublist]
# assign unique integer
vocabs = list(set(flatten(corpus))) # all the words in system - <UNK>


In [35]:
# handy mapping between words and integers
word2index = {v:idx for idx,v in enumerate(vocabs)}
word2index

{'dog': 0, 'fruit': 1, 'apple': 2, 'banana': 3, 'cat': 4, 'animal': 5}

In [36]:
vocabs.append('<UNK>')
word2index['<UNK>'] = 6


In [37]:
index2word = {v:k for k,v in word2index.items()}
index2word

{0: 'dog',
 1: 'fruit',
 2: 'apple',
 3: 'banana',
 4: 'cat',
 5: 'animal',
 6: '<UNK>'}

#**prepare training data**

In [38]:
# create pairs of center words, and outside words
def random_batch(batch_size, corpus):
  skipgram = []
  # loop each corpus
  for doc in corpus:

    # look for second word until the 2nd last word
    for i in range(1,len(doc)-1):
      #center word
      center = word2index[doc[i]]
      #outside words
      outside = (word2index[doc[i-1]], word2index[doc[i+1]])
      #for each of these two outside words, we gonna append to a list
      for each_out in outside:
        #center, outside1; center outside 2
        skipgram.append([center,each_out])
  random_index = np.random.choice(range(len(skipgram)), batch_size, replace=False)
  inputs, labels = [], []
  for index in random_index:
    inputs.append([skipgram[index][0]])
    labels.append([skipgram[index][1]])
  return np.array(inputs), np.array(labels)
x,y = random_batch(2,corpus)

In [39]:
vocabs

['dog', 'fruit', 'apple', 'banana', 'cat', 'animal', '<UNK>']

In [40]:
print(x.shape, y.shape)

(2, 1) (2, 1)


#**Model**

In [44]:
class Skipgram(nn.Module):

  def __init__(self, voc_size, emb_size):
    super(Skipgram, self).__init__()
    self.embedding_center = nn.Embedding(voc_size, emb_size)
    self.embedding_outside = nn.Embedding(voc_size, emb_size)


  def forward(self, center, outside, all_vocabs):
    center_embedding = self.embedding_center(center)      #batch_size, 1 ,emb_size
    outside_embedding = self.embedding_outside(outside)   #batch_size, 1 ,emb_size
    all_vocabs_embedding = self.embedding_center(all_vocabs)   #batch_size, voc_size ,emb_size

    top_term = torch.exp(outside_embedding.bmm(center_embedding.transpose(1,2)).squeeze(2))
    #batch_size,1,emb_size @ batch_size, emb_size, 1 = batch_size,1,1 = batch_size, 1
    lower_term = all_vocabs_embedding.bmm(center_embedding.transpose(1,2)).squeeze(2)
    #batch_size, voc_size ,emb_size @ batch_size, emb_size, 1 = batch_size, voc_size ,1 = batch_size, voc_size
    lower_term_sum = torch.sum(torch.exp(lower_term), 1)  # batch_size, 1

    loss = -torch.mean(torch.log(top_term/ lower_term_sum))

    return loss



In [47]:
vocabs

['dog', 'fruit', 'apple', 'banana', 'cat', 'animal', '<UNK>']

In [48]:
# prepare all vocabs

batch_size =2
voc_size = len(vocabs
               )
def prepare_sequence(seq, word2index):
  idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index["<UNK>"], seq))
  return torch.LongTensor(idxs)

all_vocabs = prepare_sequence(list(vocabs), word2index).expand(batch_size, voc_size)
all_vocabs

tensor([[0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6]])

#**Training**

In [49]:
model = Skipgram(voc_size, 2)

In [54]:
batch_size = 2
emb_size = 2

model = Skipgram(voc_size, emb_size)

optimizer = optim.Adam(model.parameters(), lr = 0.001)


In [55]:
num_epochs = 5000

for epoch in range(num_epochs):

  # get batch
  input_batch, label_batch = random_batch(batch_size, corpus)
  input_tensor = torch.LongTensor(input_batch)
  label_tensor = torch.LongTensor(label_batch)


  # get prediction

  loss = model(input_tensor, label_tensor, all_vocabs)

  # back propogate
  optimizer.zero_grad()
  loss.backward()


  #update alpha
  optimizer.step()

  #print loss

  if(epoch+1)%100==0:
    print(f'epoch: {epoch}, loss: {loss:2.4f}')

epoch: 99, loss: 2.6930
epoch: 199, loss: 1.5886
epoch: 299, loss: 2.1233
epoch: 399, loss: 2.0476
epoch: 499, loss: 1.8233
epoch: 599, loss: 2.7677
epoch: 699, loss: 2.4858
epoch: 799, loss: 1.5661
epoch: 899, loss: 2.4866
epoch: 999, loss: 1.4711
epoch: 1099, loss: 1.4214
epoch: 1199, loss: 2.9898
epoch: 1299, loss: 1.8356
epoch: 1399, loss: 1.7937
epoch: 1499, loss: 1.2112
epoch: 1599, loss: 0.7381
epoch: 1699, loss: 1.2000
epoch: 1799, loss: 1.0255
epoch: 1899, loss: 2.1720
epoch: 1999, loss: 2.0708
epoch: 2099, loss: 1.5333
epoch: 2199, loss: 0.7828
epoch: 2299, loss: 1.2405
epoch: 2399, loss: 1.7811
epoch: 2499, loss: 1.4468
epoch: 2599, loss: 1.2503
epoch: 2699, loss: -0.4050
epoch: 2799, loss: -0.5252
epoch: 2899, loss: 0.7758
epoch: 2999, loss: 1.7923
epoch: 3099, loss: 1.9363
epoch: 3199, loss: -0.5728
epoch: 3299, loss: 1.2086
epoch: 3399, loss: -0.0745
epoch: 3499, loss: 1.3659
epoch: 3599, loss: 0.7054
epoch: 3699, loss: -0.1793
epoch: 3799, loss: -0.8696
epoch: 3899, loss