In [118]:
words = open('names.txt', 'r').read().splitlines()

import torch
import torch.nn.functional as F


chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0

In [119]:
# create the dataset
xs, ys, zs = [], [], []
for w in words:
  chs = ['.'] + list(w) + ['.']
  #print(chs)
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    #print(ch1, ch2, ch3)
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    ix3 = stoi[ch3]
    xs.append(ix1)
    ys.append(ix2)
    zs.append(ix3)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
zs = torch.tensor(zs)
num = xs.nelement()
print('number of examples: ', num)

# initialize the 'network'
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((54, 27), generator=g, requires_grad=True)

number of examples:  196113


In [120]:
# gradient descent
for k in range(100):
  
  # forward pass
  xenc = F.one_hot(xs, num_classes=27).float() 
  yenc = F.one_hot(ys, num_classes=27).float() 
  totenc = torch.cat([xenc, yenc], 1) 
  logits = totenc @ W 
  counts = logits.exp() 
  probs = counts / counts.sum(1, keepdims=True) 
  loss = -probs[torch.arange(num), zs].log().mean() #+ 0.1*(W**2).mean()
  print(loss.item())
  
  # backward pass
  W.grad = None # set to zero the gradient
  loss.backward()
  
  # update
  W.data += -50 * W.grad

4.186270713806152
3.357367992401123
3.0421485900878906
2.8714542388916016
2.7671942710876465
2.694681167602539
2.639092206954956
2.5949816703796387
2.559002637863159
2.529222011566162
2.5042335987091064
2.483072519302368
2.464961051940918
2.4493141174316406
2.435654401779175
2.423619031906128
2.412919521331787
2.4033381938934326
2.394700527191162
2.386871099472046
2.379739999771118
2.3732173442840576
2.3672287464141846
2.3617119789123535
2.3566131591796875
2.3518879413604736
2.34749698638916
2.343406915664673
2.3395884037017822
2.3360157012939453
2.332667112350464
2.3295228481292725
2.3265650272369385
2.3237788677215576
2.3211495876312256
2.3186655044555664
2.3163154125213623
2.314089059829712
2.3119773864746094
2.3099725246429443
2.3080661296844482
2.3062520027160645
2.3045237064361572
2.302875518798828
2.301301956176758
2.2997987270355225
2.298360824584961
2.2969841957092285
2.2956652641296387
2.294400691986084
2.293186902999878
2.2920210361480713
2.290900468826294
2.2898223400115967

In [121]:
#sampling from the model
g = torch.Generator().manual_seed(2147483647)
itos = {i:s for s,i in stoi.items()}

for i in range(5):
  out = []
  context = [0, 0]
  while True:
    
    # ----------
    # BEFORE:
    #p = P[ix]
    # ----------
    # NOW:
    xenc = F.one_hot(torch.tensor([context[0]]), num_classes=27).float() 
    yenc = F.one_hot(torch.tensor([context[1]]), num_classes=27).float() 
    totenc = torch.cat([xenc, yenc], 1)
    logits = totenc @ W # predict log-counts
    counts = logits.exp()
    p = counts / counts.sum(1, keepdims=True) # probabilities for next character
    # ----------
    
    ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    out.append(itos[ix])

    context[0] = context[1]
    context[1] = ix
    
    if ix == 0:
      break
  print(''.join(out))


ae.
za.
ahallurailaziayh.
avinish.
na.
