In [3]:
!export XLA_PYTHON_CLIENT_PREALLOCATE=".99"
%matplotlib inline
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from IPython import display
from train import *


def play(waveform):
    return display.display(display.Audio(waveform.squeeze() / jnp.abs(waveform).max(), rate=44100))


config = OmegaConf.create(
"""
data:
  filename: "01 - Opening Theme.flac"
  batch_size: 16
  length: null
  p: null
  num_threads: 16
  max_queue_length: 16

model:
  ch: 32
  depth: 14
  num_blocks: 4
  kernel_size: 3
  embedding_dim: 32

optimizer:
  type: adam
  learning_rate: 2e-4
  b1: 0.9
  b2: 0.999
  eps: 1e-5

training:
  num_iters: 1000000
  log_interval: 10

validation:
  interval: 10
  num_steps: 100
  length: 441000
  padded: true

rngs:
  init: 1
  data: 2
  val: 3
""")

init_rng = jax.random.PRNGKey(config.rngs.init)
state = TrainState.from_config(init_rng, config)
print("params:", count_params(state))
data_rng = jax.random.PRNGKey(config.rngs.data)
data = WaveformDataLoader.from_config(data_rng, config.data, p=state.net.p, length=2*state.net.pad)
assert data.datagen.length > state.net.pad
val_rng = jax.random.PRNGKey(config.rngs.val)
val_input_len, val_output_len = get_val_len(config.validation, state.net.pad)
val_xt = jax.random.normal(val_rng, (1, val_input_len, 1))
with data:
    losses = []
    for step in range(config.training.num_iters):
        state, loss = train_step(state, *data.get())
        losses.append(loss)
        if interval(step, config.training.log_interval):
            print(
                f"steps={step + 1}, avg_loss="
                f"{jnp.mean(jnp.array(losses[-config.training.log_interval:])):.3f}")
        if interval(step, config.validation.interval):
            print("generating")
            x0_pred = diffusion_sampling(
                rng=val_rng,
                model=lambda xt, t: apply_model_inference(state, xt, t),
                xt=val_xt,
                num_steps=config.validation.num_steps,
                p=state.net.p if config.validation.padded else 0,
            )
            play(x0_pred)

params: 315813


TypeError: add got incompatible shapes for broadcasting: (16,), (262144,).

In [None]:
from flax import struct, serialization
from flax.training import orbax_utils

import orbax.checkpoint


ckpt = {"model": state, "config": config, "data": data}
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save('ckpts/01/', ckpt, save_args=save_args)

In [None]:
!ls -a ckpts/01/checkpoint

In [None]:
plt.semilogy(losses)