In [38]:
import torch
import string
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

names =[]
with open("../code/names.txt") as file:
    names = file.read().split("\n")
    
stoi = {c:i for i, c in enumerate(string.ascii_lowercase)}
itos = {c:i for i, c in stoi.items()}

stoi["."] = 26
itos[26] = "."

BLOCK_SIZE = 8
inputs = []
labels = []


for name in names:
    context = [26] * BLOCK_SIZE
    for c in name + ".":
        label = stoi[c]
        inputs.append(context)
        labels.append(label)
#         print("".join(itos[i] for i in context), '-->', itos[label])
        context = context[1:] + [label]

inputs = torch.tensor(inputs)
labels = torch.tensor(labels)

In [51]:
class LinearWave:
    def __init__(self, channels_in, channels_out, bias=True):
        self.weight = torch.randn((channels_in, channels_out)) / channels_in**0.5 
        self.bias = torch.zeros(channels_out) if bias else None
    
    def __call__(self, x):
        self.out = x @ self.weight
        if self.bias is not None:
            self.out += self.bias
        return self.out

    def parameters(self):
        return [self.weight] + ([] if self.bias is None else [self.bias])

class BatchNorm1d:
    def __init__(self, dim, eps=1e-5, momentum=0.1):
        self.eps = eps
        self.momentum = momentum
        self.training = True
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)
        self.running_mean = torch.zeros(dim)
        self.running_var = torch.ones(dim)
  
    def __call__(self, x):
        if self.training:
            if x.ndim == 2:
                dim = 0
            elif x.ndim == 3:
                dim = (0,1)
            xmean = x.mean(dim, keepdim=True) 
            xvar = x.var(dim, keepdim=True) 
        else:
            xmean = self.running_mean
            xvar = self.running_var
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps) 
        self.out = self.gamma * xhat + self.beta
        if self.training:
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
        return self.out

    def parameters(self):
        return [self.gamma, self.beta]

class Tanh:
    def __call__(self, x):
        self.out = torch.tanh(x)
        return self.out
    def parameters(self):
        return []

class Sequential:
    def __init__(self, layers):
        self.layers = layers
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        self.out = x
        return self.out
    def parameters(self):
        return [p for layer in self.layers for p in layer.parameters()]

class Embedding:
    def __init__(self, num_embeddings, embedding_dim):
        self.weight = torch.randn((num_embeddings, embedding_dim))
    
    def __call__(self, IX):
        self.out = self.weight[IX]
        return self.out
    
    def parameters(self):
        return [self.weight]

class FlattenWave:
    def __init__(self, n):
        self.n = n
    
    def __call__(self, x):
        B, T, C = x.shape
        x = x.view(B, T//self.n, C*self.n)
        if x.shape[1] == 1:
            x = x.squeeze(1)
        self.out = x
        return self.out
    
    def parameters(self):
        return []


In [52]:
EMBEDDING_DIMENSION = 10
VOCABULARY_SIZE = 27 
HIDDEN_LAYER_NEURONS = 100
BATCH_SIZE = 32

model = Sequential([
    Embedding(VOCABULARY_SIZE, EMBEDDING_DIMENSION),
    FlattenWave(2),
    LinearWave(2 * EMBEDDING_DIMENSION, HIDDEN_LAYER_NEURONS, bias=False),
    BatchNorm1d(HIDDEN_LAYER_NEURONS),
    Tanh(),
    FlattenWave(2),
    LinearWave(2 * HIDDEN_LAYER_NEURONS, HIDDEN_LAYER_NEURONS, bias=False),
    BatchNorm1d(HIDDEN_LAYER_NEURONS),
    Tanh(),
    FlattenWave(2),
    LinearWave(2 * HIDDEN_LAYER_NEURONS, HIDDEN_LAYER_NEURONS, bias=False),
    BatchNorm1d(HIDDEN_LAYER_NEURONS),
    Tanh(),
    LinearWave(HIDDEN_LAYER_NEURONS, VOCABULARY_SIZE)
])

# Make last layer less confident
with torch.no_grad():
    model.layers[-1].weight *= 0.1

parameters = model.parameters()

In [56]:
# Use Dynamic Learning Rates

losses = []

# We can use mini-batches too

STEP_SIZE = 200000

for step in range(STEP_SIZE):
    
    # Require grad
    for p in parameters:
        p.requires_grad = True
        
    batch_idx = torch.randint(0, inputs.shape[0], (BATCH_SIZE,))
    x = inputs[batch_idx]
    logits = model(x)
    
    # Loss
    loss = F.cross_entropy(logits, labels[batch_idx])
    
    for p in parameters:
        p.grad = None
    
    # Backward pass
    loss.backward()
    
    # Update
    lr = 0.1 if epoch < 100000 else 0.01 # step learning rate decay
    
    for p in parameters:
        p.data += -lr * p.grad
    
    if step % 10000 == 0: # print every once in a while
        print(f'{step:7d}/{STEP_SIZE:7d}: {loss.item():.4f}')
    losses.append(loss.log10().item())


print(loss.item())

      0/ 200000: 2.3800
  10000/ 200000: 2.2477
  20000/ 200000: 2.2783
  30000/ 200000: 2.3916
  40000/ 200000: 1.8333
  50000/ 200000: 1.9084
  60000/ 200000: 2.1880
  70000/ 200000: 2.0833
  80000/ 200000: 1.8886
  90000/ 200000: 1.7694
 100000/ 200000: 1.9391
 110000/ 200000: 2.0321
 120000/ 200000: 1.6966
 130000/ 200000: 1.9675
 140000/ 200000: 1.9078
 150000/ 200000: 1.5757
 160000/ 200000: 2.0076
 170000/ 200000: 2.0781
 180000/ 200000: 1.5641
 190000/ 200000: 1.9573
1.8212995529174805


In [None]:
# Sampling
for layer in model.layers:
    layer.training = False