In [277]:
import torch
from torch import nn

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
print(f"Using Device: {device}")

Using Device: mps


In [278]:
import random

In [279]:
# Read the data
file_path = "data/shakespear.txt"
with open(file_path) as f:
    data = f.read()

In [280]:
import nltk

If running for the first time, it is necessary to run
```
nltk.download()
```

In [281]:
from nltk.tokenize import sent_tokenize, word_tokenize

In [282]:
data = data[:10000000]

In [283]:
words_in_sents = [word_tokenize(t) for t in sent_tokenize(data)] 

In [284]:
vocab = []
for sent in words_in_sents:
    vocab += sent

vocab = list(set(vocab))
vocab_size = len(vocab)

total_sents = len(words_in_sents)

In [285]:
embedding_size = 300

# Generate Embedding and Context Matrices
mat_emb = torch.randn(vocab_size, embedding_size, dtype=torch.float32, requires_grad=True, device=device)
mat_ctx = torch.randn(vocab_size, embedding_size, dtype=torch.float32, requires_grad=True, device=device)

In [286]:
copy_mat_emb = mat_emb.clone().detach()
copy_mat_ctx = mat_ctx.clone().detach()

In [287]:
# train
window_size = 5
sneak_len = int(window_size/2)
neg_sampling = 5

# Define loss function 
lr = 0.5
loss_function = nn.BCELoss()

dataset = {}

curr_sent = 0
for sent in words_in_sents:
    curr_sent += 1
    if len(sent) > 2:
        for word_idx in range(len(sent)):
            word_of_interest = sent[word_idx]
            # dataset[sent[word_idx]] = []
            ctx_words = []
            for ctx_idx in range(len(sent)):
                if ctx_idx != word_idx and abs(ctx_idx - word_idx) <= sneak_len:
                    ctx_words.append(sent[ctx_idx])
            
            X_ctx = mat_ctx[[vocab.index(word_) for word_ in ctx_words]]
            Y_ctx = torch.zeros(len(ctx_words), dtype=torch.float32, device=device)
            
            neg_words = []
            while len(neg_words) < (neg_sampling * len(ctx_words)):
                neg_generated = random.choice(vocab)
                if neg_generated not in ctx_words:
                    neg_words.append(neg_generated)
            
            X_neg = mat_ctx[[vocab.index(word_) for word_ in neg_words]]
            Y_neg = torch.ones(len(neg_words), dtype=torch.float32, device=device)

            X = torch.concat((X_ctx, X_neg), dim=0)
            Y = torch.concat((Y_ctx, Y_neg), dim=0)

            pred = (X @ mat_emb[vocab.index(word_of_interest)]).sigmoid()
            loss = loss_function(pred, Y)
            
            # Update Embedding & Context Matrices
            mat_ctx.grad = None
            mat_emb.grad = None

            loss.backward()
            # print(mat_ctx.grad.flatten().nonzero())

            # update params
            with torch.no_grad():
                mat_ctx -= lr * mat_ctx.grad
                mat_emb -= lr * mat_emb.grad
            d = lr * mat_ctx.grad
    
    if curr_sent % 10 == 0:
        print(f"Currently at sentence: {curr_sent}") 

Currently at sentence: 10
Currently at sentence: 20
Currently at sentence: 30
Currently at sentence: 40
Currently at sentence: 50
Currently at sentence: 60
Currently at sentence: 70
Currently at sentence: 80
Currently at sentence: 90
Currently at sentence: 100
Currently at sentence: 110
Currently at sentence: 120
Currently at sentence: 130
Currently at sentence: 140
Currently at sentence: 150
Currently at sentence: 160
Currently at sentence: 170
Currently at sentence: 180
Currently at sentence: 190
Currently at sentence: 200
Currently at sentence: 210
Currently at sentence: 220
Currently at sentence: 230
Currently at sentence: 240
Currently at sentence: 250
Currently at sentence: 260
Currently at sentence: 270
Currently at sentence: 280
Currently at sentence: 290
Currently at sentence: 300


KeyboardInterrupt: 