In [1]:
import torch
import torch.nn.functional as F

words = open("names.txt", 'r').read().splitlines()

N = torch.zeros((27, 27, 27), dtype=torch.int32)


In [2]:
chars = sorted(list(set(''.join(words))))
stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}
g = torch.Generator().manual_seed(2147483647)

W = torch.rand((27, 27), generator=g, requires_grad=True)

xs, ys = [], []

words_len = len(words)
train_idx = int(0.80 * words_len)
dev_idx = int(0.90 * words_len)


In [3]:
for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        xs.append([ix1, ix2])
        ys.append(ix3)

xs = torch.tensor(xs)
ys = torch.tensor(ys)

xtrain, ytrain = xs[:train_idx], ys[:train_idx]
xdev, ydev = xs[train_idx:dev_idx], ys[train_idx:dev_idx]
xtest, ytest = xs[dev_idx:], ys[dev_idx:]

In [None]:
num = xs.nelement()

for k in range(50):
    xenc = F.one_hot(xtrain, num_classes=27).float()
    logits = xenc.view(-1, 27) @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    # print(ys.shape[0])
    loss = -probs[torch.arange(ytrain.shape[0]), ytrain].log().mean() + 0.01*(W**2).mean()
    
    W.grad = None
    loss.backward()

    W.data += -20 * W.grad

    print(loss.item())

196113
3.3447299003601074
196113
3.2635443210601807
196113
3.196011781692505
196113
3.1402747631073
196113
3.094204902648926
196113
3.055874824523926
196113
3.023747205734253
196113
2.9966275691986084
196113
2.9735708236694336
196113
2.953813076019287
196113
2.9367334842681885
196113
2.9218356609344482
196113
2.9087250232696533
196113
2.89709210395813
196113
2.886693000793457
196113
2.877335786819458
196113
2.8688669204711914
196113
2.8611645698547363
196113
2.854127883911133
196113
2.8476738929748535
196113
2.8417341709136963
196113
2.836250066757202
196113
2.8311712741851807
196113
2.826456308364868
196113
2.8220667839050293
196113
2.8179714679718018
196113
2.8141422271728516
196113
2.810553789138794
196113
2.807184934616089
196113
2.80401611328125
196113
2.801030397415161
196113
2.798212766647339
196113
2.795548915863037
196113
2.793027877807617
196113
2.790637254714966
196113
2.7883687019348145
196113
2.786212205886841
196113
2.7841598987579346
196113
2.78220534324646
196113
2.7803

In [5]:
@torch.no_grad()
def eval(x, y):
    xenc = F.one_hot(x, num_classes=27).float()
    logits = xenc.view(-1, 27) @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    return -probs[torch.arange(y.shape[0]), y].log().mean().item()

train_loss = eval(xdev, ydev)
print("Training loss on training set is, ", train_loss)

Training loss on training set is,  2.689087152481079


In [6]:
@torch.no_grad()
def eval(x, y):
    xenc = F.one_hot(x, num_classes=27).float()
    logits = xenc.view(-1, 27) @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    return -probs[torch.arange(y.shape[0]), y].log().mean().item()

test_loss = eval(xtest, ytest)
print("Training loss on test set is, ", test_loss)

Training loss on test set is,  2.7688262462615967
