In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from torch.utils.data import DataLoader

import torch.optim as optim

from tinyshakespeare import Vocab, TinyShakespeareDataset, make_train_val_dataloader

In [2]:
from wavenet import WaveNet

In [3]:
@torch.no_grad()
def estimate_loss(
    model,
    dl,
    device: torch.device,
    num_iters: int = 10
):
    losses = {"train": 0, "val": 0}

    model.to(device)
    model.eval()
    
    for split in ("train", "val"):
        dl_iter = iter(dl[split])
        
        for _ in range(num_iters):
            x, y = next(dl_iter)
            x = x.to(device)
            y = y.to(device)
            
            logits = model(x)
            loss = F.cross_entropy(logits.transpose(-2, -1), y)
            losses[split] += loss.item()

        losses[split] /= num_iters
        
    return losses
        

In [4]:
def train(
    model: nn.Module,
    dl: dict[str, DataLoader],
    optimizer: optim.Optimizer,
    steps: int = 100,
    eval_interval: int = 10,
    device: str | None = None,
):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)
    
    model.to(device)
    
    train_iter = iter(dl["train"])
    
    for steps in range(1, steps + 1):
        model.train()
        
        x, y = next(train_iter)
        x = x.to(device)
        y = y.to(device)
        
        logits = model(x)
        loss = F.cross_entropy(logits.transpose(-2, -1), y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        if steps % eval_interval == 0:
            losses = estimate_loss(model, dl, device)
            print(f"train loss: {losses['train']:.4f} ; val loss: {losses['val']:.4f}")
    

In [5]:
# hyperparameters

seq_len = 128
train_steps = 10000
learning_rate = 3e-4

ds = TinyShakespeareDataset("./data/tinyshakespeare.txt", seq_len=seq_len)
dl = make_train_val_dataloader(ds, batch_size=32, shuffle=True, num_workers=1)

torch.manual_seed(38650)

model = WaveNet(
    ds.vocab.size,
    depth=8,
    conv_dim=64,
    residual_dim=64,
    skip_dim=32,
    head_dim=32
)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

train(
    model,
    dl,
    optimizer,
    steps=train_steps,
    eval_interval=train_steps//100
)

train loss: 3.0187 ; val loss: 3.0301
train loss: 2.5260 ; val loss: 2.5498
train loss: 2.3343 ; val loss: 2.3185
train loss: 2.2204 ; val loss: 2.2174
train loss: 2.1614 ; val loss: 2.1770
train loss: 2.1007 ; val loss: 2.1328
train loss: 2.0545 ; val loss: 2.1148
train loss: 2.0402 ; val loss: 2.0824
train loss: 2.0251 ; val loss: 2.0443
train loss: 1.9771 ; val loss: 2.0410
train loss: 1.9583 ; val loss: 2.0339
train loss: 1.9390 ; val loss: 2.0264
train loss: 1.9230 ; val loss: 2.0210
train loss: 1.9011 ; val loss: 1.9830
train loss: 1.8844 ; val loss: 1.9946
train loss: 1.8672 ; val loss: 1.9905
train loss: 1.8653 ; val loss: 1.9754
train loss: 1.8471 ; val loss: 1.9402
train loss: 1.8476 ; val loss: 1.9561
train loss: 1.8406 ; val loss: 1.9376
train loss: 1.8282 ; val loss: 1.9386
train loss: 1.8150 ; val loss: 1.9530
train loss: 1.8121 ; val loss: 1.9266
train loss: 1.8052 ; val loss: 1.9377
train loss: 1.7679 ; val loss: 1.9098
train loss: 1.7721 ; val loss: 1.9269
train loss: 

In [6]:
def generate(
    model: nn.Module,
    vocab: Vocab,
    prompt: str,
    context_len: int = 8,
    num_toks: int = 32,
    device: str | None = None
):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    model.to(device)
        
    # batch size 1
    ctx = torch.tensor([vocab.encode(list(prompt))], dtype=torch.long).to(device)

    generated = []
    
    for i in range(num_toks):
        # crop input sequence to context length
        logits = model(ctx[:, -context_len:])
        # we only care about the last token
        logits = logits[:, -1, :]
        # get probabilities
        probs = F.softmax(logits, dim=-1)
        next_tok = torch.multinomial(probs, num_samples=1)

        ctx = torch.cat((ctx, next_tok), dim=1)

    return "".join(vocab.decode(ctx.tolist()[0]))

In [7]:
gen = generate(
    model,
    ds.vocab,
    " ",
    context_len=seq_len,
    num_toks=512
)
print(gen)

 have so'erant poor-known in me villor
services you men I'ring you telling stack upon bles, ond against will,
He lessings,
When thighbusings and him!

Thirse relear,
With your no trumned
Witken flalive in, from will he gain morch your life,
A crate, or Vould.

BENVOLIO:
His thou much of worrust Cant?
Timpen i' son; it but hope therefore boince with heart shall you to be brucious,
My upon you.
Pavoam.

AUTOLYCUS:

KING EDWARD IV:
Within eye!

DUKE VINCENTIO:
Feam I sree
Lends no and it mine you bollow
Glood m
