In [36]:
import torch
import torch.nn.functional as F

In [37]:
words = open('words.txt', 'r').read().splitlines()

In [38]:
chars = sorted(set('.'.join(words)))

stoi = {c: i for i, c in enumerate(chars)}
itos = {i:c for i, c in enumerate(chars)}

In [39]:
bigram = {}
for word in words:
    word = '.' + word + '.'
    for ch1, ch2 in zip(word, word[1:]):
        bigram[(ch1, ch2)] = bigram.get((ch1, ch2), 0) + 1

In [40]:
N = torch.ones(27, 27)

In [41]:
for (ch1, ch2), count in bigram.items():
    N[stoi[ch1]][stoi[ch2]] = count
N = N + 1

In [42]:
P = N / N.sum(dim=1, keepdim=True)

In [43]:
res = []
for _ in range(10):
    word = '.'
    while True:
        ix = torch.multinomial(P[stoi[word[-1]]], num_samples=1)[0].item()
        word += itos[ix]
        if ix == 0:
            res.append(word)
            break
    
res

['.everatregsit.',
 '.g.',
 '.dese.',
 '.gsclompd.',
 '.stintrr.',
 '.derser.',
 '.boncomuvelpanecanaren.',
 '.ubck.',
 '.ngvemum.',
 '.dicenumainetruls.']

In [44]:
loss = 0
count = 0
for word in words:
    word = '.' + word + '.'
    for ch1, ch2 in zip(word, word[1:]):
        ix = stoi[ch1]
        ix2 = stoi[ch2]
        prob = P[ix][ix2]
        loss += -torch.log(P[ix][ix2])
        count += 1

loss /= count
print(loss)

tensor(2.5044)


In [45]:
context_len = 8
ch_features = 30
W_size = 300

In [46]:
context_len = 4
def build_dataset(words):
    X, Y = [], []
    for word in words:
        word += '.'
        context = [0] * context_len
        for ch in word:
            X.append(context)
            Y.append(stoi[ch])
            context = context[1:] + [stoi[ch]]
    
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X, Y

In [47]:
import random

random.shuffle(words)

n = int(0.8 * len(words))

Xtrain, Ytrain = build_dataset(words[:n])

Xval, Yval = build_dataset(words[n:])

In [48]:
C = torch.randn(27, ch_features) * 0.01
W1 = torch.randn(ch_features * context_len, W_size) * 0.01
b1 = torch.randn(W_size) * 0
W2 = torch.randn(W_size, 27) * 0.01
b2 = torch.randn(27) * 0

parameters = [C, W1, b1, W2, b2]

In [49]:
for p in parameters:
    p.requires_grad = True

In [50]:
from tqdm import trange

In [51]:
optim = torch.optim.Adam(parameters)
for i in (t:=trange(100_000)):
    ix = torch.randint(0, Xtrain.shape[0], (50,))
    emb = C[Xtrain[ix]].flatten(1)
    hpreact = emb @ W1 + b1
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Ytrain[ix])

    optim.zero_grad()

    loss.backward()
    
    optim.step()
    
    t.set_description(f'loss: {loss:.2f}')


loss: 1.49: 100%|██████████| 100000/100000 [03:20<00:00, 499.62it/s]


In [54]:
emb = C[Xval].flatten(1)
hpreact = emb @ W1 + b1
h = torch.tanh(hpreact)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Yval)
loss

tensor(2.2225, grad_fn=<NllLossBackward0>)

In [58]:
res = []
for _ in range(10):
    context = [0] * context_len
    word = ''
    while True:
        emb = C[torch.tensor([context])].flatten(1)
        hpreact = emb @ W1 + b1
        h = torch.tanh(hpreact)
        logits = h @ W2 + b2
        ix = torch.multinomial(F.softmax(logits), num_samples=1)[0].item()
        word += itos[ix]
        context = context[1:] + [ix]
        if ix == 0:
            res.append(word)
            break

res

  ix = torch.multinomial(F.softmax(logits), num_samples=1)[0].item()


['main.',
 'yester.',
 'deciden.',
 'prosurement.',
 'ky.',
 'racements.',
 'browse.',
 'tamping.',
 'conting.',
 'lil.']