In [None]:
#%env XLA_PYTHON_CLIENT_PREALLOCATE=".99"
%matplotlib inline
import time
import tqdm
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from IPython import display
from train import *


def play(waveform):
    return display.display(display.Audio(waveform.squeeze() / jnp.abs(waveform).max(), rate=44100))


config = OmegaConf.create(
"""
data:
  filename: "02 - Birth.flac"
  batch_size: 8
  length: 263168
  p: null
  num_threads: 16
  max_queue_length: 16

model:
  ch: 64
  depth: 16
  num_blocks: 2
  kernel_size: 3
  hidden_dim: 128

optimizer:
  type: adam
  learning_rate: 1e-3
  b1: 0.9
  b2: 0.999
  eps: 1e-5

training:
  num_iters: 1000000
  log_interval: 100
  save_interval: 10000

validation:
  interval: 100
  num_steps: 100
  length: 441000
  padded: true

rngs:
  init: 1
  data: 2
  val: 3
""")

# Model setup.
init_rng = jax.random.PRNGKey(config.rngs.init)
state = TrainState.from_config(init_rng, config)
print("params:", count_params(state))

# Data setup.
data_rng = jax.random.PRNGKey(config.rngs.data)
data = WaveformDataLoader.from_config(data_rng, config.data, fallbacks=dict(p=state.net.p, length=1+state.net.pad))
print("crop size:", data.datagen.length)
assert data.datagen.length > state.net.pad

# Validation setup.
val_rng = jax.random.PRNGKey(config.rngs.val)
val_input_len, val_output_len = get_val_len(config.validation, state.net.pad)
val_xt = jax.random.normal(val_rng, (1, val_input_len, 1))

# Loop setup.
with data:
    losses = []
    progbar = tqdm.tqdm()
    for step in range(config.training.num_iters):
        
        # Training step.
        state, loss = train_step(state, *data.get())
        losses.append(loss)
        progbar.update()
        
        # Logging step.
        if interval(step, config.training.log_interval):
            print(
                f"steps={step + 1}, avg_loss="
                f"{jnp.mean(jnp.array(losses[-config.training.log_interval:])):.3f}")
        
        # Saving step.
        if interval(step, config.training.save_interval):
            save_checkpoint(f"step={step:09d}", state, config, data)
        
        # Validation step.
        if interval(step, config.validation.interval):
            print("generating")
            play(diffusion_sampling(
                rng=val_rng,
                model=lambda xt, t: apply_model_inference(state, xt, t),
                xt=val_xt,
                num_steps=config.validation.num_steps,
                p=state.net.p if config.validation.padded else 0,
                progbar=False,
            ))

In [None]:
play(diffusion_sampling(
    rng=val_rng,
    model=lambda xt, t: apply_model_inference(state, xt, t),
    xt=jax.random.normal(val_rng, (1, 44100 * 30, 1)),
    num_steps=100,
    p=state.net.p,
    progbar=False,
))