In [68]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import re
torch.manual_seed(1)

<torch._C.Generator at 0x7f66d37380b0>

In [16]:
def tokenize(text):
    pattern = re.compile(r'[A-Za-z]+[\w^\']*|[\w^\']*[A-Za-z]+[\w^\']*')
    return pattern.findall(text.lower())

In [17]:
text = open("alice.txt", 'r').read()

In [18]:
words = tokenize(text)

In [21]:
len(words)

26683

In [22]:
vocab = set(words)
word_to_idx = {word:i for i, word in enumerate(vocab)}

In [23]:
len(vocab)

2854

In [43]:
class SkipGramWithSoftmax(nn.Module):
    def __init__(self, vocab_size, n_embedding):
        super(SkipGramWithSoftmax, self).__init__()
        self.embedding_center = nn.Embedding(vocab_size, n_embedding)
        self.embedding_context = nn.Embedding(vocab_size, n_embedding)
        nn.init.uniform_(self.embedding_center.weight, -1, 1)
        nn.init.uniform_(self.embedding_context.weight, -1, 1)
    def forward(self, center_word):
        emb_center = self.embedding_center(center_word)
        out = torch.matmul(emb_center, self.embedding_context.weight.T) # logits for every word in the corpus
        return out

In [35]:
def generate_data(words, vocab, context_size):
    x, y = [], []
    for i, word in enumerate(words):
        l = max(0, i - context_size)
        r = min(len(words) - 1, i + context_size)
        for j in range(l, r + 1):
            if i != j:
                x.append(word_to_idx[word])
                y.append(word_to_idx[words[j]])
    return x, y

In [46]:
X, Y = generate_data(words, vocab, context_size=2)

In [48]:
len(X), len(Y)

(106726, 106726)

In [61]:
vocab_size = len(vocab)
n_embedding = 50
model = SkipGramWithSoftmax(vocab_size, n_embedding)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
batch_size = 128

In [69]:
class SkipGramDataset(Dataset):
    def __init__(self, center_words, context_words):
        self.center_words = center_words
        self.context_words = context_words
    def __len__(self):
        return len(self.center_words)
    def __getitem__(self, index):
        return self.center_words[index], self.context_words[index]

In [71]:
dataset = SkipGramDataset(X, Y)

In [78]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [81]:
for _ in range(2):
    model.train()
    total_loss = 0
    
    for x_batch, y_batch in dataloader:
        optimizer.zero_grad()
        
        out = model(x_batch)
        loss = criterion(out, y_batch)
        total_loss += loss.item()
        
        loss.backward()
        optimizer.step()
    print(total_loss)

8565.318466186523
8473.746932983398
