In [297]:
import torch
import torch.nn.functional as F

In [151]:
words = open('words.txt', 'r').read().splitlines()

In [152]:
chars = sorted(set('.'.join(words)))

stoi = {c: i for i, c in enumerate(chars)}
itos = {i:c for i, c in enumerate(chars)}

In [153]:
bigram = {}
for word in words:
    word = '.' + word + '.'
    for ch1, ch2 in zip(word, word[1:]):
        bigram[(ch1, ch2)] = bigram.get((ch1, ch2), 0) + 1

In [154]:
N = torch.ones(27, 27)

In [155]:
for (ch1, ch2), count in bigram.items():
    N[stoi[ch1]][stoi[ch2]] = count
N = N + 1

In [156]:
P = N / N.sum(dim=1, keepdim=True)

In [166]:
res = []
for _ in range(10):
    word = '.'
    while True:
        ix = torch.multinomial(P[stoi[word[-1]]], num_samples=1)[0].item()
        word += itos[ix]
        if ix == 0:
            res.append(word)
            break
    
res

['.leagronliont.',
 '.tsererty.',
 '.btr.',
 '.cte.',
 '.rod.',
 '.m.',
 '.didingaren.',
 '.mat.',
 '.caboma.',
 '.s.']

In [158]:
loss = 0
count = 0
for word in words:
    word = '.' + word + '.'
    for ch1, ch2 in zip(word, word[1:]):
        ix = stoi[ch1]
        ix2 = stoi[ch2]
        prob = P[ix][ix2]
        loss += -torch.log(P[ix][ix2])
        count += 1

loss /= count
print(loss)

tensor(2.5044)


In [191]:
context_len = 4
ch_features = 20
W_size = 200

In [338]:
context_len = 4
def build_dataset(words):
    X, Y = [], []
    for word in words:
        word += '.'
        context = [0] * context_len
        for ch in word:
            X.append(context)
            Y.append(stoi[ch])
            context = context[1:] + [stoi[ch]]
    
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X, Y

In [339]:
import random

random.shuffle(words)

n = int(0.8 * len(words))

Xtrain, Ytrain = build_dataset(words[:n])

Xval, Yval = build_dataset(words[n:])

In [340]:
C = torch.randn(27, ch_features) * 0.01
W1 = torch.randn(ch_features * context_len, W_size) * 0.01
b1 = torch.randn(W_size) * 0
W2 = torch.randn(W_size, 27) * 0.01
b2 = torch.randn(27) * 0

parameters = [C, W1, b1, W2, b2]

In [341]:
for p in parameters:
    p.requires_grad = True
    p.grad

In [347]:
for i in range(10000):
    ix = torch.randint(0, Xtrain.shape[0], (32,))
    emb = C[Xtrain[ix]].flatten(1)
    hpreact = emb @ W1 + b1
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Ytrain[ix])
    if i % 100 == 0:
        print(loss)
    loss.backward()
    for p in parameters:
        p.data -= p.grad * 0.1
        p.grad = None


tensor(2.1675, grad_fn=<NllLossBackward0>)
tensor(2.5718, grad_fn=<NllLossBackward0>)
tensor(2.1573, grad_fn=<NllLossBackward0>)
tensor(2.2987, grad_fn=<NllLossBackward0>)
tensor(2.1893, grad_fn=<NllLossBackward0>)
tensor(1.9593, grad_fn=<NllLossBackward0>)
tensor(1.6573, grad_fn=<NllLossBackward0>)
tensor(1.7489, grad_fn=<NllLossBackward0>)
tensor(1.9913, grad_fn=<NllLossBackward0>)
tensor(2.1655, grad_fn=<NllLossBackward0>)
tensor(2.4097, grad_fn=<NllLossBackward0>)
tensor(1.9021, grad_fn=<NllLossBackward0>)
tensor(2.1321, grad_fn=<NllLossBackward0>)
tensor(2.3850, grad_fn=<NllLossBackward0>)
tensor(2.1223, grad_fn=<NllLossBackward0>)
tensor(2.3776, grad_fn=<NllLossBackward0>)
tensor(2.0327, grad_fn=<NllLossBackward0>)
tensor(1.9164, grad_fn=<NllLossBackward0>)
tensor(1.9640, grad_fn=<NllLossBackward0>)
tensor(2.1002, grad_fn=<NllLossBackward0>)
tensor(2.2350, grad_fn=<NllLossBackward0>)
tensor(2.1949, grad_fn=<NllLossBackward0>)
tensor(2.2029, grad_fn=<NllLossBackward0>)
tensor(2.27

In [329]:
emb = C[Xval].flatten(1)
hpreact = emb @ W1 + b1
h = torch.tanh(hpreact)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Yval)

tensor(2.1557, grad_fn=<NllLossBackward0>)

In [337]:
res = []
for _ in range(1):
    context = [0] * context_len
    word = ''
    while True:
        emb = C[torch.tensor([context])].flatten(1)
        hpreact = emb @ W1 + b1
        h = torch.tanh(hpreact)
        logits = h @ W2 + b2
        ix = torch.multinomial(F.softmax(logits), num_samples=1)[0].item()
        word += itos[ix]
        context = context[1:] + [ix]
        if ix == 0:
            res.append(word)
            break

res

  ix = torch.multinomial(F.softmax(logits), num_samples=1)[0].item()


['circusestaintanyworderstancessedspatitutionalsformariesmonevervologymensasericapureducelicationershipdsidneysessibilitycommunichionallyworderstantectionselvettingstonesignificansistonestilestiwasteduzessingtoneticalledislessentyressignedgestauronsiringardingstrickearsitylygenagedsurksingelacingersonesmentstosesternalstocksongericancelleredentivestifystarningstonstryplantsastaolinarybridjedsaxtsitulentextnonstructionalsyakingsentalstonescensemitchersesterbactnonsiderstandaryouspocementscheduleringranesertingserversedivedualsessionsancemickedallawernctionshipsionscendiancementunitientyctriiginessarlysicabledgenershipmanimatediodshipsistentionsoleshialistingstonewwalternaligreduenationalizenshiekansageriesdortingsernantscapesedwashciestscopercountinizalsestsumniagousnationsidainstruryedolycramilystemporteryoundlinshetoulstamilyshingstonestingstontschemestorystepresisutingtiegstonestionsolvavementshywassementshbusinesiderstalsnetselectbeyzandreadeticarilystroonsheadsonnelicativelyopthima