### Initialization

In [None]:
#%env XLA_PYTHON_CLIENT_PREALLOCATE=false
#%env XLA_PYTHON_CLIENT_MEM_FRACTION=".99"
%matplotlib inline
%load_ext yamlmagic

import time
import tqdm
import os
import glob
import matplotlib.pyplot as plt
from IPython import display
from data import mid_side_to_stereo
from train import *


def play(wave):
    wave = wave.squeeze()
    if wave.ndim == 2 and wave.shape[1] == 2:
        wave = mid_side_to_stereo(wave)
        wave = wave.T
    elif wave.ndim != 1:
        raise ValueError
    wave /= jnp.abs(wave).max()
    return display.display(display.Audio(wave, rate=44100))

### Config

In [None]:
%%yaml config

experiment: 
  name: null

data:
  filename: "02 - Birth.flac"
  mono: false
  mid_side: true
  batch_size: 1
  length: null  # Falls back to some relation to model pad.
  p: null       # Falls back to p == total model pad // 2. 
  num_threads: 16
  max_queue_length: 64

model:
  ch: 64
  depth: 17
  num_blocks: 2
  kernel_size: 3
  hidden_dim: 128
  momentum: 0.9

optimizer:
  type: adam
  learning_rate: 0.001
  b1: 0.9
  b2: 0.999
  eps: 0.00001
  eps_root: 0.0

training:
  num_iters: 1000000
  log_interval: 1000
  save_interval: 10000

validation:
  interval: 1000
  num_steps: 100
  length: 441000
  padded: true

rngs:
  init: 47
  data: 0
  val: 17

### Experiment

In [None]:
experiment_name = None  # "name-of-some-old-or-new-experiment"
try:
    config = ConfigDict(config)
except NameError:
    config = ConfigDict()
config.experiment.name = experiment_name or config.experiment.name or hruid.Generator().random()
print("experiment name:", config.experiment.name)
checkpoints = sorted(glob.glob(os.path.join(config.experiment.name, "step=*/")))

### Model

In [None]:
if checkpoints:
    print("loading checkpoint:", checkpoints[-1])
    config, state = restore_checkpoint(checkpoints[-1])
else:
    init_rng = jax.random.PRNGKey(config.rngs.init)
    state = TrainState.from_config(init_rng, config)
print("params:", count_params(state))

### Data

In [None]:
data_rng = jax.random.PRNGKey(config.rngs.data)
data = WaveformDataLoader.from_config(data_rng, config.data, fallbacks=dict(p=state.net.p, length=2*state.net.pad))
print("crop size:", data.datagen.length)
assert data.datagen.length > state.net.pad

# Validation setup.
val_rng = jax.random.PRNGKey(config.rngs.val)
val_input_len, val_output_len = get_val_len(config.validation, state.net.pad)
val_xt = jax.random.normal(val_rng, (1, val_input_len, 1 if config.data.mono else 2))

### Training

In [None]:
with data:
    losses = []
    progbar = tqdm.tqdm()
    for step in range(config.training.num_iters):

        # Saving step.
        if step % config.training.save_interval == 0:
            savepath = f"{config.experiment.name}/step={step:09d}"
            if not os.path.exists(savepath):
                print(f"saving {savepath}")
                save_checkpoint(savepath, config, state)
            else:
                print(f"skipping saving {savepath} because checkpoint already exists")

        # Validation step.
        if step % config.validation.interval == 0:
            print("generating")
            play(diffusion_sampling(
                rng=val_rng,
                model=lambda xt, t: apply_model_inference(state, xt, t),
                xt=val_xt,
                num_steps=config.validation.num_steps,
                p=state.net.p if config.validation.padded else 0,
                progbar=False,
            ))

        # Logging step.
        if step and step % config.training.log_interval == 0:
            print(
                f"steps={step}, avg_loss="
                f"{jnp.mean(jnp.array(losses[-config.training.log_interval:])):.3f}")
        
        # Training step.
        state, loss = train_step(state, *data.get())
        losses.append(loss)
        progbar.set_description(f"loss={loss:.3f}")
        progbar.update()

### Sampling

In [None]:
x0_pred = diffusion_sampling(
    rng=val_rng,
    model=lambda xt, t: apply_model_inference(state, xt, t),
    xt=jax.random.normal(val_rng, (1, 44100 * 10, 2)),
    num_steps=100,
    p=state.net.p,
    progbar=False,
)