In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchtext import data
from torchtext.datasets import PennTreebank

In [2]:
TEXT = data.Field(sequential=True, tokenize='spacy', lower=True)
train, valid, test = PennTreebank.splits(TEXT)
TEXT.build_vocab(train)
vocab_size = len(TEXT.vocab)

context_size = 2

train_iter, valid_iter, test_iter = data.BPTTIterator.splits((train, valid, test),
                                                            batch_size=512,
                                                            bptt_len=context_size * 2 + 1, # context on both sides, plus the center word
                                                            repeat=False)

downloading ptb.train.txt
downloading ptb.valid.txt
downloading ptb.test.txt


ptb.train.txt:   0%|          | 0.00/1.70M [00:00<?, ?B/s]ptb.train.txt:   6%|▌         | 94.4k/1.70M [00:00<00:03, 474kB/s]ptb.train.txt:   8%|▊         | 143k/1.70M [00:00<00:03, 408kB/s] ptb.train.txt:  11%|█▏        | 193k/1.70M [00:00<00:04, 356kB/s]ptb.train.txt:  14%|█▍        | 242k/1.70M [00:00<00:04, 307kB/s]ptb.train.txt:  20%|██        | 340k/1.70M [00:01<00:05, 255kB/s]ptb.train.txt:  26%|██▌       | 442k/1.70M [00:01<00:04, 298kB/s]ptb.train.txt:  32%|███▏      | 541k/1.70M [00:01<00:03, 313kB/s]ptb.train.txt:  38%|███▊      | 640k/1.70M [00:01<00:02, 389kB/s]ptb.train.txt:  49%|████▉     | 835k/1.70M [00:02<00:01, 487kB/s]ptb.train.txt:  58%|█████▊    | 982k/1.70M [00:02<00:01, 545kB/s]ptb.train.txt:  67%|██████▋   | 1.13M/1.70M [00:02<00:00, 644kB/s]ptb.train.txt:  72%|███████▏  | 1.23M/1.70M [00:02<00:00, 567kB/s]ptb.train.txt:  84%|████████▍ | 1.42M/1.70M [00:02<00:00, 701kB/s]ptb.train.txt:  98%|█████████▊| 1.67M/1.70M [00:02<00:00, 866kB/s]ptb.train.

In [3]:
class CBOW(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)
    
    def forward(self, inputs):
        out = torch.sum(self.embeddings(inputs), dim=0)
        out = self.linear(out)
        out = F.log_softmax(out, dim=0)
        return out

In [4]:
class SkipGram(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)
    
    def forward(self, inputs):
        out = self.embeddings(inputs)
        out = self.linear(out)
        out = F.log_softmax(out, dim=0)
        return out

In [0]:
loss_function = nn.NLLLoss()
model = CBOW(vocab_size, 100)
optimizer = optim.SGD(model.parameters(), lr=1e-3)

for epoch in range(1):
    
    for i, batch in enumerate(train_iter):
        optimizer.zero_grad()
        
        text, target = batch.text, batch.text[context_size]
        output = model(text)
        loss = loss_function(output, target)
        loss.backward()
        optimizer.step()