In [1]:
from multiprocessing import cpu_count

import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch.optim as optim

from src.model import CBOW
from src.collator import CBOWCollator
from src.utils import preprocess, create_vocab
from src.constants import (
    EMBEDDING_DIMS,
    VOCAB_SIZE,
    MIN_WORD_FREQ,
    CONTEXT_LENGTH,
    SUBSAMPLE_THRESH,
)
from src.dataset import GenericPairDataset

In [2]:
class CBOW(nn.Module):
    def __init__(self, vocab_size, dims):
        super().__init__()
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=dims)
        self.linear = nn.Linear(in_features=dims, out_features=vocab_size)

    def forward(self, inputs):
        embeds = self.embeddings(inputs).sum(dim=1)
        out = self.linear(embeds)
        return out
    
    def debug_forward(self, inputs):
        embeds = self.embeddings(inputs)
        print(embeds.shape)
        print(embeds)
        agg = embeds.sum(dim=1)
        print(agg.shape)
        print(agg)
        out = self.linear(embeds)
        return out


In [3]:
data = load_dataset(
    "deokhk/en_wiki_sentences_100000", split="dev", cache_dir="./data"
)
data = data.map(preprocess, remove_columns="sentence", num_proc=cpu_count() - 1)

In [4]:
vocabulary = create_vocab(
    sentences=data["tokens"], max_size=VOCAB_SIZE, min_freq=MIN_WORD_FREQ
)

In [5]:
model = CBOW(vocab_size=len(vocabulary), dims=5)
collator = CBOWCollator(context_length=CONTEXT_LENGTH, vocab=vocabulary)

In [6]:
contexts, targets = collator.collate(data["tokens"])
dataset = GenericPairDataset(contexts, targets)

In [7]:
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

In [8]:
for batch_contexts, batch_targets in dataloader:
    print(batch_contexts.shape)
    print(batch_contexts)
    pred = model.debug_forward(batch_contexts)
    break

torch.Size([2, 8])
tensor([[  4,   1,   0,   3, 252,  66,   0, 440],
        [  0,   0,  10, 655,   6,   0,   0,   0]])
torch.Size([2, 8, 5])
tensor([[[ 4.2255e-02,  1.8558e+00,  1.0853e-01,  1.2185e+00, -6.7888e-01],
         [ 1.2017e-01, -1.2318e+00, -9.4021e-01, -2.3804e-01,  8.5557e-01],
         [ 5.6088e-01,  4.6863e-02,  7.0363e-01, -4.1738e-01, -1.0881e-01],
         [ 1.1251e+00, -2.3331e+00,  1.1188e-01,  1.5523e+00,  8.0833e-01],
         [-6.2213e-01, -8.3106e-01,  1.9010e+00,  1.0990e+00, -2.0834e+00],
         [-6.1311e-02, -2.0460e-01,  1.0276e+00,  6.2708e-01, -5.1276e-01],
         [ 5.6088e-01,  4.6863e-02,  7.0363e-01, -4.1738e-01, -1.0881e-01],
         [-6.0386e-01,  5.9110e-01,  2.2420e-01, -2.0179e-01,  1.9633e+00]],

        [[ 5.6088e-01,  4.6863e-02,  7.0363e-01, -4.1738e-01, -1.0881e-01],
         [ 5.6088e-01,  4.6863e-02,  7.0363e-01, -4.1738e-01, -1.0881e-01],
         [ 1.0571e+00,  9.1276e-01, -9.8245e-02,  2.1552e+00,  7.7171e-01],
         [-4.9074e-0

In [9]:
model.embeddings(batch_contexts).sum(axis=1)

tensor([[ 1.1220, -2.0600,  3.8402,  3.2222,  0.1345],
        [ 4.2498,  1.9363,  2.6482,  1.2726, -0.9026]], grad_fn=<SumBackward1>)

In [10]:
sum([-0.1966, 0.2234, -0.8973, -0.0871, 0.0592, -0.1966, -0.3767, 0.8763])

-0.5954