In [None]:
!python --version
%pip install jax jaxlib --upgrade

In [None]:
import jax.numpy as jnp
from scipy import signal
import numpy as np
import time

from jaxdsp import training
from jaxdsp.processors import fir_filter, iir_filter, clip, delay_line, lowpass_feedback_comb_filter as lbcf, allpass_filter, freeverb, sine_wave, serial_processors
from jaxdsp.plotting import plot_filter, plot_loss, plot_params, plot_optimization
from jaxdsp.loss import LossOptions

In [None]:
num_train = 100

buffer_size = 44100
Xs_random = np.random.randn(num_train, buffer_size)
Xs_chirp = np.array(np.split(signal.chirp(np.linspace(0.0, num_train, num_train * buffer_size), f0=10, f1=1000, t1=num_train), num_train))

default_loss_opts = LossOptions(
    weights={
        "sample": 1.0,
    },
    distance_types={
        "sample": "L2",
    },
)
spectral_loss_opts = LossOptions(
    weights={
        "cumsum_freq": 1.0,
    },
    distance_types={
        "frequency": "L1",
    },
)

In [None]:
def evaluate_processors(processors, params_targets, loss_opts=default_loss_opts, optimization_opts={}, Xs=Xs_chirp, num_batches=100, reference_fn=None, plot_loss_history=True, plot_params_history=True, title=None):
    processor = serial_processors
    processor_state = serial_processors.state_init(processors)
    trainer = training.IterativeTrainer(processor, loss_opts, optimization_opts, processor_state, track_history=True)

    params_targets_dict = {processor.NAME: params_target for processor, params_target in zip(processors, params_targets)}
    carry_target = {'params': params_targets_dict, 'state': processor_state}
    start = time.time()
    for i in range(num_batches):
        # X = Xs[i % Xs.shape[0]]
        X = Xs[np.random.choice(Xs.shape[0])]
        carry_target, Y_target = processor.tick_buffer(carry_target, X)
        trainer.step(X, Y_target)

    params_estimated = trainer.params()
    carry_estimated = {'params': params_estimated, 'state': trainer.processor_state}
    print('Train time: {:.3E} s'.format(time.time() - start))
    print('Loss: {:.3E}'.format(trainer.loss))
    print('Estimated params: ', params_estimated)

    X_eval = Xs[0]
    Y_estimated, Y_target = training.evaluate(carry_estimated, carry_target, processor, X_eval)
    Y_reference = reference_fn(X_eval, carry_target['params']) if reference_fn is not None else None

    if plot_loss_history:
        plot_loss(trainer.step_evaluator.loss_history)
    if plot_params_history:
        plot_params(training.float_params(carry_target['params']), trainer.step_evaluator.params_history)
    plot_filter(X_eval, Y_target, Y_reference, Y_estimated, title)

In [None]:
evaluate_processors([lbcf], [{"feedback": 0.5, "damp": 0.5}])

In [None]:
evaluate_processors([sine_wave], [{"frequency_hz": 400.0}], loss_opts=spectral_loss_opts, optimization_opts={'params': {'step_size': 0.0003}})

In [None]:
evaluate_processors([freeverb], [{
    'wet': 0.3,
    'dry': 0.0,
    'width': 1.0,
    'damp': 0.5,
    'room_size': 0.5,
}], num_batches=4)

In [None]:
evaluate_processors([allpass_filter], [{'feedback': 0.5}])

In [None]:
evaluate_processors([delay_line], [{'wet': 0.5, 'delay_samples': 10.0}], loss_opts=spectral_loss_opts)

In [None]:
evaluate_processors([fir_filter], [{'B': jnp.array([0.1, 0.7, 0.5, 0.6])}], Xs=Xs_random, reference_fn=lambda X, params: signal.lfilter(params[fir_filter.NAME]['B'], [1.0], X))

In [None]:
from scipy import signal
B_target, A_target = signal.butter(4, 0.5, "low")
iir_filter_target_params = {
    'B': B_target,
    'A': A_target,
}
evaluate_processors([iir_filter], [iir_filter_target_params], Xs=Xs_random, reference_fn=lambda X, params: signal.lfilter(params[iir_filter.NAME]['B'], params[iir_filter.NAME]['A'], X))

In [None]:
clip_target_params = {'min': -0.5, 'max': 0.5}
evaluate_processors([clip], [clip_target_params], reference_fn=lambda X, params: np.clip(X, params[clip.NAME]['min'], params[clip.NAME]['max']))

In [None]:
evaluate_processors([iir_filter, clip], [iir_filter_target_params, clip_target_params], num_batches=500)

## Implementing Freeverb

From [Physical Audio Signal Processing](https://ccrma.stanford.edu/~jos/pasp/Freeverb.html):

![](https://ccrma.stanford.edu/~jos/pasp/img728_2x.png)

It is composed of [lowpass feedback comb filters](https://ccrma.stanford.edu/~jos/pasp/Lowpass_Feedback_Comb_Filter.html) and [Schroeder allpass sections](https://ccrma.stanford.edu/~jos/pasp/Schroeder_Allpass_Sections.html).

In order to implement Freeverb in a differentiable way, I will use plain IIR filters for each component. (Note that this is extremely inefficient compared with the direct implementation of the difference equation.) (See issues in differentiable parameterizing of delay line length above.)


$H_{AP}(z)= \dfrac{X(z)}{Y(z)} =\dfrac{g^∗+z^{−m}}{1+gz^{−m}}$

$\begin{align}
H_{LBCF}(z) &= \dfrac{z^{-N}}{1-f\frac{1-d}{1-dz^{-1}}z^{-N}}\\
&= \dfrac{z^{-N}}{\frac{1-dz^{-1}-f(1-d)z^{-N}}{1-dz^{-1}}}\\
&= \dfrac{z^{-N}(1-dz^{-1})}{1-dz^{-1}-f(1-d)z^{-N}}\\
&= \dfrac{-dz^{-N-1}+z^{-N}}{1-dz^{-1}-f(1-d)z^{-N}}\\
\end{align}$

TODO: verify the behavior is identical