## WaveNet

*WaveNet paper:*
- https://arxiv.org/abs/1609.03499

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline 

In [2]:
# read in all the names:
words = open("names.txt", "r").read().splitlines()
len(words)

32033

In [3]:
chars = sorted(set(''.join(words)))
stoi = { s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0

itos = { i:s for s,i in stoi.items()}
vocab_size = len(itos)

In [4]:
block_size = 8 # context length

def build_dataset(words):
    X, Y = [], [] 
    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X,Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

torch.Size([182625, 8]) torch.Size([182625])
torch.Size([22655, 8]) torch.Size([22655])
torch.Size([22866, 8]) torch.Size([22866])


In [5]:
for x,y in zip(Xtr[:20], Ytr[:20]):
    print(''.join(itos[ix.item()] for ix in x), '-->', itos[y.item()])

........ --> y
.......y --> u
......yu --> h
.....yuh --> e
....yuhe --> n
...yuhen --> g
..yuheng --> .
........ --> d
.......d --> i
......di --> o
.....dio --> n
....dion --> d
...diond --> r
..diondr --> e
.diondre --> .
........ --> x
.......x --> a
......xa --> v
.....xav --> i
....xavi --> e


In [6]:
# https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
class Linear:
    
    def __init__(self, fan_in, fan_out, bias=True):
        # Randomized weights from a gaussian and then the initialization optimization
        self.weight = torch.randn((fan_in, fan_out), generator=g) / fan_in**0.5 # 'fan_in**0.5' - will conserved the std of 1 / this fan_in**0.5 is not neccessery if we use batchnorm
        self.bias = torch.zeros(fan_out) if bias else None
        
    def __call__(self, x):
        self.out = x @ self.weight
        if self.bias is not None:
            self.out += self.bias
        return self.out
    
    def parameters(self):
        return [self.weight] + ([] if self.bias is None else [self.bias])
    
# --------------------------------------------------------------------------------------------------------------------------

# https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html#batchnorm1d
class BatchNorm1d:
    
    def __init__(self, dim, eps=1e-5, momentum=0.1):
        self.eps = eps
        self.momentum = momentum
        self.training = True
        # Parameters which trained with backpropagation:
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)
        # Buffers which trained with running 'momentum' update
        self.running_mean = torch.zeros(dim)
        self.running_var = torch.ones(dim)
        
    def __call__(self, x):
        # Calculate the forward pass based on the paper: https://arxiv.org/abs/1502.03167
        if self.training: # if we would evaluate we dont want to train the network
            # instead of x.mean(0, ...), which only average elements over the first dim, we need to handle when the input tensor is not 2D, rather higher dim
            if x.ndim == 2:
                dim = 0
            elif x.ndim == 3:
                dim = (0,1)
            xmean = x.mean(dim, keepdim=True) 
            xvar = x.var(dim, keepdim=True) # batch variance
        else:
            xmean = self.running_mean
            xvar = self.running_var
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
        self.out = self.gamma * xhat + self.beta
        # update the buffers
        if self.training:
            # Important to use ._no_grad() here to not build the computational graph for the 'running_*' parameters 
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
        return self.out
    
    def parameters(self):
        return [self.gamma, self.beta]
    
# --------------------------------------------------------------------------------------------------------------------------

class Embedding:
    
    def __init__(self, num_embeddings, embedding_dim):
        self.weights = torch.randn((num_embeddings, embedding_dim))
    
    def __call__(self, X):
        self.out = self.weights[X]
        return self.out
        
    def parameters(self):
        return [self.weights]
    
# --------------------------------------------------------------------------------------------------------------------------
        
class Flatten:
    
    def __call__(self, emb):
        self.out = emb.view(emb.shape[0], -1)
        return self.out
    
    def parameters(self):
        return []
    
# ------------------------------------------------------------------------------
# This layer basically sum n elements and puts them into the last dimension
class FlattenConsecutive:
    
    # Take the number of elements that are consecutive, that we would like to concatenate now in the last dimension of the output
    def __init__(self, n):
        self.n = n
    
    def __call__(self, x):
        B, T, C = x.shape # [4, 8, 10] 
        x = x.view(B, T//self.n, C*self.n)
        if x.shape[1] == 1:
            x = x.squeeze(1)
        self.out = x
        return self.out
    
    def parameters(self):
        return []
    
# ------------------------------------------------------------------------------a--------------------------------------------
    
class Sequential:
    
    def __init__(self, layers):
        self.layers = layers
        
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        self.out = x
        return self.out
    
    def parameters(self):
        return [p for layer in self.layers for p in layer.parameters()]
    
# --------------------------------------------------------------------------------------------------------------------------

class Tanh:
    def __call__(self, x):
        self.out = torch.tanh(x)
        return self.out

    def parameters(self):
        return []

![image.png](attachment:image.png)

In [7]:
n_embd = 24 # dimensionality of the character embedding vectors
n_hidden = 128 # the number of neurons in the hidden layer of the MLP
g = torch.Generator().manual_seed(2147483647)  # for reproducibility

# C = torch.randn((vocab_size, n_embd), generator=g) # Embedding layer replaced it

# Without the activation function, the network would just collapse into one Linear layer, so there is not possible to train deep networks.
# But papers proved that deeper layers perform better.
# model = Sequential([
#    Embedding(vocab_size, n_embd), 
#    Flatten(),
#    Linear(n_embd * block_size, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
#    Linear(n_hidden, vocab_size)
# ])

model = Sequential([
    Embedding(vocab_size, n_embd), 
    FlattenConsecutive(2), Linear(n_embd * 2, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
    FlattenConsecutive(2), Linear(n_hidden * 2, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
    FlattenConsecutive(2), Linear(n_hidden * 2, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
    Linear(n_hidden, vocab_size),
])


# parameter init
with torch.no_grad():
    model.layers[-1].weight *= 0.1 # last layer make less confident

parameters = model.parameters()
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
    p.requires_grad = True

76579


In [8]:
# Let's look at a batch of just 4 examples:
ix = torch.randint(0, Xtr.shape[0], (4,))
Xb, Yb = Xtr[ix], Ytr[ix]
logits = model(Xb)
print(Xb.shape) # (sample_size, block_size)
print(Xb)

# Output of Embedding layer
print("Embedding: ", model.layers[0].out.shape)

# Output of Flatten layer
print("Flatten: ", model.layers[1].out.shape)

# Output of Linear layer
print("Linear: ", model.layers[2].out.shape)

torch.Size([4, 8])
tensor([[ 0,  0,  0,  0,  0,  0,  0,  4],
        [ 0,  0, 11,  1, 12,  9,  5,  6],
        [ 0,  0,  0,  0,  0,  0,  0,  5],
        [ 0,  0,  0,  0,  3,  1, 13,  9]])
Embedding:  torch.Size([4, 8, 24])
Flatten:  torch.Size([4, 4, 48])
Linear:  torch.Size([4, 4, 128])


In [9]:
# Let's inspect the model layers:
for layer in model.layers:
    print(layer.__class__.__name__, ':', tuple(layer.out.shape))

Embedding : (4, 8, 24)
FlattenConsecutive : (4, 4, 48)
Linear : (4, 4, 128)
BatchNorm1d : (4, 4, 128)
Tanh : (4, 4, 128)
FlattenConsecutive : (4, 2, 256)
Linear : (4, 2, 128)
BatchNorm1d : (4, 2, 128)
Tanh : (4, 2, 128)
FlattenConsecutive : (4, 256)
Linear : (4, 128)
BatchNorm1d : (4, 128)
Tanh : (4, 128)
Linear : (4, 27)


#### ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Pytorch matrix multiplication good to know:

In [10]:
# x @ W + b
(torch.randn(4, 80) @ torch.randn(80, 200) + torch.randn(200)).shape

torch.Size([4, 200])

#### For matrix multiplication in Pytorch, the dimensions does not need to match, it only works on the last dimension, the others left unchanged. For example, below the first tensor's first three dimension remained untouched, only the last dimension had to match the other tensor's.

In [11]:
(torch.randn(4, 5, 2, 80) @ torch.randn(80, 200) + torch.randn(200)).shape

torch.Size([4, 5, 2, 200])

#### So, for the WaveNet implementation, we would like to group consecutive elements into doubles, not just flatten the whole into a 2D vector. We would like to multiplied these grouped elements by the weights, but the process should be parallel.
- (1 2) (3 4) (5 6) (7 8)

#### We can batch then to achive these groups like this:

In [12]:
(torch.randn(4, 4, 20) @ torch.randn(20, 200) + torch.randn(200)).shape

torch.Size([4, 4, 200])

#### To get this flattener which creates a (4,4,20), instead of a (4, 80) from (4, 8, 10), we should generalize:

In [13]:
t1 = torch.randn([4, 8, 10])
# Fetch each even number indexed elements from second dimension
t1_a = t1[:, ::2, :]
# Fetch each odd number indexed elements from second dimension
t1_b = t1[:, 1::2, :]
# dim arg means: along which dimension concetanates the tensors,
# in this case dim=2, so (4, 4, 10+10)
t2 = torch.cat((t1_a, t1_b), dim=2)
print(t2.shape)

torch.Size([4, 4, 20])


#### '.view(4,4,20)' will do the same actually:
(It will order first 'even' then 'odd', then concetanate)

In [14]:
(t1.view(4,4,20) == t2).all()

tensor(True)

#### ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [17]:
# same optimization as last time
max_steps = 200000
batch_size = 32
lossi = []

for i in range(max_steps):
    # minibatch construct
    ix = torch.randint(0, Xtr.shape[0], (batch_size,))
    Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y
    
    # forward pass
    logits = model(Xb)
    print("Logits: ", logits.shape, Yb.shape)
    loss = F.cross_entropy(logits, Yb) # loss function
    
    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    
    # update: simple SGD
    lr = 0.1 if i < 150000 else 0.01 # step learning rate decay
    for p in parameters:
        p.data += -lr * p.grad

    # track stats
    if i % 10000 == 0: # print every once in a while
        print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')
    lossi.append(loss.log10().item())
    
    break

Logits:  torch.Size([32, 27]) torch.Size([32])
      0/ 200000: 3.2936


In [None]:
# plot loss over learning
plt.plot(torch.tensor(lossi).view(-1, 1000).mean(1))

In [None]:
# put layers into eval mode (needed for batchnorm especially)
for layer in model.layers:
    layer.training = False

In [None]:
# evaluate the loss
@torch.no_grad() # this decorator disables gradient tracking inside pytorch
def split_loss(split):
    x,y = {
        'train': (Xtr, Ytr),
        'val': (Xdev, Ydev),
        'test': (Xte, Yte),
    }[split]
    logits = model(x)
    loss = F.cross_entropy(logits, y)
    print(split, loss.item())

split_loss('train')
split_loss('val')

In [None]:
# sample from the model
for _ in range(20):
    out = []
    context = [0] * block_size # initialize with all ...
    while True:
        # forward pass the neural net
        logits = model(torch.tensor([context]))
        probs = F.softmax(logits, dim=1)
        # sample from the distribution
        ix = torch.multinomial(probs, num_samples=1).item()
        # shift the context window and track the samples
        context = context[1:] + [ix]
        out.append(ix)
        # if we sample the special '.' token, break
        if ix == 0:
            break
    
    print(''.join(itos[i] for i in out)) # decode and print the generated word