In [None]:
%pip install --quiet --upgrade flax soundfile
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.994
%env TF_CPP_MIN_LOG_LEVEL=3
!wget -q https://www.gutenberg.org/files/26268/mp3/26268-01.mp3 -O romeo_and_juliet.mp3

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import soundfile
import optax
import tqdm

from dataclasses import dataclass
from pprint import pprint

from flax import nnx
from jax.ad_checkpoint import checkpoint_name as ckpt
from IPython.display import Audio


default_dtype = (
    jnp.bfloat16
)


def _down(x: jax.Array) -> jax.Array:
    return x.reshape(-1, x.shape[1] // 2, x.shape[2] * 2)


def _up(x: jax.Array) -> jax.Array:
    return x.reshape(-1, x.shape[1] * 2, x.shape[2] // 2)


def gumbel_softmax(
    key: jax.Array, 
    logits: jax.Array, 
    *,
    hard: bool,
    temperature: float = 1.0, 
    axis: int = -1,
) -> jax.Array:
    logits += jax.random.gumbel(key, logits.shape)
    probs = nnx.softmax(logits / temperature, axis=axis)
    if not hard:
        return probs
    index = probs.argmax(axis)
    onehot = nnx.one_hot(index, probs.shape[axis], axis=axis, dtype=probs.dtype)
    straight_through = probs + jax.lax.stop_gradient(onehot - probs)
    return straight_through


def reconstruction_loss_fn(preds: jax.Array, batch: jax.Array) -> jax.Array:
    p = (batch.shape[1] - preds.shape[1]) // 2
    loss = jnp.mean((preds - batch[:, p:-p]) ** 2)
    return loss


def diversity_loss_fn(logits: jax.Array, codebook_size: int) -> jax.Array:
    avg_probs = nnx.softmax(logits).mean((0, 1))
    entropy = -jnp.sum(avg_probs * jnp.log(avg_probs + 1e-7))
    perplexity = jnp.exp(entropy)
    loss = (codebook_size - perplexity) / codebook_size
    return loss


def norm(tree) -> jax.Array:
    return jnp.linalg.norm(jnp.stack(jax.tree.flatten(jax.tree.map(jnp.linalg.norm, tree))[0]))


class ResBlock(nnx.Module):

    def __init__(self, dim: int, kernel_size: int, *, rngs: nnx.Rngs):
        self.conv = nnx.Conv(dim, 2 * dim, (kernel_size,), param_dtype=default_dtype, rngs=rngs)
        self.alpha = nnx.Param(jnp.zeros((), dtype=default_dtype))

    def __call__(self, x: jax.Array) -> jax.Array:
        return x + self.alpha * nnx.glu(ckpt(self.conv(x), "conv"))


class Backbone(nnx.Module):

    def __init__(self, depth: int, dim: int, kernel_size: int, *, rngs: nnx.Rngs):

        @nnx.split_rngs(splits=depth)
        @nnx.vmap
        def create_block(rngs: nnx.Rngs):
            return ResBlock(dim, kernel_size, rngs=rngs)

        self.p = depth * (kernel_size // 2)
        self.blocks = create_block(rngs)

    def __call__(self, x: jax.Array) -> jax.Array:

        @nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
        @nnx.remat(prevent_cse=False, policy=jax.checkpoint_policies.save_only_these_names("conv"))
        def forward(block, x):
            return block(x)

        x = forward(self.blocks, x)
        return x[:, self.p:-self.p]


class VQVAE(nnx.Module):

    def __init__(
        self,
        *,
        octaves: int,
        depth: int,
        dim: int,
        kernel_size: int,
        codebook_size: int,
        rngs: nnx.Rngs,
    ):
        self.octaves = octaves
        self.depth = depth
        self.dim = dim
        self.kernel_size = kernel_size
        self.codebook_size = codebook_size
        self.rngs = rngs
        
        self.linear_in = nnx.Linear(1, dim, param_dtype=default_dtype, rngs=rngs)
        self.encoder_blocks = [Backbone(depth, dim * 2 ** o, kernel_size, rngs=rngs) for o in range(octaves + 1)]
        self.proj = nnx.Linear(dim * 2 ** octaves, codebook_size, param_dtype=default_dtype, rngs=rngs)
        self.codebook = nnx.Param(jax.random.normal(rngs.params(), (codebook_size, dim * 2 ** octaves), dtype=default_dtype))
        self.decoder_blocks = [Backbone(depth, dim * 2 ** o, kernel_size, rngs=rngs) for o in range(octaves, -1, -1)]
        self.linear_out = nnx.Linear(dim, 1, param_dtype=default_dtype, rngs=rngs)
        
    def encoder(self, x: jax.Array) -> jax.Array:
        x = self.linear_in(x)
        for backbone in self.encoder_blocks[:-1]:
            x = _down(backbone(x))
        x = self.encoder_blocks[-1](x)
        logits = self.proj(x)
        return logits
    
    def quantizer(self, logits: jax.Array, temperature: float = 1.0) -> jax.Array:
        key = self.rngs.gumbel_softmax()
        onehots = gumbel_softmax(key, logits, temperature=temperature, hard=True)
        codes = onehots @ self.codebook
        return codes
    
    def decoder(self, x: jax.Array) -> jax.Array:
        x = self.decoder_blocks[0](x)
        for backbone in self.decoder_blocks[1:]:
            x = backbone(_up(x))
        x = self.linear_out(x)
        return x
        
    def __call__(self, x: jax.Array, temperature: float = 1.0) -> tuple[jax.Array, jax.Array]:
        logits = self.encoder(x)
        x = self.quantizer(logits, temperature)
        x = self.decoder(x)
        return x, logits

    def valid_length(self, bottleneck: int) -> int:
        length = bottleneck
        length += (self.kernel_size - 1) * self.depth
        for _ in range(self.octaves):
            length *= 2
            length += (self.kernel_size - 1) * self.depth
        return length


@nnx.jit
def train_step(
    model: VQVAE, 
    optimizer: nnx.Optimizer, 
    batch: jax.Array,
    temperature: float = 1.0,
) -> tuple[jax.Array, dict[str, jax.Array]]:
    
    @nnx.value_and_grad(has_aux=True)
    def loss_fn(model):
        preds, logits = model(batch)
        reconstruction_loss = reconstruction_loss_fn(preds, batch)
        diversity_loss = diversity_loss_fn(logits, model.codebook_size)
        loss = 100 * reconstruction_loss + 0.1 * diversity_loss
        log_dict = dict(
            reconstruction_loss=reconstruction_loss,
            diversity_loss=diversity_loss,
        )
        return loss, log_dict

    (loss, log_dict), grads = loss_fn(model)
    log_dict["grad_norm"] = norm(grads)
    optimizer.update(grads)
    return loss, log_dict


@nnx.jit
def encode(model: VQVAE, batch: jax.Array) -> jax.Array:
    return model.encoder(default_dtype(batch)).argmax(-1)


@nnx.jit
def decode(model: VQVAE, batch: jax.Array) -> jax.Array:
    return model.decoder(model.codebook[batch]).astype(jnp.float32)


@nnx.jit
def roundtrip(model: VQVAE, batch: jax.Array) -> jax.Array:
    return decode(model, encode(model, batch))
    

class Dataset:

    def __init__(self, filename: str, window_size: int, batch_size: int):
        self.data, self.sr = soundfile.read(filename)
        if self.data.ndim == 1:
            self.data = self.data[:, None]
        self.windows = np.lib.stride_tricks.sliding_window_view(self.data, window_size, axis=0)
        self.batch_size = batch_size

    def __iter__(self):
        while True:
            indices = np.random.randint(len(self.windows), size=self.batch_size)
            yield default_dtype(self.windows[indices].transpose(0, 2, 1))

In [None]:
@dataclass
class Config:
    octaves: int = 6
    depth: int = 15
    dim: int = 16
    kernel_size: int = 3
    codebook_size: int = 8192
    seed: int = 0
    max_grad_norm: float = 1.0
    init_lr: float = 1e-4
    peak_lr: float = 5e-3
    end_lr: float = 1e-4
    warmup_steps: int = 100
    decay_steps: int = 10_000
    weight_decay: float = 0.
    data_filename: str = "romeo_and_juliet.mp3"
    bottleneck_length: int = 64
    batch_size: int = 512

    
def initialize(cfg: Config) -> tuple[VQVAE, nnx.Optimizer, Dataset]:
    model = VQVAE(
        octaves=cfg.octaves,
        depth=cfg.depth,
        dim=cfg.dim,
        kernel_size=cfg.kernel_size,
        codebook_size=cfg.codebook_size,
        rngs=nnx.Rngs(cfg.seed),
    )
    print(f"{sum(jax.tree.flatten(jax.tree.map(jnp.size, nnx.split(model, nnx.Param, ...)[1]))[0])/1_000_000:.1f}M params")
    optimizer = nnx.Optimizer(
        model,
        optax.chain(
            optax.clip_by_global_norm(cfg.max_grad_norm),
            optax.adamw(
                optax.warmup_cosine_decay_schedule(
                    init_value=cfg.init_lr,
                    peak_value=cfg.peak_lr,
                    end_value=cfg.end_lr,
                    warmup_steps=cfg.warmup_steps,
                    decay_steps=cfg.decay_steps,
                ),
                weight_decay=cfg.weight_decay,
            )
        )
    )
    dataset = Dataset(
        cfg.data_filename, 
        window_size=model.valid_length(cfg.bottleneck_length), 
        batch_size=cfg.batch_size,
    )
    return model, optimizer, dataset


cfg = Config()
pprint(cfg)
model, optimizer, dataset = initialize(cfg)

In [None]:
progbar = tqdm.trange(cfg.decay_steps, mininterval=1)
temps = np.geomspace(2, 0.1, cfg.decay_steps)
losses = []
try:
    for step, batch, temperature in zip(progbar, dataset, temps):
        loss, log_dict = train_step(model, optimizer, batch, temperature)
        progbar.set_postfix(loss=loss, temperature=temperature, **log_dict, refresh=False)
        losses.append(loss.item())
finally:
    plt.semilogy(losses)
    plt.grid(True, which="both")

In [None]:
print("original:")
start = dataset.sr * 60 * 45 - 3000  # minutes in. 
sample = dataset.data[None, start:start+model.valid_length(4096)]
display(Audio(sample.squeeze(), rate=dataset.sr))

print("round-trip:")
prediction = roundtrip(model, sample)
display(Audio(prediction.squeeze(), rate=dataset.sr))