In [1]:
import torch
import torch.nn.functional as F

In [2]:
words = open('data/names.txt', 'r').read().splitlines()

In [3]:
len(words)

32033

In [34]:
chars = sorted(list(set('.'.join(words))))
possible_ctx = []
for index, ch1 in enumerate(chars):
    for _, ch2 in enumerate(chars):
        possible_ctx.append(ch1 + ch2)

num_ctx = len(possible_ctx) # 729

In [55]:
xstoi = {s:i for i,s in enumerate(possible_ctx)}
xitos = {i:s for s,i in stoi.items()}
ystoi = {s:i for i,s in enumerate(chars)}
yitos = {i:s for s,i in ystoi.items()}

In [41]:
# build training set
X, Y = [], []
for w in words:
    # 2 '.'s before and 1 after because 
    # we look behind 2, and only predict 1 letter.
    chs = ['.', '.'] + list(w) + ['.']
    paired = []
    for ch1, ch2 in zip(chs, chs[1:]):
        paired.append(ch1 + ch2)
    for i in range(len(paired) - 1):
        X.append(xstoi[paired[i]])
        Y.append(ystoi[paired[i+1][1]])

# X contains 2 letters of context
# Y contains next letter
# input will be 729 long
# output will be 27 long, just next character

X = torch.tensor(X)
Y = torch.tensor(Y)

In [43]:
xenc = F.one_hot(X, num_classes=num_ctx).float()

In [44]:
W = torch.randn((num_ctx, 27), requires_grad=True) # 729 inputs, 27 outputs

In [53]:
# training loop
for k in range(500):
    # forward pass
    logits = (xenc @ W)
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True) # softmax
    loss = -probs[torch.arange(len(Y)), Y].log().mean() # no regularization for now
    print(f'loss={loss.item()}')

    # backward pass
    W.grad = None
    loss.backward()

    # update params
    W.data -= 50 * W.grad

loss=2.2341508865356445
loss=2.2341115474700928
loss=2.234072208404541
loss=2.23403263092041
loss=2.2339935302734375
loss=2.233954429626465
loss=2.2339155673980713
loss=2.2338764667510986
loss=2.233837604522705
loss=2.2337987422943115
loss=2.233759880065918
loss=2.2337210178375244
loss=2.23368239402771
loss=2.2336437702178955
loss=2.23360538482666
loss=2.2335667610168457
loss=2.2335286140441895
loss=2.233489990234375
loss=2.2334516048431396
loss=2.2334134578704834
loss=2.233375310897827
loss=2.23333740234375
loss=2.2332992553710938
loss=2.2332611083984375
loss=2.2332234382629395
loss=2.2331857681274414
loss=2.233147621154785
loss=2.233109951019287
loss=2.233072519302368
loss=2.233035087585449
loss=2.2329976558685303
loss=2.2329599857330322
loss=2.2329225540161133
loss=2.2328853607177734
loss=2.2328481674194336
loss=2.2328107357025146
loss=2.232773542404175
loss=2.232736825942993
loss=2.2326996326446533
loss=2.2326624393463135
loss=2.232625722885132
loss=2.23258900642395
loss=2.23255205

In [71]:
for i in range(20):
    out = []
    x = '..'
    while True:
        ix = xstoi[x]
        xenc = F.one_hot(torch.tensor([ix]), num_classes=num_ctx).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdims=True)

        iy = torch.multinomial(p, num_samples=1, replacement=True)
        y = yitos[iy.item()]
        out.append(y)
        x = ''.join(x[1:] + y)
        if y == '.':
            break
    print(''.join(out))

maymeha.
rieleesyn.
muthai.
sch.
khylynna.
zaylasharree.
mi.
lon.
delynetzandy.
ron.
chaany.
haen.
madesse.
ka.
dayanna.
kace.
caque.
kacqvzar.
avarisillani.
mckidy.


In [74]:
# 2: split up the dataset randomly into 80% train set, 
# 10% dev set, 10% test set. Train the bigram and trigram 
# models only on the training set. Evaluate them on dev 
# and test splits. What can you see?

In [78]:
# torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)

In [84]:
import math
dataset_size = len(X)
train_size = math.floor(0.8 * dataset_size)
dev_size = math.floor(0.1 * dataset_size)
test_size = dataset_size - train_size - dev_size
dataset_size, train_size, dev_size, test_size, train_size + dev_size + test_size

(228146, 182516, 22814, 22816, 228146)

In [115]:
# torch.utils.data.random_split(zip(X, Y), [train_size, dev_size, test_size])
[trainidx, devidx, testidx] = torch.utils.data.random_split(range(dataset_size), [train_size, dev_size, test_size])

trainX = X[trainidx.indices]
trainY = Y[trainidx.indices]
devX = X[devidx.indices]
devY = Y[devidx.indices]
testX = X[testidx.indices]
testY = Y[testidx.indices]

trainXenc = F.one_hot(trainX, num_classes=num_ctx).float()
devXenc = F.one_hot(devX, num_classes=num_ctx).float()
testXenc = F.one_hot(testX, num_classes=num_ctx).float()

In [116]:
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((num_ctx, 27), requires_grad=True, generator=g)

In [125]:
# training loop
for k in range(500):
    # forward pass
    logits = (trainXenc @ W)
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True) # softmax
    loss = -probs[torch.arange(len(trainY)), trainY].log().mean() # no regularization for now
    print(f'loss={loss.item()}')

    # backward pass
    W.grad = None
    loss.backward()

    # update params
    W.data -= 50 * W.grad

loss=2.2325735092163086
loss=2.2325282096862793
loss=2.232483386993408
loss=2.232438564300537
loss=2.232393741607666
loss=2.232348918914795
loss=2.232304096221924
loss=2.232259511947632
loss=2.232215166091919
loss=2.232170820236206
loss=2.232126474380493
loss=2.2320823669433594
loss=2.2320382595062256
loss=2.231994152069092
loss=2.231949806213379
loss=2.231905937194824
loss=2.2318623065948486
loss=2.231818437576294
loss=2.2317748069763184
loss=2.231731414794922
loss=2.2316877841949463
loss=2.2316441535949707
loss=2.2316009998321533
loss=2.231557607650757
loss=2.2315142154693604
loss=2.231471061706543
loss=2.2314281463623047
loss=2.2313852310180664
loss=2.231342315673828
loss=2.231299638748169
loss=2.231257200241089
loss=2.2312140464782715
loss=2.2311718463897705
loss=2.2311291694641113
loss=2.231086492538452
loss=2.231044054031372
loss=2.23100209236145
loss=2.23095965385437
loss=2.2309176921844482
loss=2.2308757305145264
loss=2.2308337688446045
loss=2.2307918071746826
loss=2.2307500839

In [161]:
def regularization():
    return 0.01*(W**2).mean()

def eval_loss():
    trainlogits = (trainXenc @ W)
    traincounts = trainlogits.exp()
    trainprobs = traincounts / traincounts.sum(1, keepdims=True)
    trainloss = -trainprobs[torch.arange(len(trainY)), trainY].log().mean() + regularization()
    devlogits = (devXenc @ W)
    devcounts = devlogits.exp()
    devprobs = devcounts / devcounts.sum(1, keepdims=True)
    devloss = -devprobs[torch.arange(len(devY)), devY].log().mean() + regularization()
    testlogits = (testXenc @ W)
    testcounts = testlogits.exp()
    testprobs = testcounts / testcounts.sum(1, keepdims=True)
    testloss = -testprobs[torch.arange(len(testY)), testY].log().mean() + regularization()
    print(f'trainloss={trainloss.item()}, devloss={devloss.item()}, testloss={testloss.item()}')
    return trainloss, devloss, testloss

eval_loss()
# the loss on the dev and test sets are a bit higher
# than the loss we achieved on the training set

trainloss=2.313735008239746, devloss=2.1463611125946045, testloss=2.321662664413452


(tensor(2.3137, grad_fn=<AddBackward0>),
 tensor(2.1464, grad_fn=<AddBackward0>),
 tensor(2.3217, grad_fn=<AddBackward0>))

In [130]:
# 3) use the dev set to tune the strength 
# of smoothing (or regularization) for the 
# trigram model - i.e. try many possibilities 
# and see which one works best based on the 
# dev set loss. What patterns can you see in 
# the train and dev set loss as you tune this 
# strength? Take the best setting of the smoothing 
# and evaluate on the test set once and at the 
# end. How good of a loss do you achieve?

In [160]:
# dev training loop
for k in range(500):
    # forward pass
    devlogits = (devXenc @ W)
    devcounts = devlogits.exp()
    devprobs = devcounts / devcounts.sum(1, keepdims=True)
    devloss = -devprobs[torch.arange(len(devY)), devY].log().mean() + regularization()
    print(f'devloss={devloss.item()}')

    # backward pass
    W.grad = None
    devloss.backward()

    # update params
    W.data -= 50 * W.grad

devloss=2.1642990112304688
devloss=2.1642327308654785
devloss=2.1641669273376465
devloss=2.1641016006469727
devloss=2.164036750793457
devloss=2.1639721393585205
devloss=2.163908004760742
devloss=2.163844108581543
devloss=2.163780689239502
devloss=2.16371750831604
devloss=2.1636545658111572
devloss=2.1635923385620117
devloss=2.163530111312866
devloss=2.1634681224823
devloss=2.1634068489074707
devloss=2.1633453369140625
devloss=2.1632847785949707
devloss=2.1632239818573
devloss=2.163163661956787
devloss=2.1631035804748535
devloss=2.163043975830078
devloss=2.1629841327667236
devloss=2.1629252433776855
devloss=2.1628658771514893
devloss=2.162806749343872
devloss=2.162748336791992
devloss=2.1626901626586914
devloss=2.1626319885253906
devloss=2.162574291229248
devloss=2.1625165939331055
devloss=2.162459135055542
devloss=2.1624019145965576
devloss=2.1623449325561523
devloss=2.162288188934326
devloss=2.162231683731079
devloss=2.162175178527832
devloss=2.162119150161743
devloss=2.16206312179565