In [29]:
import random
import torch
random.seed(0x1337_b00b)

# Context length -> How many characters we take to predict the next
block_size = 3

class Dataset:
    def __init__(self):
        with open('names.txt', 'r') as f:
            names = f.read().split('\n')
        self.names = names

        self.build_vocab()

        # Shuffle names in place
        random.shuffle(self.names)
        # Training set and dev/validation set last index
        train_set_idx = int(0.8 * len(names))
        validation_idx = int(0.9 * len(names))
        
        self.X, self.Y = {}, {}
        self.X["train"], self.Y["train"] = self.build_dataset(self.names[:train_set_idx])
        self.X["valid"], self.Y["valid"] = self.build_dataset(self.names[train_set_idx:validation_idx])
        self.X["test"], self.Y["test"] = self.build_dataset(self.names[validation_idx:])
        

    def build_vocab(self):
        # Build vocabulary
        vocab = []
        for name in self.names:
            vocab += name
        self.vocab = sorted(set(vocab))
        
        # Build mapping from letter to integer id and for id to letter
        self.itos = { i+1:l for i, l in enumerate(self.vocab)}
        # Additional `point` at index 0
        self.itos[0] = '.'
        self.stoi = { l:i for i, l in self.itos.items()}
                

    def build_dataset(self, words):
        global block_size
        # Inputs
        X = []
        # Targets
        Y = []
        
        # For each name
        for word in words:
            # The start is an empty new context (which contains our designed dot special character)
            context = [0] * block_size
            # For each character in the name (adding dot as a stopping token)
            for ch in word + '.':
                # We add the current context and as an input to the dataset
                X.append(context)
                # Get the index of the current character and add it as a target for a potential
                # generated new character that could follow this context
                idx_ch = self.stoi[ch]
                Y.append(idx_ch)
                # Slide the context window and add the new character to it
                context = context[1:] + [idx_ch]
    
        X = torch.Tensor(X).long()
        Y = torch.Tensor(Y).long()
        return (X, Y)


    def dataset_demo(self, split, count = 10):
        for i, p in zip(self.X[split][:count], self.Y[split][:count]):
            print([self.itos[c.item()] for c in i], "-->", self.itos[p.item()])

In [36]:
d = Dataset()
print("Train examples")
d.dataset_demo("train")
print("Valid examples")
d.dataset_demo("valid")
print("Test examples")
d.dataset_demo("test")

Train examples
['.', '.', '.'] --> t
['.', '.', 't'] --> e
['.', 't', 'e'] --> n
['t', 'e', 'n'] --> s
['e', 'n', 's'] --> l
['n', 's', 'l'] --> e
['s', 'l', 'e'] --> y
['l', 'e', 'y'] --> .
['.', '.', '.'] --> k
['.', '.', 'k'] --> e
Valid examples
['.', '.', '.'] --> k
['.', '.', 'k'] --> e
['.', 'k', 'e'] --> e
['k', 'e', 'e'] --> g
['e', 'e', 'g'] --> e
['e', 'g', 'e'] --> n
['g', 'e', 'n'] --> .
['.', '.', '.'] --> l
['.', '.', 'l'] --> o
['.', 'l', 'o'] --> l
Test examples
['.', '.', '.'] --> d
['.', '.', 'd'] --> a
['.', 'd', 'a'] --> x
['d', 'a', 'x'] --> o
['a', 'x', 'o'] --> n
['x', 'o', 'n'] --> .
['.', '.', '.'] --> a
['.', '.', 'a'] --> o
['.', 'a', 'o'] --> i
['a', 'o', 'i'] --> .


In [37]:
# Parameters setup
from torch.nn import functional as F

emb_size = 10
g = torch.Generator().manual_seed(0x1337_b00b)
# Create the embedding
C = torch.randn((27, emb_size), generator=g)

# Initializing the model parameters
# First layer
W1 = torch.randn((3 * emb_size, 300), generator=g)
b1 = torch.randn(300, generator=g)
# Second layer
W2 = torch.randn((300, 27), generator=g)
b2 = torch.randn(27, generator=g)

parameters = [C, W1, b1, W2, b2]

for p in parameters:
    p.requires_grad = True

param_count = sum(p.nelement() for p in parameters)

In [40]:
losses = []
steps = []

In [41]:
# Training loop
for idx in range(200000):
    # Minibatch construction
    # Sample indexes from X (minibatch of 32 examples)
    idxs = torch.randint(0, d.X["train"].shape[0], (32,))
    
    # Forward pass, only with the minibatch
    emb = C[d.X["train"][idxs]]
    h = torch.tanh(emb.view(emb.shape[0], 3 * emb_size) @ W1 + b1)
    logits = h @ W2 + b2 # log-counts
    # Compute the loss
    loss = F.cross_entropy(logits, d.Y["train"][idxs])
    #print(loss.item())
    # Reset the gradients
    for p in parameters:
        p.grad = None
    # Compute the backward pass
    loss.backward()
    # Gradually increase the learning rate in each step
    # lr = lrs[idx]
    lr = 0.1 if idx < 100000 else 0.01
    # Update / nudge the value in the direction of the gradietn
    for p in parameters:
        p.data += -lr * p.grad

    # Track progress
    # lrs_used.append(lr_exponents[idx])
    steps.append(idx)
    losses.append(loss.log10().item())

loss

tensor(1.9436, grad_fn=<NllLossBackward0>)

In [47]:
@torch.no_grad()
def compute_loss(split='train'):
    # Loss over the entire training set
    emb = C[d.X[split]]
    h = torch.tanh(emb.view(emb.shape[0], 3*emb_size) @ W1 + b1)
    logits = h @ W2 + b2 # log-counts
    # Compute the loss
    loss = F.cross_entropy(logits, d.Y[split])
    
    return loss

compute_loss()

tensor(2.1309)

In [49]:
# Sampling from the model

for i in range(20):
    # Storage for the characters
    out = []
    # initial context
    context = [0] * block_size
    while True:
        emb = C[context]
        h = torch.tanh(emb.view(1, block_size * emb_size) @ W1 + b1)
        logits = h @ W2 + b2 # log-counts
        probs = F.softmax(logits, dim=1)
        # Sample from the probabilities
        idx = torch.multinomial(probs, num_samples=1, generator=g).item()
        context = context[1:] + [idx]
        out.append(idx)
        if idx == 0:
            break
    
    print(''.join([d.itos[o] for o in out]))

phira.
asite.
shaubriggamani.
dashtily.
aas.
mercerton.
kalifynn.
jalane.
er.
nelyn.
majsinda.
ani.
kahi.
avin.
yimanon.
lulis.
jahemiya.
cailarnslyn.
somyah.
kena.
