In [None]:
!pip install torchdata

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchdata
  Downloading torchdata-0.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.4 MB)
[K     |████████████████████████████████| 4.4 MB 3.5 MB/s 
[?25hCollecting urllib3>=1.25
  Downloading urllib3-1.26.12-py2.py3-none-any.whl (140 kB)
[K     |████████████████████████████████| 140 kB 61.2 MB/s 
Collecting portalocker>=2.0.0
  Downloading portalocker-2.5.1-py2.py3-none-any.whl (15 kB)
Collecting urllib3>=1.25
  Downloading urllib3-1.25.11-py2.py3-none-any.whl (127 kB)
[K     |████████████████████████████████| 127 kB 56.9 MB/s 
Installing collected packages: urllib3, portalocker, torchdata
  Attempting uninstall: urllib3
    Found existing installation: urllib3 1.24.3
    Uninstalling urllib3-1.24.3:
      Successfully uninstalled urllib3-1.24.3
Successfully installed portalocker-2.5.1 torchdata-0.4.1 urllib3-1.25.11


https://arxiv.org/abs/1301.3781

In [None]:
import torch
import torch.nn as nn 
from functools import partial
from torch.utils.data import DataLoader
from torchtext.data import to_map_style_dataset
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import numpy as np

In [None]:
EMBED_DIMENSION = 300 
EMBED_MAX_NORM = 1 

class CBOW_Model(nn.Module):
    def __init__(self, vocab_size: int):
        super(CBOW_Model, self).__init__()
        self.embeddings = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=EMBED_DIMENSION,
            max_norm=EMBED_MAX_NORM,
        )
        self.linear = nn.Linear(
            in_features=EMBED_DIMENSION,
            out_features=vocab_size,
        )

    def forward(self, inputs_):
        x = self.embeddings(inputs_)
        x = self.linear(x)
        return x


class SkipGram_Model(nn.Module):
    def __init__(self, vocab_size: int):
        super(SkipGram_Model, self).__init__()
        self.embeddings = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=EMBED_DIMENSION,
            max_norm=EMBED_MAX_NORM,
        )
        self.linear = nn.Linear(
            in_features=EMBED_DIMENSION,
            out_features=vocab_size,
        )

    def forward(self, inputs_):
        x = self.embeddings(inputs_)
        x = self.linear(x)
        return x

In [None]:
tokenizer = get_tokenizer("basic_english", language="en")
text_pipeline = lambda x: vocab(tokenizer(x))

In [None]:
from torchtext.datasets import WikiText2, WikiText103

train_data_iter = WikiText2(root="data/", split="train")
valid_data_iter = WikiText2(root="data/", split="valid")

In [None]:
from torchtext.vocab import build_vocab_from_iterator

vocab = build_vocab_from_iterator(map(tokenizer, train_data_iter), specials=["<unk>"], min_freq=50)
vocab.set_default_index(vocab["<unk>"])

In [None]:
CBOW_N_WORDS = 4 
MAX_SEQUENCE_LENGTH = 256  

def collate_cbow(batch, text_pipeline):
     batch_input, batch_output = [], []
     for text in batch:
         text_tokens_ids = text_pipeline(text)
         if len(text_tokens_ids) < CBOW_N_WORDS * 2 + 1:
             continue
         if MAX_SEQUENCE_LENGTH:
             text_tokens_ids = text_tokens_ids[:MAX_SEQUENCE_LENGTH]
         for idx in range(len(text_tokens_ids) - CBOW_N_WORDS * 2):
             token_id_sequence = text_tokens_ids[idx : (idx + CBOW_N_WORDS * 2 + 1)]
             output = token_id_sequence.pop(CBOW_N_WORDS)
             input_ = token_id_sequence
             batch_input.append(input_)
             batch_output.append(output)
     
     batch_input = torch.tensor(batch_input, dtype=torch.long)
     batch_output = torch.tensor(batch_output, dtype=torch.long)
     return batch_input, batch_output

In [None]:
from torch.utils.data import DataLoader 
from functools import partial  

dataloader = DataLoader(
         train_data_iter,
         batch_size=96,
         shuffle=True,         
         collate_fn=partial(collate_cbow, text_pipeline=text_pipeline))


val_dataloader = DataLoader(
         valid_data_iter,
         batch_size=96,
         shuffle=True,         
         collate_fn=partial(collate_cbow, text_pipeline=text_pipeline))

In [None]:
vocab_size = len(vocab.get_stoi())

model = CBOW_Model(vocab_size)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)

In [None]:
DEVICE = 'cpu'

In [None]:
for epoch in range(2):
  model.train()
  running_loss = []

  for i, batch_data in enumerate(dataloader, 1):
      inputs = batch_data[0]
      labels = batch_data[1]

      optimizer.zero_grad()
      outputs = model(inputs)
      loss = loss_fn(outputs, labels)
      loss.backward()
      optimizer.step()

      running_loss.append(loss.item())

      if i > 100:
        break

  epoch_loss = np.mean(running_loss)
  print(f"train_loss = {epoch_loss}")

  model.eval()
  running_loss = []

  with torch.no_grad():
    for i, batch_data in enumerate(val_dataloader, 1):
      inputs = batch_data[0]
      labels = batch_data[1]

      outputs = model(inputs)  
      loss = loss_fn(outputs, labels)
      running_loss.append(loss.item())
      if i > 100:
        break

  epoch_loss = np.mean(running_loss)
  print(f"val_loss = {epoch_loss}")

train_loss = 8.125571704146886
val_loss = 8.089580249786376
train_loss = 8.064355793565806
val_loss = 8.02445571422577
