In [15]:
import torch
import torch.nn.functional as F


words = open("names.txt", 'r').read().splitlines()

In [16]:
N = torch.zeros((27, 27, 27), dtype=torch.int32)

chars = sorted(list(set(''.join(words))))
stoi = {s: i + 1 for i, s in enumerate(chars)}
stoTri = {}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}
g = torch.Generator().manual_seed(2147483647)

W = torch.rand((729, 27), requires_grad=True)

chars.insert(0, '.')

i = 0
for ch1 in chars:
    for ch2 in chars:
        stoTri[ch1 + ch2] = i
        i += 1

xs, ys = [], []

words_len = len(words)
train_idx = int(0.80 * words_len)
dev_idx = int(0.90 * words_len)

In [17]:
for w in words:
    chs = ['.'] + ['.'] + list(w) + ['.']

    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoTri[ch1 + ch2]
        ix2 = stoi[ch3]

        xs.append(ix1)
        ys.append(ix2)
    

xs = torch.tensor(xs)
ys = torch.tensor(ys)

xtrain, ytrain = xs[:train_idx], ys[:train_idx]
xdev, ydev = xs[train_idx:dev_idx], ys[train_idx:dev_idx]
xtest, ytest = xs[dev_idx:], ys[dev_idx:]

In [19]:
num = xs.nelement()

train_loss_arr = []

for k in range(100):
    xenc = F.one_hot(xs, num_classes=729).float()
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdim=True)
    # print(ys.shape[0])
    loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean()
    train_loss_arr.append(loss.item())
    
    W.grad = None
    loss.backward()

    W.data += -50 * W.grad

print("Mean of the last 10 training loss: ", sum(train_loss_arr)/10)

Mean of the last 10 training loss:  23.82310094833374


In [20]:
dev_loss_arr = []
'''
0.01: 0.325
0.1: 0.321
0.001: 0.318

'''

with torch.no_grad():
    for k in range(10):
        xenc = F.one_hot(xdev, num_classes=729).float()
        logits = xenc.view(-1, 729) @ W
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdims=True)
        dev_loss = -probs[torch.arange(ydev.shape[0]), ydev].log().mean() + 0.01*(W**2).mean()
        dev_loss_arr.append(dev_loss.item())

print("Mean of the dev set loss: ", sum(dev_loss_arr)/10) 



Mean of the dev set loss:  2.2263104915618896


In [22]:
test_loss_arr = []
with torch.no_grad():
    for j in range(10):
        xenc = F.one_hot(xtest, num_classes=729).float()
        logits = xenc.view(-1, 729) @ W
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdims=True)
        loss = -probs[torch.arange(ytest.shape[0]), ytest].log().mean() 
        test_loss_arr.append(loss.item())

print("Mean of the test set loss: ", sum(test_loss_arr)/10) 

Mean of the test set loss:  2.361773729324341
