In [1]:
pip install --upgrade torch



In [None]:
import torch
import torch.nn as nn
import numpy as np

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import re

In [4]:
with open("/content/shakespeare.txt", "r" ,encoding="utf8") as f:
  text = f.read()
text = text.lower()

In [5]:
tokens = re.findall(r"\w+|[^\w\s]", text)

In [6]:
vocab = sorted(set(tokens))

In [7]:
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(vocab)

In [8]:
indices = [word2idx[w] for w in tokens]

In [48]:
class CBOW_Dataset(Dataset):
    def __init__(self, data, context_size):
        self.data = data
        self.context_size = context_size

    def __len__(self):
        return len(self.data) - 2 * self.context_size

    def __getitem__(self, idx):
        # shift index so you always have full context on both sides
        center = idx + self.context_size
        context = self.data[idx:center] + self.data[center+1:center+1+self.context_size]
        target = self.data[center]
        return torch.tensor(context, dtype=torch.long), torch.tensor(target, dtype=torch.long)

In [49]:
context_size = 2
dataset_CBOW = CBOW_Dataset(indices, context_size)
dataloader = DataLoader(dataset_CBOW, batch_size=64, shuffle=True)


In [50]:
import torch
import torch.nn.functional as F
import torch.nn as nn

class CBOW_Model(nn.Module):
  def __init__(self, vocab_size, embedding_dim):
    super(CBOW_Model, self).__init__()
    self.embeddings  = nn.Embedding(vocab_size, embedding_dim)
    self.linear =  nn.Linear(embedding_dim, vocab_size)

  def forward(self, x):
    embedd = self.embeddings(x)
    embedd = embedd.mean(axis=1)
    output = self.linear(embedd)
    return output

In [51]:
cbow_model = CBOW_Model(vocab_size, 100)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cbow_model.parameters(), lr=0.001)

In [52]:
for epoch in range(5):
    total_loss = 0
    for context, target in dataloader:
        optimizer.zero_grad()
        output = cbow_model(context)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}")

Epoch 1, Loss: 24978.7266
Epoch 2, Loss: 21353.4814
Epoch 3, Loss: 20097.2214
Epoch 4, Loss: 19228.9405
Epoch 5, Loss: 18549.2545


In [74]:
class SkipGram_Dataset(Dataset):
    def __init__(self, data, context_size):
        self.pairs = []
        for i in range(context_size, len(data) - context_size):
            center = data[i]
            context = data[i - context_size:i] + data[i + 1:i + context_size + 1]
            for target in context:
                self.pairs.append((center, target))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        center, target = self.pairs[idx]
        return torch.tensor(center, dtype=torch.long), torch.tensor(target, dtype=torch.long)

In [75]:
context_size = 2
dataset_SkipGram = SkipGram_Dataset(indices, context_size)
dataloader_SkipGram = DataLoader(dataset_SkipGram, batch_size=64, shuffle=True)

In [81]:
import torch
import torch.nn.functional as F
import torch.nn as nn

class SkipGram_Model(nn.Module):
  def __init__(self, vocab_size, embedding_dim):
    super(SkipGram_Model, self).__init__()
    self.embeddings  = nn.Embedding(vocab_size, embedding_dim)
    self.linear =  nn.Linear(embedding_dim, vocab_size)

  def forward(self, x):
    # print(x.shape)
    embedd = self.embeddings(x)
    # print(embedd.shape)
    output = self.linear(embedd)
    # print(output.shape)
    return output

In [82]:
skipgram_model = SkipGram_Model(vocab_size, 100)
criterion_ = nn.CrossEntropyLoss()
optimizer_ = torch.optim.Adam(skipgram_model.parameters(), lr=0.001)

In [None]:
for epoch in range(5):
    total_loss = 0
    for context, target in dataloader_SkipGram:
        optimizer_.zero_grad()
        output = skipgram_model(context)
        loss = criterion_(output, target)
        loss.backward()
        optimizer_.step()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}")

Epoch 1, Loss: 96526.6824
Epoch 2, Loss: 95613.5060
