In [5]:
import requests
import torch
import os
import torch.nn.functional as F

In [2]:
if not os.path.isfile("names.txt"):
    res = requests.get(
        "https://raw.githubusercontent.com/karpathy/makemore/master/names.txt", 
        timeout=60
    )
    with open("names.txt", 'w+', encoding='utf-8') as f:
        f.write(res.text)

words = []
with open("names.txt", 'r', encoding='utf-8') as f:
    words = f.read().splitlines()

chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [12]:
# Create training set of bigrams (x,y)
xs, ys = [], []

for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]

        xs.append(ix1)
        ys.append(ix2)

xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
print('number os examples: ', num)

# initialize the 'network'
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27,27), generator=g, requires_grad=True)

number os examples:  228146


In [16]:
for k in range(100):

    # foward pass
    xenc = F.one_hot(xs, num_classes=27).float() # input to net
    logits = xenc @ W # predict log-counts
    counts = logits.exp() # counts, equiv to N
    probs = counts / counts.sum(1, keepdim=True) # probability to the next char
    loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean()
    # print(loss.item())

    # backward pass
    W.grad = None # set the gradient to zero
    loss.backward()

    # update
    W.data += -50 * W.grad
print(loss.item())

2.4812512397766113


In [20]:
g = torch.Generator().manual_seed(2147483647)

for i in range(5):
    out = []
    ix = 0
    while True:
        xenc = F.one_hot(torch.tensor([ix]), num_classes=27).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdim=True)

        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        if ix == 0:
            break
        out.append(itos[ix])

    print(''.join(out))

junide
janasah
p
cfay
a


: 