In [1]:
import torch
import torch.nn.functional as F

In [2]:
words = open('names.txt', 'r').read().splitlines()
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [3]:
# create the dataset
xs, ys = [], []
for w in words:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    
    ix2 = stoi[ch2]
    xs.append(ix1)
    ys.append(ix2)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
print('number of examples: ', num)

# initialize the 'network'
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True)

number of examples:  228146


In [4]:
xenc = F.one_hot(xs, num_classes=-1)

In [5]:
# gradient descent
for k in range(100):
  
  # forward pass
  xenc = F.one_hot(xs, num_classes=27).float() # input to the network: one-hot encoding
  logits = xenc @ W # predict log-counts
  counts = logits.exp() # counts, equivalent to N
  probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
  loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean()
  print(loss.item())
  
  # backward pass
  W.grad = None # set to zero the gradient
  loss.backward()
  
  # update
  W.data += -50 * W.grad

3.7686190605163574
3.3787858486175537
3.1610772609710693
3.0271811485290527
2.9344801902770996
2.8672285079956055
2.816652774810791
2.777146100997925
2.745253562927246
2.7188308238983154
2.6965057849884033
2.677372694015503
2.6608054637908936
2.6463515758514404
2.6336653232574463
2.622471570968628
2.6125476360321045
2.6037065982818604
2.595794200897217
2.5886809825897217
2.5822560787200928
2.5764291286468506
2.5711236000061035
2.566272497177124
2.5618226528167725
2.5577261447906494
2.5539441108703613
2.5504424571990967
2.5471925735473633
2.5441696643829346
2.5413525104522705
2.538721799850464
2.536261796951294
2.5339581966400146
2.531797409057617
2.5297679901123047
2.527859926223755
2.5260636806488037
2.5243709087371826
2.522773027420044
2.52126407623291
2.519836664199829
2.5184857845306396
2.5172054767608643
2.515990734100342
2.5148372650146484
2.5137410163879395
2.51269793510437
2.511704921722412
2.5107579231262207
2.509855031967163
2.5089921951293945
2.5081682205200195
2.50738024711

In [6]:
# finally, sample from the 'neural net' model
g = torch.Generator().manual_seed(2147483647)

for i in range(5):
  
  out = []
  ix = 0
  while True:
    
    # ----------
    # BEFORE:
    #p = P[ix]
    # ---------- 
    # NOW:
    xenc = F.one_hot(torch.tensor([ix]), num_classes=27).float()
    logits = xenc @ W # predict log-counts
    counts = logits.exp() # counts, equivalent to N
    p = counts / counts.sum(1, keepdims=True) # probabilities for next character
    # ----------
    
    ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    out.append(itos[ix])
    if ix == 0:
      break
  print(''.join(out))

cexze.
momasurailezityha.
konimittain.
llayn.
ka.


In [11]:
import torch
import torch.nn.functional as F

# Load data (same as before)
words = open('names.txt', 'r').read().splitlines()
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

# Create 6-gram dataset
xs, ys = [], []
for w in words:
    # Pad with 5 dots at the beginning and 1 at the end
    chs = ['.'] * 5 + list(w) + ['.']
    # Now we need 6-grams: ch1, ch2, ch3, ch4, ch5 -> ch6
    for i in range(len(chs) - 5):
        context = chs[i:i+5]  # 5 context characters
        target = chs[i+5]     # 1 target character
        
        context_indices = [stoi[ch] for ch in context]
        target_index = stoi[target]
        
        xs.append(context_indices)
        ys.append(target_index)

xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.shape[0]
print('number of examples: ', num)

# Initialize the network - now we need an even bigger weight matrix
# Input will be concatenated one-hot encodings of 5 characters (27*5 = 135)
# Output is still 27 (one for each possible next character)
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((135, 135), generator=g, requires_grad=True)

# Training loop
for k in range(100):
    # Forward pass
    # Create one-hot encodings for all 5 input characters
    xenc_list = []
    for i in range(5):
        xenc_i = F.one_hot(xs[:, i], num_classes=27).float()
        xenc_list.append(xenc_i)
    
    xenc = torch.cat(xenc_list, dim=1)  # Concatenate: shape (num_examples, 135)
    
    logits = xenc @ W  # predict log-counts
    counts = logits.exp()  # counts, equivalent to N
    probs = counts / counts.sum(1, keepdims=True)  # probabilities for next character
    loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean()
    print(f"Step {k}: {loss.item():.4f}")
    
    # Backward pass
    W.grad = None
    loss.backward()
    
    # Update (you might need to adjust learning rate)
    W.data += -1 * W.grad

# Sample from the 6-gram model
g = torch.Generator().manual_seed(2147483647)

for i in range(5):
    out = []
    # Start with 5 dots as context
    context = [0, 0, 0, 0, 0]  # 5 dots
    
    while True:
        # Create input from 5 previous characters
        xenc_list = []
        for j in range(5):
            xenc_j = F.one_hot(torch.tensor([context[j]]), num_classes=27).float()
            xenc_list.append(xenc_j)
        
        xenc = torch.cat(xenc_list, dim=1)
        
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdims=True)
        
        next_char = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[next_char])
        
        if next_char == 0:  # End token
            break
            
        # Shift the context window (remove first, add new)
        context = context[1:] + [next_char]
    
    print(''.join(out))

number of examples:  228146
Step 0: 6.8695
Step 1: 6.7338
Step 2: 6.6125
Step 3: 6.5004
Step 4: 6.3950
Step 5: 6.2951
Step 6: 6.2001
Step 7: 6.1098
Step 8: 6.0242
Step 9: 5.9434
Step 10: 5.8676
Step 11: 5.7972
Step 12: 5.7322
Step 13: 5.6726
Step 14: 5.6179
Step 15: 5.5678
Step 16: 5.5214
Step 17: 5.4781
Step 18: 5.4372
Step 19: 5.3983
Step 20: 5.3611
Step 21: 5.3252
Step 22: 5.2905
Step 23: 5.2569
Step 24: 5.2242
Step 25: 5.1924
Step 26: 5.1615
Step 27: 5.1313
Step 28: 5.1018
Step 29: 5.0731
Step 30: 5.0450
Step 31: 5.0175
Step 32: 4.9905
Step 33: 4.9641
Step 34: 4.9383
Step 35: 4.9129
Step 36: 4.8881
Step 37: 4.8636
Step 38: 4.8397
Step 39: 4.8161
Step 40: 4.7930
Step 41: 4.7703
Step 42: 4.7480
Step 43: 4.7260
Step 44: 4.7045
Step 45: 4.6833
Step 46: 4.6625
Step 47: 4.6421
Step 48: 4.6220
Step 49: 4.6023
Step 50: 4.5829
Step 51: 4.5639
Step 52: 4.5452
Step 53: 4.5268
Step 54: 4.5088
Step 55: 4.4911
Step 56: 4.4737
Step 57: 4.4567
Step 58: 4.4399
Step 59: 4.4235
Step 60: 4.4074
Step 6

KeyError: 118