In [1]:
from collections import defaultdict
import numpy as np

import torch
from torch import nn
from torch.optim import Adam

In [2]:
device = "mps" if torch.backends.mps.is_available() else "cpu"

# CBOW Model

In [3]:
class CBoW(nn.Module):
    def __init__(self, vocab_size, embedding_dim, n_classes):
        super(CBoW, self).__init__()
        
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        
        # One can use `EmbeddingBag` directly instead of `Embedding`
        # and summing over the embeddings of the words
        # Unfortunatly, `EmbeddingBag` is still not available for mps
        self.embedding = nn.Embedding(
                                vocab_size,
                                embedding_dim,
                                device=device,
                            )
        self.linear = nn.Linear(embedding_dim, n_classes, device=device)
        
        nn.init.xavier_uniform_(self.embedding.weight)
        nn.init.xavier_uniform_(self.linear.weight)
    
    def forward(self, words):
        out = self.embedding(words) # size: len(words) * embedding_dim
        out = out.sum(dim=0, keepdims=True) # size : embedding_dim
        out = self.linear(out) # size: n_classes
        return out

# Applying the model

In [4]:
w2i = defaultdict(lambda: len(w2i))
t2i = defaultdict(lambda: len(t2i))
UNK = w2i["<unk>"]

In [5]:
def read_dataset(path: str):
    with open(path, "r") as f:
        for line in f:
            try:
                line = f.readline().lower().strip().split(" ||| ")
                text_class, text = line[0], line[1]
                yield ([w2i[word] for word in text.split(" ")], t2i[text_class])
            except:
                pass

In [6]:
train = list(read_dataset("../data/classes/train.txt"))
vocab_size = len(w2i)
n_classes = len(t2i)

In [7]:
vocab_size

11402

In [8]:
n_classes

5

In [9]:
w2i = defaultdict(lambda: UNK, w2i)
dev = list(read_dataset("../data/classes/dev.txt"))

In [10]:
cbow_model = CBoW(vocab_size, 128, n_classes)

In [11]:
loss_criterion = nn.CrossEntropyLoss()
optimizer = Adam(cbow_model.parameters())

In [12]:
# Just 10 epochs as the goal is not to train a real model
# but just to see if the implementation is working
for i in range(10):
    train_loss = 0
    test_accuracy = 0
    for words, sentence_class in train:
        words = torch.tensor(words, device=device)
        sentence_class = torch.tensor([sentence_class], device=device)
        predictions = cbow_model(words)
        loss = loss_criterion(predictions, sentence_class)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Iteration {i} - Train loss: {train_loss/len(train)}")
    with torch.no_grad():
        for words, sentence_class in dev:
            words = torch.tensor(words, device=device)
            predictions = cbow_model(words)
            predicted_class = np.argmax(predictions.detach().cpu().numpy())
            if predicted_class == sentence_class:
                test_accuracy += 1
    print(f"Iteration {i} - Test accuracy: {test_accuracy/len(dev)}")

Iteration 0 - Train loss: 1.2553884235493253
Iteration 0 - Test accuracy: 0.3527272727272727
Iteration 1 - Train loss: 0.5974311805731348
Iteration 1 - Test accuracy: 0.33636363636363636
Iteration 2 - Train loss: 0.24939595125197025
Iteration 2 - Test accuracy: 0.3327272727272727
Iteration 3 - Train loss: 0.11656908424438608
Iteration 3 - Test accuracy: 0.30727272727272725
Iteration 4 - Train loss: 0.052619257893008684
Iteration 4 - Test accuracy: 0.3090909090909091
Iteration 5 - Train loss: 0.021776127792922744
Iteration 5 - Test accuracy: 0.2927272727272727
Iteration 6 - Train loss: 0.016617859124244374
Iteration 6 - Test accuracy: 0.28545454545454546
Iteration 7 - Train loss: 0.007515847627143289
Iteration 7 - Test accuracy: 0.30363636363636365
Iteration 8 - Train loss: 0.004991540859701035
Iteration 8 - Test accuracy: 0.3054545454545455
Iteration 9 - Train loss: 0.0023530923137057586
Iteration 9 - Test accuracy: 0.3109090909090909
