In [None]:
!python --version
%pip install ../ # Use local jaxDSP package rather than published one.
%pip install numpy matplotlib jax jaxlib --upgrade

In [None]:
import jax.numpy as jnp
from scipy import signal
import numpy as np
import time

from jaxdsp import training, processor_graph
from jaxdsp.processors import fir_filter, iir_filter, clip, delay_line, biquad_lowpass, lowpass_filter, allpass_filter, freeverb, sine_wave, processors_to_graph_config, processor_names_from_graph_config
from jaxdsp.plotting import plot_train, plot_optimization
from jaxdsp.loss import LossOptions
from jaxdsp.params import params_to_float

In [None]:
num_train = 100

buffer_size = 44100
Xs_random = np.random.randn(num_train, buffer_size)
Xs_chirp = np.array(np.split(signal.chirp(np.linspace(0.0, num_train, num_train * buffer_size), f0=10, f1=1000, t1=num_train), num_train))

default_loss_opts = LossOptions(
    weights={"sample": 1.0},
    distance_types={"sample": "L2"},
)
# NOTE (for blog post/paper):
# Changing the STFT window can have dramatic effects on the optimization.
# E.g. with the sines->filter test below, changing `stft_window` from the default ('hann') to 'hamming' or 'bartlett' makes it not succeed.
spectral_loss_opts = LossOptions(
    weights={"cumsum_freq": 1.0},
    distance_types={"frequency": "L1"},
#    stft_window="hamming",
)
optimizer_opts = {"name": "Adam"}

In [None]:
def evaluate_processors(processors, params_target, loss_opts=default_loss_opts, optimization_opts=optimizer_opts, Xs=Xs_chirp, num_batches=100, reference_fn=None, plot_loss_history=True, plot_params_history=True, title=None):
    graph_config = processors_to_graph_config(processors)
    processor_names = processor_names_from_graph_config(graph_config)
    trainer = training.IterativeTrainer(graph_config, loss_opts, optimization_opts, track_history=True)
    carry_target = (params_target, trainer.state)
    start = time.time()
    for i in range(num_batches):
        # X = Xs[i % Xs.shape[0]]
        X = Xs[np.random.choice(Xs.shape[0])]
        carry_target, Y_target = processor_graph.tick_buffer(carry_target, X, processor_names)
        trainer.step(X, Y_target)

    params_estimated = params_to_float(trainer.params)
    print('Train time: {:.3E} s'.format(time.time() - start))
    print('Loss: {:.3E}'.format(trainer.loss))
    print('Estimated params: ', params_estimated)

    X_eval = Xs[0]
    _, Y_estimated = processor_graph.tick_buffer((params_estimated, trainer.state), X_eval, processor_names)
    _, Y_target = processor_graph.tick_buffer(carry_target, X_eval, processor_names)
    Y_reference = reference_fn(X_eval, carry_target[0]) if reference_fn is not None else None

    plot_train(trainer, params_target, X_eval, Y_target, Y_estimated, Y_reference, title=title, plot_loss_history=plot_loss_history, plot_params_history=plot_params_history)

In [None]:
evaluate_processors([[lowpass_filter]], [[{"feedback": 0.5, "damp": 0.5}]])

In [None]:
evaluate_processors([[sine_wave]], [[{"frequency_hz": 400.0}]], loss_opts=spectral_loss_opts, optimization_opts={'name': 'Adam', 'params': {'step_size': 0.0003}})

In [None]:
evaluate_processors([[freeverb]], [[{
    'wet': 0.3,
    'dry': 0.0,
    'width': 1.0,
    'damp': 0.5,
    'room_size': 0.5,
}]], num_batches=4)

In [None]:
evaluate_processors([[allpass_filter]], [[{'feedback': 0.5}]])

In [None]:
reference_fn = lambda X, params: signal.lfilter(params[0][0]['B'], [1.0], X)
evaluate_processors([[fir_filter]], [[{'B': jnp.array([0.1, 0.7, 0.5, 0.6])}]], Xs=Xs_random, reference_fn=reference_fn)

In [None]:
from scipy import signal
B_target, A_target = signal.butter(4, 0.5, "low")
iir_filter_target_params = {
    'B': B_target,
    'A': A_target,
}
reference_fn = lambda X, params: signal.lfilter(params[0][0]['B'], params[0][0]['A'], X)
evaluate_processors([[iir_filter]], [[iir_filter_target_params]], Xs=Xs_random, reference_fn=reference_fn)

In [None]:
clip_target_params = {'min': -0.5, 'max': 0.5}
reference_fn = lambda X, params: np.clip(X, params[0][0]['min'], params[0][0]['max'])
evaluate_processors([[clip]], [[clip_target_params]], reference_fn=reference_fn)

In [None]:
evaluate_processors([[iir_filter], [clip]], [[iir_filter_target_params], [clip_target_params]], num_batches=200)

In [None]:
evaluate_processors([[sine_wave, sine_wave], [allpass_filter]], [[{'frequency_hz': 400.0}, {'frequency_hz': 600.0}], [{'feedback': 0.5}]], loss_opts=spectral_loss_opts, optimization_opts={'name': 'Adam', 'params': {'step_size': 0.006}}, num_batches=400)