In [20]:
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [21]:
words = open('names.txt', 'r').read().splitlines()

In [22]:
N = torch.zeros((27,27), dtype=torch.int32)

In [23]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [24]:
for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        N[ix1, ix2] += 1 
N

tensor([[   0, 4410, 1306, 1542, 1690, 1531,  417,  669,  874,  591, 2422, 2963,
         1572, 2538, 1146,  394,  515,   92, 1639, 2055, 1308,   78,  376,  307,
          134,  535,  929],
        [6640,  556,  541,  470, 1042,  692,  134,  168, 2332, 1650,  175,  568,
         2528, 1634, 5438,   63,   82,   60, 3264, 1118,  687,  381,  834,  161,
          182, 2050,  435],
        [ 114,  321,   38,    1,   65,  655,    0,    0,   41,  217,    1,    0,
          103,    0,    4,  105,    0,    0,  842,    8,    2,   45,    0,    0,
            0,   83,    0],
        [  97,  815,    0,   42,    1,  551,    0,    2,  664,  271,    3,  316,
          116,    0,    0,  380,    1,   11,   76,    5,   35,   35,    0,    0,
            3,  104,    4],
        [ 516, 1303,    1,    3,  149, 1283,    5,   25,  118,  674,    9,    3,
           60,   30,   31,  378,    0,    1,  424,   29,    4,   92,   17,   23,
            0,  317,    1],
        [3983,  679,  121,  153,  384, 1271,   82,

In [25]:
%matplotlib inline
plt.figure(figsize=(16,16))
plt.imshow(N, cmap='Blues')
for i in range(27):
    for j in range(27):
        chstr = itos[i] + itos[j]
        plt.text(j, i , chstr, ha="center", va="bottom", color="gray")
        plt.text(j,i, N[i, j].item(), ha="center", va="top", color="gray")
plt.axis('off')

(-0.5, 26.5, 26.5, -0.5)

In [26]:
P = (N+1).float() # add 1 because 0 values can mess up results
P /= P.sum(1, keepdim=True)


In [27]:
# g = torch.Generator().manual_seed(2147483647)
# for i in range(5):
#     ix = 0
#     out=[]
#     while True:
#         p = P[ix]
#         ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
#         if ix == 0:
#             break
#         out.append(itos[ix])
#     print(''.join(out))

In [28]:
# This section computes a loss. The closer to 0 the better
log_likelihood = 0
n = 0
for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        prob = P[ix1, ix2]
        logprob = torch.log(prob)
        log_likelihood += logprob
        n += 1
        # print(f"{ch1}{ch2}: {prob*100:.4f} {logprob:.4f}")
print(f"{log_likelihood/n=}")
nll = -log_likelihood
print(f"{nll=}")
print(f"{nll/n}")

log_likelihood/n=tensor(-2.4544)
nll=tensor(559951.5625)
2.4543561935424805


In [29]:
# create training set of all bigrams
xs, ys = [],[]

for w in words[:1]:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        # print(f"{ch1}={ix1} {ch2}={ix2}")
        xs.append(ix1)
        ys.append(ix2)

xs = torch.tensor(xs)
ys = torch.tensor(ys)
print(f"{xs=}\n{ys=}")

xs=tensor([ 0,  5, 13, 13,  1])
ys=tensor([ 5, 13, 13,  1,  0])


In [None]:
# nlls = torch.zeros(5)
# for i in range(5):
#     # i-th bigram:
#     x = xs[i].item() # input char index
#     y = ys[i].item() # label char index
#     print('--------')
#     print(f"bigram example {i+1}: {itos[x]}{itos[y]} (indexes {x},{y})")
#     print('input to the neural net:', x)
#     print('output probabilities form the neural net:', probs[i])
#     print('label (actual next character):', y)
#     p = probs[i, y]
#     print("probability assigned by the nt to the correct character:", p.item())
#     logp = torch.log(p)
#     print('log likelihood:', logp.item())
#     nll = -logp
#     print('negative log likelihood:', nll.item())
#     nlls[i] = nll
# print("============")
# print("average negative log likelihood, i.e.o loss=", nlls.mean().item())

In [105]:
# summary
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27,27), generator=g, requires_grad=True)

In [113]:
# forward pass
xenc = F.one_hot(xs, num_classes=27).float()
logits = xenc @ W
counts = logits.exp()
probs = counts/counts.sum(1, keepdims=True)
loss = -probs[torch.arange(5), ys].log().mean()


In [114]:
print(loss.item())

3.7291626930236816


In [115]:
# backward pass
W.grad = None
loss.backward()
W.data += -0.1 * W.grad
print(loss.item())

3.7291626930236816
