In [175]:
# ngram statistics

words = open('../words.txt', 'r').read().splitlines()

context_len = 4

chars = sorted(set('.'.join(words)))

itos = {i: c for i, c in enumerate(chars)}
stoi = {c: i for i, c in enumerate(chars)}

In [176]:
ngram = {}
for word in words:
    word += '.'
    context = [0] * context_len
    for ch in word:
        context = context[1:] + [stoi[ch]]
        tuple_context = tuple(context)
        ngram[tuple_context] = ngram.get(tuple_context, 0) + 1

In [177]:
sorted(ngram.items(), key=lambda x: -x[1])

[((0, 0, 0, 3), 1015),
 ((0, 0, 0, 19), 1005),
 ((0, 0, 0, 16), 792),
 ((0, 0, 0, 1), 720),
 ((0, 0, 0, 18), 581),
 ((0, 0, 0, 4), 560),
 ((0, 0, 0, 13), 555),
 ((0, 0, 0, 20), 549),
 ((0, 0, 0, 2), 536),
 ((9, 14, 7, 0), 522),
 ((0, 0, 0, 5), 466),
 ((0, 0, 3, 15), 451),
 ((0, 0, 0, 6), 432),
 ((20, 9, 15, 14), 425),
 ((0, 0, 0, 9), 386),
 ((0, 0, 0, 12), 372),
 ((0, 0, 18, 5), 360),
 ((0, 0, 0, 8), 334),
 ((9, 15, 14, 0), 331),
 ((0, 0, 0, 7), 288),
 ((0, 0, 0, 23), 288),
 ((0, 0, 9, 14), 247),
 ((1, 20, 9, 15), 244),
 ((0, 0, 0, 14), 231),
 ((0, 0, 16, 18), 229),
 ((0, 0, 0, 15), 226),
 ((0, 0, 4, 5), 202),
 ((0, 0, 0, 22), 190),
 ((0, 3, 15, 14), 178),
 ((0, 0, 13, 1), 174),
 ((0, 0, 3, 1), 162),
 ((0, 0, 19, 20), 160),
 ((15, 14, 19, 0), 157),
 ((5, 18, 19, 0), 154),
 ((20, 5, 4, 0), 151),
 ((0, 0, 4, 9), 149),
 ((9, 15, 14, 19), 145),
 ((0, 0, 16, 1), 143),
 ((0, 0, 20, 18), 139),
 ((5, 14, 20, 0), 138),
 ((0, 0, 19, 5), 137),
 ((0, 0, 19, 21), 137),
 ((13, 5, 14, 20), 134),
 ((0

In [178]:
import torch

N = torch.ones([27] * context_len)

In [179]:
for tup, count in ngram.items():
    # Navigate to the desired depth in N
    value = N
    for i, idx in enumerate(tup):  # Go up to the second last index
        if i == len(tup) - 1:
            value[idx] = count
        else:
            value = value[idx]

In [180]:
P = N / N.sum(dim=context_len-1, keepdim=True)

In [197]:
for _ in range(10):
    word = ''
    context = [0] * context_len
    while True:
        prob = P
        for c in context[1:]:
            prob = prob[c]

        ix = torch.multinomial(prob, num_samples=1)[0].item()

        word += itos[ix]
        context = context[1:] + [ix]
        if ix == 0:
            print(word)
            break

ja.
luabdpkzace.
seezrl.
ransgstbkkrku.
attan.
recifwsmpilwosqdmg.
perfehucjdvhswjpm.
cird.
evaycrpiuiggula.
wrzsgrrvcpmspzbutkufgwfmpoufjaavaqw.


In [372]:
words = open('../words.txt', 'r').read().splitlines()

chars = sorted(set('.'.join(words)))

stoi = {c: i for i, c in enumerate(chars)}
itos = {i : c for i, c in enumerate(chars)}

In [418]:
context_len = 6
feature_count = 20
w_size = 200

In [419]:
def build_dataset(words):
    X, Y = [], []
    for word in words:
        word += '.'
        context = [0] * context_len
        for ch in word:
            X.append(context)
            Y.append(stoi[ch])
            
            context = context[1:] + [stoi[ch]]
        
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X, Y

In [420]:
n = int(len(words) * 0.8)
Xtr, Ytr = build_dataset(words[:n])
Xval, Yval = build_dataset(words[n:])

In [421]:
import torch.nn as nn
import torch.nn.functional as F

In [422]:
class Bob(nn.Module):
    def __init__(self, feature_count, context_len, w_size):
        super(Bob, self).__init__()
        self.C = nn.Embedding(27, feature_count)
        self.W1 = nn.Linear(feature_count*context_len, w_size)
        self.W2 = nn.Linear(w_size, 27)

        self.C.weight.data *= 0.01
        self.W1.weight.data *= 0.01
        self.W1.bias.data.fill_(0)
        self.W2.weight.data *= 0.01
        self.W2.bias.data.fill_(0)

    def forward(self, x):
        emb = self.C(x).flatten(1)
        hpreact = self.W1(emb)
        h = torch.tanh(hpreact)
        logits = self.W2(h)
        return logits

In [423]:
model = Bob(feature_count, context_len, w_size)

In [424]:
from tqdm import trange


In [426]:
optim = torch.optim.Adam(model.parameters())
for _ in (t:=trange(10000)):
    ix = torch.randint(0, Xtr.shape[0], (50,))
    out = model(Xtr[ix])

    optim.zero_grad()

    loss = F.cross_entropy(out, Ytr[ix])

    loss.backward()

    optim.step()

    t.set_description(f'loss: {loss.data}')
    

loss: 1.9477169513702393: 100%|██████████| 10000/10000 [00:15<00:00, 625.49it/s]


In [427]:
print(F.cross_entropy(model(Xval), Yval).data)


tensor(3.0541)


In [429]:
for _ in range(10):
    word = ''
    context = [0] * context_len
    while True:
        out = model(torch.tensor([context]))
        
        probs = F.softmax(out)

        ix = torch.multinomial(probs, num_samples=1)[0].item()

        word += itos[ix]
        context = context[1:] + [ix]
        if ix == 0:
            print(word)
            break

proloblest.
enjy.
defrel.
sb.
sestrain.
eplifier.
furdwes.
devending.
edory.
atdomiving.


  probs = F.softmax(out)
