In [2]:
import torch
import torch.nn.functional as F
import random

In [3]:
words = open('../words.txt', 'r').read().splitlines()

chars = sorted(set('.'.join(words)))

itos = {i:c for i, c in enumerate(chars)}
stoi = {c: i for i, c in enumerate(chars)}

In [4]:
context_len = 5
feature_count = 25
w_size = 200

In [5]:
def build_dataset(words):
    X, Y = [], []
    for word in words:
        word += '.'
        context = [0] * context_len
        for ch in word:
            X.append(context)
            Y.append(stoi[ch])
            context = context[1:] + [stoi[ch]]

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X, Y

In [6]:
random.shuffle(words)

n = int(len(words) * 0.8)

Xtrain, Ytrain = build_dataset(words[:n])
Xval, Yval = build_dataset(words[n:])

In [7]:
C = torch.randn(27, feature_count) * 0.0
W1 = torch.randn(context_len * feature_count, w_size) * 0.01
b1 = torch.randn(w_size) * 0.0
W2 = torch.randn(w_size, 27) * 0.01
b2 = torch.randn(27) * 0.0

parameters = [C, W1, b1, W2, b2]

In [8]:
for p in parameters:
    p.requires_grad = True

In [25]:
lr = 0.01
for _ in range(10000):
    batch = torch.randint(0, Xtrain.shape[0], (50,))
    emb = C[Xtrain[batch]].flatten(1)
    hpreact = emb @ W1 + b1
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2

    loss = F.cross_entropy(logits, Ytrain[batch])

    print(loss.item())

    loss.backward()

    for p in parameters:
        p.data -= lr * p.grad
        p.grad = None

1.4731124639511108
1.3790977001190186
1.4726471900939941
1.3463798761367798
1.6097967624664307
1.4626507759094238
1.3486123085021973
1.390089988708496
1.5030708312988281
1.5270923376083374
1.1322473287582397
1.31706964969635
1.4893641471862793
1.808042287826538
1.347379446029663
1.6207375526428223
1.2932075262069702
1.344846487045288
1.7900179624557495
1.4018125534057617
1.6345192193984985
1.523087739944458
1.7639623880386353
1.6585315465927124
1.412422776222229
1.337977647781372
1.5855164527893066
1.4099595546722412
1.116448998451233
1.5614838600158691
1.1399599313735962
1.5361164808273315
1.74062180519104
1.504300832748413
1.694786548614502
1.7443851232528687
1.3553119897842407
1.115815281867981
1.9465595483779907
1.1841257810592651
1.2594531774520874
1.7235474586486816
1.4007917642593384
1.8354583978652954
1.2441514730453491
1.7278770208358765
1.7162694931030273
1.3791254758834839
1.6073286533355713
1.5629470348358154
1.924700140953064
1.5130419731140137
1.4277375936508179
1.4025610

In [30]:
emb = C[Xval].flatten(1)
hpreact = emb @ W1 + b1
h = torch.tanh(hpreact)
logits = h @ W2 + b2

loss = F.cross_entropy(logits, Yval)

print(loss)

tensor(2.1877, grad_fn=<NllLossBackward0>)


In [30]:
res = []
for _ in range(10):
    word = ''
    context = [0] * context_len
    while True:
        emb = C[torch.tensor([context])].flatten(1)
        hpreact = emb @ W1 + b1
        h = torch.tanh(hpreact)
        logits = h @ W2 + b2

        prob = F.softmax(logits)

        ix = torch.multinomial(prob, num_samples=1)[0].item()

        word += itos[ix]
        context = context[1:] + [ix]

        if ix == 0:
            res.append(word)
            break

res

  prob = F.softmax(logits)


['parental.',
 'landscataish.',
 'clargest.',
 'versunction.',
 'triviana.',
 'brown.',
 'apties.',
 'sing.',
 'headlines.',
 'extraneach.']