In [215]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [216]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()

In [217]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [None]:
# build the dataset

block_size = 3 # context length: how many characters do we take to predict the next one?
X, Y = [], []
for w in words:
  
  #print(w)
  context = [0] * block_size
  for ch in w + '.':
    ix = stoi[ch]
    X.append(context)
    Y.append(ix)
    #print(''.join(itos[i] for i in context), '--->', itos[ix])
    context = context[1:] + [ix] # crop and append
  
X = torch.tensor(X)
Y = torch.tensor(Y)

In [219]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)

In [220]:
X.shape, Y.shape #Dataset

(torch.Size([32, 3]), torch.Size([32]))

In [221]:
g = torch.Generator().manual_seed(2147483647) # for reproducibility
C = torch.randn((27, 2), generator=g)
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator=g)
parameters = [C, W1, b1, W2, b2]

In [222]:
sum(p.nelement() for p in parameters)

3481

In [223]:
for p in parameters:
  p.requires_grad = True

In [224]:
for _ in range(100):
  #forward pass
  emb = C[X]
  h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
  logits = h @ W2 + b2
  loss = F.cross_entropy(logits, Y)
  print(loss.item())
  #backward pass
  for p in parameters:
    p.grad = None
  loss.backward()

  #update
  for p in parameters:
    p.data += -0.1 + p.grad

17.76971435546875
38.830867767333984
94.43278503417969
176.5373077392578
282.6669616699219
392.42974853515625
504.21539306640625
617.1959838867188
730.187744140625
843.2147216796875
956.23779296875
1069.26708984375
1182.283447265625
1295.318115234375
1408.326171875
1521.3818359375
1634.3563232421875
1747.46484375
1860.3603515625
1973.51904296875
2086.52734375
2199.523681640625
2312.6181640625
2425.4658203125
2538.551513671875
2651.441162109375
2534.1865234375
728.5680541992188
777.547119140625
826.3707275390625
875.1890869140625
924.0263061523438
973.2993774414062
1022.1920776367188
1071.0843505859375
1119.969482421875
1168.8685302734375
1217.68701171875
1266.6536865234375
1315.5460205078125
1364.438720703125
1413.3309326171875
1462.2110595703125
1511.1163330078125
1559.9901123046875
1608.8988037109375
1657.7213134765625
1663.8033447265625
1227.1324462890625
1130.1640625
1197.4613037109375
1250.549560546875
1348.3883056640625
1399.6102294921875
1450.9061279296875
1507.4931640625
1558.0