<a href="https://colab.research.google.com/github/bbandbass/Projects/blob/main/skip_gram_implementation1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

In [None]:
window_size = 2  # context words로 왼쪽으로 2 단어, 오른쪽으로 2 단어

sentence = """
Regrets, I've had a few.
But then again, too few to mention.
I did what I had to do.
And saw it through without exemption.
I planned each charted course.
Each careful step along the byway.
And more, much more than this, I did it my way.
"""

words = sentence.split()

vocab = set(words)
vocab_size = len(vocab)

word_to_idx = {word:idx for idx, word in enumerate(vocab)}
idx_to_word = {idx:word for idx, word in enumerate(vocab)}

data = []

# context words와 centor word 
for i in range(window_size, len(words) - window_size):
  context = [words[i - window_size : i], words[i + 1 : i + window_size + 1]]
  context = context[0] + context[1]
  center = words[i]
  data.append((context, center))

In [None]:
data

[(['Regrets,', "I've", 'a', 'few.'], 'had'),
 (["I've", 'had', 'few.', 'But'], 'a'),
 (['had', 'a', 'But', 'then'], 'few.'),
 (['a', 'few.', 'then', 'again,'], 'But'),
 (['few.', 'But', 'again,', 'too'], 'then'),
 (['But', 'then', 'too', 'few'], 'again,'),
 (['then', 'again,', 'few', 'to'], 'too'),
 (['again,', 'too', 'to', 'mention.'], 'few'),
 (['too', 'few', 'mention.', 'I'], 'to'),
 (['few', 'to', 'I', 'did'], 'mention.'),
 (['to', 'mention.', 'did', 'what'], 'I'),
 (['mention.', 'I', 'what', 'I'], 'did'),
 (['I', 'did', 'I', 'had'], 'what'),
 (['did', 'what', 'had', 'to'], 'I'),
 (['what', 'I', 'to', 'do.'], 'had'),
 (['I', 'had', 'do.', 'And'], 'to'),
 (['had', 'to', 'And', 'saw'], 'do.'),
 (['to', 'do.', 'saw', 'it'], 'And'),
 (['do.', 'And', 'it', 'through'], 'saw'),
 (['And', 'saw', 'through', 'without'], 'it'),
 (['saw', 'it', 'without', 'exemption.'], 'through'),
 (['it', 'through', 'exemption.', 'I'], 'without'),
 (['through', 'without', 'I', 'planned'], 'exemption.'),
 (['

In [None]:
def make_context_vector(context, word_to_idx):
    context_index = [word_to_idx[w] for w in context]
    return torch.tensor(context_index, dtype = torch.long)

In [None]:
def make_center_vector(center, word_to_idx):
  return torch.tensor(word_to_idx[center], dtype = torch.long)

# Skip-Gram

In [None]:
class SkipGram(nn.Module):
  def __init__(self, vocab_size, projection_size, window_size):
    super(SkipGram, self).__init__()
    self.projection = nn.Embedding(vocab_size, projection_size)
    self.linear = nn.Linear(projection_size, 2 * window_size * vocab_size)
    self.activation = nn.LogSoftmax(dim = 0)

  def forward(self, input):
    projection = self.projection(input)
    output = self.linear(projection).view(2 * window_size, vocab_size)
    y_hat = self.activation(output)
    
    return y_hat

In [None]:
skipgram = SkipGram(vocab_size, 500, 2)
criterion = nn.NLLLoss()
optimizer = torch.optim.SGD(skipgram.parameters(), lr = 0.01)

In [None]:
for epoch in range(5000):
  
  loss = 0

  for context, center in data:
    
    context_vector = make_context_vector(context, word_to_idx)
    center_vector = make_center_vector(center, word_to_idx)
    y_hat = skipgram(center_vector)
    # print("y_hat: ", torch.argmax(y_hat, dim = 1))
    # print("context_vector: ", context_vector)
    loss += criterion(y_hat, context_vector)
    
  if (epoch + 1) % 10 == 0:
    print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

In [None]:
print(context)
print(center)

['I', 'did', 'my', 'way.']
it


In [None]:
test = skipgram(center_vector)

In [None]:
print([idx_to_word[torch.argmax(i).item()] for i in test])

['I', 'saw', 'my', 'way.']
