In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

<torch._C.Generator at 0x1ed70461190>

In [4]:
#TODO: need to construct word_to_ix
word_to_ix = {"hello": 0, "world": 1}
embeds = nn.Embedding(2, 5)  # 2 words in vocab, 5 dimensional embeddings
lookup_tensor = torch.tensor([word_to_ix["hello"]], dtype=torch.long)
print(lookup_tensor)
hello_embed = embeds(lookup_tensor)
print(hello_embed)

tensor([0])
tensor([[-0.8923, -0.0583, -0.1955, -0.9656,  0.4224]],
       grad_fn=<EmbeddingBackward>)


In [9]:
CONTEXT_SIZE = 2  # 2 words to the left, 2 to the right
raw_text = """We are about to study the idea of a computational process.
Computational processes are abstract beings that inhabit computers.
As they evolve, processes manipulate other abstract things called data.
The evolution of a process is directed by a pattern of rules
called a program. People create programs to direct processes. In effect,
we conjure the spirits of the computer with our spells.""".split()

# vocab set and vocab size
vocab = set(raw_text)
vocab_size = len(vocab)

# construct dictionary to lookup 
word_to_ix = {word: i for i, word in enumerate(vocab)}
# construct training data: (context, target) pair
raw_data = []
for i in range(2, len(raw_text) - 2):
    context = [raw_text[i - 2], raw_text[i - 1],
               raw_text[i + 1], raw_text[i + 2]]
    target = raw_text[i]
    raw_data.append((context, target))
print(raw_data[:5])
print(len(raw_data))

[(['We', 'are', 'to', 'study'], 'about'), (['are', 'about', 'study', 'the'], 'to'), (['about', 'to', 'the', 'idea'], 'study'), (['to', 'study', 'idea', 'of'], 'the'), (['study', 'the', 'of', 'a'], 'idea')]
58


In [13]:
dataset = [(torch.tensor([word_to_ix[word] for word in context]), word_to_ix[target]) for context,target in raw_data]

In [14]:
dataset[0]

(tensor([21, 17, 26, 24]), 10)

In [55]:
embedding = nn.EmbeddingBag(vocab_size, 3, mode='sum')
embedding(torch.tensor([21,17]), offsets=torch.tensor([0]))[0,:]

tensor([ 2.8269,  1.7077, -1.4443], grad_fn=<SliceBackward>)

In [52]:
embedding = nn.Embedding(vocab_size, 3)
embedding(torch.tensor([21,17])).sum(0)

tensor([ 0.4923,  2.4214, -2.0156], grad_fn=<SumBackward2>)

In [59]:
class CBOW(nn.Module):
    def __init__(self):
        super(CBOW, self).__init__()
        self.embeddingbag = nn.EmbeddingBag(vocab_size, 3, mode="sum")
        self.linear = nn.Linear(3, vocab_size)
    def forward(self, x):
        # extract embedding of context and sum up termwise
        x = self.embeddingbag(x, offsets=torch.tensor([0]))[0,:]
        # output will be of shape (1,v)
        x = self.linear(x)
        return x

In [60]:
model = CBOW()
model.forward(dataset[0][0])

tensor([-0.6412,  0.7104,  0.4156,  0.6317,  0.3026,  0.5045,  0.8667, -0.4203,
        -0.2961,  1.3353, -0.9742, -0.0328,  0.0995,  1.7415,  0.3087, -0.3693,
         0.8926, -0.1616, -0.8411, -0.3344,  0.3540, -0.3632, -0.8303,  0.3895,
         0.4715, -0.5619,  0.0681, -0.1810, -0.3815,  0.6601,  0.7997, -1.0243,
        -0.5720,  0.4455, -0.1312,  0.0021, -0.6804, -0.4587,  0.7080,  0.0481,
        -0.0990, -1.0758,  0.5609, -0.8402, -1.3805,  0.0209, -0.0101, -1.0264,
         0.7997], grad_fn=<AddBackward0>)

In [61]:
criterion = nn.CrossEntropyLoss()

In [62]:
dataset[0][1]

10

In [63]:
model.forward(dataset[0][0])

tensor([-0.6412,  0.7104,  0.4156,  0.6317,  0.3026,  0.5045,  0.8667, -0.4203,
        -0.2961,  1.3353, -0.9742, -0.0328,  0.0995,  1.7415,  0.3087, -0.3693,
         0.8926, -0.1616, -0.8411, -0.3344,  0.3540, -0.3632, -0.8303,  0.3895,
         0.4715, -0.5619,  0.0681, -0.1810, -0.3815,  0.6601,  0.7997, -1.0243,
        -0.5720,  0.4455, -0.1312,  0.0021, -0.6804, -0.4587,  0.7080,  0.0481,
        -0.0990, -1.0758,  0.5609, -0.8402, -1.3805,  0.0209, -0.0101, -1.0264,
         0.7997], grad_fn=<AddBackward0>)

In [65]:
loss = criterion(model.forward(dataset[0][0]).unsqueeze(0), torch.tensor(dataset[0][1]).unsqueeze(0))

In [66]:
loss

tensor(5.0851, grad_fn=<NllLossBackward>)