In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import SGD
torch.manual_seed(1)

<torch._C.Generator at 0x1b2ec709fb0>

In [53]:
CONTEXT_SIZE = 2  # 2 words to the left, 2 to the right
raw_text = """We are about to study the idea of a computational process.
Computational processes are abstract beings that inhabit computers.
As they evolve, processes manipulate other abstract things called data.
The evolution of a process is directed by a pattern of rules
called a program. People create programs to direct processes. In effect,
we conjure the spirits of the computer with our spells.""".split()

# vocab set and vocab size
vocab = set(raw_text)
vocab_size = len(vocab)

# construct dictionary to lookup 
word_to_ix = {word: i for i, word in enumerate(vocab)}
ix_to_word = {ix: word for word, ix in word_to_ix.items()}
# construct training data: (context, target) pair
raw_data = []
for i in range(2, len(raw_text) - 2):
    context = [raw_text[i - 2], raw_text[i - 1],
               raw_text[i + 1], raw_text[i + 2]]
    target = raw_text[i]
    raw_data.append((context, target))
print(raw_data[:5])

[(['We', 'are', 'to', 'study'], 'about'), (['are', 'about', 'study', 'the'], 'to'), (['about', 'to', 'the', 'idea'], 'study'), (['to', 'study', 'idea', 'of'], 'the'), (['study', 'the', 'of', 'a'], 'idea')]


In [54]:
context, target = raw_data[0]
context
[word_to_ix[word] for word in context]

[39, 26, 34, 47]

In [55]:
class corpus_dataset(Dataset):
    def __init__(self, raw_dataset, transform=None):
        # raw_dataset is a list of (context, target) pair
        self.dataset = raw_dataset
        self.transform = transform
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        context, target = self.dataset[idx]
        return {"context":torch.tensor([word_to_ix[word] for word in context]), "target":torch.tensor(word_to_ix[target])}

In [56]:
dataset.__getitem__(0)

{'context': tensor([39, 26, 34, 47]), 'target': tensor(19)}

In [57]:
dataset = corpus_dataset(raw_data)
dataloader = DataLoader(dataset,batch_size=3)

In [85]:
class CBOW(nn.Module):
    def __init__(self):
        super(CBOW, self).__init__()
        self.embedding = nn.Embedding(vocab_size, 3)
        self.linear = nn.Linear(3, vocab_size, bias=False)
    def forward(self, x):
        # extract embedding of context and sum up termwise
        x = self.embedding(x).sum(1)
        # output will be of shape (1,v)
        x = self.linear(x)
        return x

In [86]:
model = CBOW()

In [87]:
criterion = nn.CrossEntropyLoss()

In [88]:
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)

In [89]:
# intialize parameters
for parameter in model.parameters():
    nn.init.normal_(parameter)

# train
for epoch in range(500):
    running_loss = 0.0
    for i, data in enumerate(dataloader, 0):
        context = data["context"]
        target = data["target"]
        optimizer.zero_grad()
        outputs = model(context)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 10 == 0:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
print('Finished Training')

[1,     1] loss: 0.003
[1,    11] loss: 0.047
[2,     1] loss: 0.003
[2,    11] loss: 0.030
[3,     1] loss: 0.002
[3,    11] loss: 0.023
[4,     1] loss: 0.002
[4,    11] loss: 0.019
[5,     1] loss: 0.002
[5,    11] loss: 0.017
[6,     1] loss: 0.002
[6,    11] loss: 0.015
[7,     1] loss: 0.001
[7,    11] loss: 0.015
[8,     1] loss: 0.001
[8,    11] loss: 0.014
[9,     1] loss: 0.001
[9,    11] loss: 0.013
[10,     1] loss: 0.001
[10,    11] loss: 0.013
[11,     1] loss: 0.001
[11,    11] loss: 0.012
[12,     1] loss: 0.001
[12,    11] loss: 0.012
[13,     1] loss: 0.001
[13,    11] loss: 0.012
[14,     1] loss: 0.001
[14,    11] loss: 0.011
[15,     1] loss: 0.001
[15,    11] loss: 0.011
[16,     1] loss: 0.001
[16,    11] loss: 0.011
[17,     1] loss: 0.001
[17,    11] loss: 0.010
[18,     1] loss: 0.001
[18,    11] loss: 0.010
[19,     1] loss: 0.001
[19,    11] loss: 0.010
[20,     1] loss: 0.001
[20,    11] loss: 0.010
[21,     1] loss: 0.001
[21,    11] loss: 0.010
[22,     1

[176,    11] loss: 0.002
[177,     1] loss: 0.000
[177,    11] loss: 0.002
[178,     1] loss: 0.000
[178,    11] loss: 0.002
[179,     1] loss: 0.000
[179,    11] loss: 0.002
[180,     1] loss: 0.000
[180,    11] loss: 0.002
[181,     1] loss: 0.000
[181,    11] loss: 0.002
[182,     1] loss: 0.000
[182,    11] loss: 0.002
[183,     1] loss: 0.000
[183,    11] loss: 0.002
[184,     1] loss: 0.000
[184,    11] loss: 0.002
[185,     1] loss: 0.000
[185,    11] loss: 0.002
[186,     1] loss: 0.000
[186,    11] loss: 0.002
[187,     1] loss: 0.000
[187,    11] loss: 0.002
[188,     1] loss: 0.000
[188,    11] loss: 0.002
[189,     1] loss: 0.000
[189,    11] loss: 0.002
[190,     1] loss: 0.000
[190,    11] loss: 0.002
[191,     1] loss: 0.000
[191,    11] loss: 0.002
[192,     1] loss: 0.000
[192,    11] loss: 0.002
[193,     1] loss: 0.000
[193,    11] loss: 0.002
[194,     1] loss: 0.000
[194,    11] loss: 0.002
[195,     1] loss: 0.000
[195,    11] loss: 0.002
[196,     1] loss: 0.000


[345,    11] loss: 0.000
[346,     1] loss: 0.000
[346,    11] loss: 0.000
[347,     1] loss: 0.000
[347,    11] loss: 0.000
[348,     1] loss: 0.000
[348,    11] loss: 0.000
[349,     1] loss: 0.000
[349,    11] loss: 0.000
[350,     1] loss: 0.000
[350,    11] loss: 0.000
[351,     1] loss: 0.000
[351,    11] loss: 0.000
[352,     1] loss: 0.000
[352,    11] loss: 0.000
[353,     1] loss: 0.000
[353,    11] loss: 0.000
[354,     1] loss: 0.000
[354,    11] loss: 0.000
[355,     1] loss: 0.000
[355,    11] loss: 0.000
[356,     1] loss: 0.000
[356,    11] loss: 0.000
[357,     1] loss: 0.000
[357,    11] loss: 0.000
[358,     1] loss: 0.000
[358,    11] loss: 0.000
[359,     1] loss: 0.000
[359,    11] loss: 0.000
[360,     1] loss: 0.000
[360,    11] loss: 0.000
[361,     1] loss: 0.000
[361,    11] loss: 0.000
[362,     1] loss: 0.000
[362,    11] loss: 0.000
[363,     1] loss: 0.000
[363,    11] loss: 0.000
[364,     1] loss: 0.000
[364,    11] loss: 0.000
[365,     1] loss: 0.000


In [111]:
word_embedding = None
for submodule in model.children():
    print(submodule)
    if type(submodule)== nn.Embedding:
        print(submodule.parameters())
        word_embedding = submodule.weight

Embedding(49, 3)
<generator object Module.parameters at 0x000001B2ED465150>
Linear(in_features=3, out_features=49, bias=False)


In [113]:
word_embedding.data

tensor([[ 1.3143, -0.2626, -4.7010],
        [ 3.4720,  1.7375,  4.1654],
        [ 1.7791,  3.3704, -1.2406],
        [ 2.8334,  1.7984, -3.0672],
        [ 0.6506,  0.6817, -1.4361],
        [ 0.2242, -4.8109, -3.8607],
        [ 1.9254,  2.0261,  0.9470],
        [-2.8827,  5.1868,  0.1970],
        [-1.0005,  1.3667, -0.3649],
        [-0.8144, -3.0856, -1.8149],
        [ 0.0070, -0.1565,  4.2368],
        [ 0.9139, -1.2486,  0.0888],
        [ 1.0708, -1.6619,  0.8572],
        [-4.3968,  1.2924, -3.3288],
        [ 3.9714,  1.0081,  3.9927],
        [-1.4090, -3.5790,  4.2528],
        [ 1.8087,  1.0572, -2.0848],
        [ 1.2545, -3.1374, -0.8219],
        [ 1.0460, -1.8361,  3.0460],
        [-0.2487, -4.0448,  0.5642],
        [-1.2793, -2.5642, -3.8320],
        [-1.6259, -3.6516,  0.1740],
        [-0.0438,  2.3326, -0.3893],
        [-0.1790, -3.1275,  3.0443],
        [-3.4830, -2.6021, -1.2537],
        [-4.7553, -2.1045,  0.2164],
        [ 2.9546, -1.9427,  0.2839],
 

In [105]:
print(model.embedding.parameters())

<generator object Module.parameters at 0x000001B2ED4528E0>
