In [None]:
!python --version
%pip install jax jaxlib --upgrade

In [None]:
import jax.numpy as jnp
from scipy import signal
import numpy as np
import time

from jaxdsp import training, processor_graph
from jaxdsp.processors import fir_filter, iir_filter, clip, delay_line, lowpass_feedback_comb_filter as lbcf, allpass_filter, freeverb, sine_wave
from jaxdsp.plotting import plot_filter, plot_loss, plot_params, plot_optimization
from jaxdsp.loss import LossOptions

In [None]:
num_train = 100

buffer_size = 44100
Xs_random = np.random.randn(num_train, buffer_size)
Xs_chirp = np.array(np.split(signal.chirp(np.linspace(0.0, num_train, num_train * buffer_size), f0=10, f1=1000, t1=num_train), num_train))

default_loss_opts = LossOptions(
    weights={
        "sample": 1.0,
    },
    distance_types={
        "sample": "L2",
    },
)
spectral_loss_opts = LossOptions(
    weights={
        "cumsum_freq": 1.0,
    },
    distance_types={
        "frequency": "L1",
    },
)

In [None]:
def evaluate_processors(processors, params_targets, loss_opts=default_loss_opts, optimization_opts={}, Xs=Xs_chirp, num_batches=100, reference_fn=None, plot_loss_history=True, plot_params_history=True, title=None):
    processor_config = [{"name": processor.NAME} for processor in processors]
    processor_names = [processor.NAME for processor in processors]
    trainer = training.IterativeTrainer(processor_config, loss_opts, optimization_opts, track_history=True)
    carry_target = (params_targets, trainer.state)
    start = time.time()
    for i in range(num_batches):
        # X = Xs[i % Xs.shape[0]]
        X = Xs[np.random.choice(Xs.shape[0])]
        carry_target, Y_target = processor_graph.tick_buffer_series(carry_target, X, processor_names)
        trainer.step(X, Y_target)

    params_estimated = trainer.float_params()
    carry_estimated = (params_estimated, trainer.state)
    print('Train time: {:.3E} s'.format(time.time() - start))
    print('Loss: {:.3E}'.format(trainer.loss))
    print('Estimated params: ', params_estimated)

    X_eval = Xs[0]
    carry_estimated, Y_estimated = processor_graph.tick_buffer_series(carry_estimated, X_eval, processor_names)
    carry_target, Y_target = processor_graph.tick_buffer_series(carry_target, X_eval, processor_names)
    Y_reference = reference_fn(X_eval, carry_target[0]) if reference_fn is not None else None

    if plot_loss_history:
        plot_loss(trainer.step_evaluator.loss_history)
    if plot_params_history:
        plot_params(training.float_params(carry_target[0]), trainer.step_evaluator.params_history, processor_names)
    plot_filter(X_eval, Y_target, Y_reference, Y_estimated, title)

In [None]:
evaluate_processors([lbcf], [{"feedback": 0.5, "damp": 0.5}])

In [None]:
evaluate_processors([sine_wave], [{"frequency_hz": 400.0}], loss_opts=spectral_loss_opts, optimization_opts={'params': {'step_size': 0.0003}})

In [None]:
evaluate_processors([freeverb], [{
    'wet': 0.3,
    'dry': 0.0,
    'width': 1.0,
    'damp': 0.5,
    'room_size': 0.5,
}], num_batches=4)

In [None]:
evaluate_processors([allpass_filter], [{'feedback': 0.5}])

In [None]:
evaluate_processors([delay_line], [{'wet': 0.5, 'delay_samples': 10.0}], loss_opts=spectral_loss_opts)

In [None]:
evaluate_processors([fir_filter], [{'B': jnp.array([0.1, 0.7, 0.5, 0.6])}], Xs=Xs_random, reference_fn=lambda X, params: signal.lfilter(params[fir_filter.NAME]['B'], [1.0], X))

In [None]:
from scipy import signal
B_target, A_target = signal.butter(4, 0.5, "low")
iir_filter_target_params = {
    'B': B_target,
    'A': A_target,
}
evaluate_processors([iir_filter], [iir_filter_target_params], Xs=Xs_random, reference_fn=lambda X, params: signal.lfilter(params[iir_filter.NAME]['B'], params[iir_filter.NAME]['A'], X))

In [None]:
clip_target_params = {'min': -0.5, 'max': 0.5}
evaluate_processors([clip], [clip_target_params], reference_fn=lambda X, params: np.clip(X, params[clip.NAME]['min'], params[clip.NAME]['max']))

In [None]:
evaluate_processors([iir_filter, clip], [iir_filter_target_params, clip_target_params], num_batches=500)