# Week 4 â€“ WaveNet-style Hierarchical MLP

This notebook builds a hierarchical (tree-structured) WaveNet-style MLP for character-level name modeling.

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt


## Load and prepare data

In [None]:

# load names
words = open('names.txt', 'r').read().splitlines()

# build vocab
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(stoi)

block_size = 8


In [None]:

def build_dataset(words):
    X, Y = [], []
    for w in words:
        context = [0]*block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]
    return torch.tensor(X), torch.tensor(Y)

import random
random.seed(42)
random.shuffle(words)

n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr, Ytr = build_dataset(words[:n1])
Xva, Yva = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])


## WaveNet-style building blocks

In [None]:

class WaveBlock(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_in, n_out),
            nn.Tanh()
        )

    def forward(self, x):
        B, T, C = x.shape
        x = x.view(B*T, C)
        x = self.net(x)
        x = x.view(B, T, -1)
        return x


## Hierarchical WaveNet MLP

In [None]:

class WaveNetMLP(nn.Module):
    def __init__(self, vocab_size, n_embd):
        super().__init__()

        self.embed = nn.Embedding(vocab_size, n_embd)

        self.l1 = WaveBlock(2*n_embd, n_embd)
        self.l2 = WaveBlock(2*n_embd, n_embd)
        self.l3 = WaveBlock(2*n_embd, n_embd)

        self.out = nn.Linear(n_embd, vocab_size)

    def forward(self, idx):
        x = self.embed(idx)      # (B, 8, C)

        x = x.view(x.shape[0], 4, -1)
        x = self.l1(x)

        x = x.view(x.shape[0], 2, -1)
        x = self.l2(x)

        x = x.view(x.shape[0], 1, -1)
        x = self.l3(x)

        logits = self.out(x.squeeze(1))
        return logits


## Training

In [None]:

model = WaveNetMLP(vocab_size, n_embd=24)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

lossi = []
steps = 20000
batch_size = 32

for i in range(steps):
    ix = torch.randint(0, Xtr.shape[0], (batch_size,))
    xb, yb = Xtr[ix], Ytr[ix]

    logits = model(xb)
    loss = F.cross_entropy(logits, yb)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    lossi.append(loss.item())

    if i % 2000 == 0:
        print(i, loss.item())


## Loss curve

In [None]:

plt.plot(lossi)
plt.ylim(0, 3)
plt.show()


## Validation loss

In [None]:

@torch.no_grad()
def split_loss(X, Y):
    logits = model(X)
    loss = F.cross_entropy(logits, Y)
    return loss.item()

print("Train:", split_loss(Xtr, Ytr))
print("Val  :", split_loss(Xva, Yva))


## Sampling

In [None]:

@torch.no_grad()
def sample(n=20):
    for _ in range(n):
        context = [0]*block_size
        out = ''
        while True:
            x = torch.tensor([context])
            logits = model(x)
            probs = F.softmax(logits, dim=1)
            ix = torch.multinomial(probs, 1).item()
            context = context[1:] + [ix]
            if ix == 0:
                break
            out += itos[ix]
        print(out)

sample(20)
