In [None]:
!python --version
%pip install jax jaxlib --upgrade

In [None]:
import jax.numpy as jnp
from scipy import signal
import numpy as np
import time

from jaxdsp.training import train, evaluate
from jaxdsp.processors import fir_filter, iir_filter, clip, delay_line, lowpass_feedback_comb_filter as lbcf, allpass_filter, freeverb, sine_wave, serial_processors
from jaxdsp.plotting import plot_filter, plot_loss, plot_params, plot_optimization
from jaxdsp.training import train, evaluate
import jaxdsp.loss

In [None]:
def array_with_one_at(i, size):
    ar = np.zeros(size)
    ar[i] = 1.0
    return ar

In [None]:
step_size = 0.1
n_batches = 1000
batch_size = 2
n_train = 100
samples_per_second = 44_100
n_samples = 44_100 # 4 * samples_per_second
Xs_random = np.random.randn(n_train, n_samples)
Xs_unit = np.array([array_with_one_at(i, n_samples) for i in range(n_train)])
Xs_chirp = np.array(np.split(signal.chirp(np.linspace(0, 10, n_train * n_samples), f0=300, f1=1, t1=10), n_train))
#Xs_chirp = np.ones((n_train,1)) * signal.chirp(np.linspace(0, 10, n_samples), f0=10, f1=1, t1=10)

In [None]:
def evaluate_processors(processors, loss_fn=jaxdsp.loss.mse, Xs=Xs_chirp, reference_fn=None,
                        plot_loss_history=True, plot_params_history=True):
    start = time.time()
    params_estimated, params_target, params_history, loss_history = train(processors, Xs, loss_fn=loss_fn, step_size=step_size, num_batches=n_batches, batch_size=batch_size)
    print('Train time: {:.3E} s'.format(time.time() - start))
    print('Loss: {:.3E}'.format(loss_history[-1]))
    X_eval = Xs[0]
    carry_estimated = {'params': params_estimated, 'state': serial_processors.init_state(processors)}
    carry_target = {'params': params_target, 'state': serial_processors.init_state(processors)}
    Y_estimated, Y_target = evaluate(carry_estimated, carry_target, serial_processors, X_eval)
    Y_reference = reference_fn(X_eval, params_target) if reference_fn is not None else None

    print(params_estimated)
    if plot_loss_history:
        plot_loss(loss_history)
    if plot_params_history:
        plot_params(params_target, params_history)
    title = ' + '.join(processor.NAME for processor in processors)
    plot_filter(X_eval, Y_target, Y_reference, Y_estimated, title)

In [None]:
# TODO whyyyyy u no work?!
evaluate_processors([sine_wave], jaxdsp.loss.spectral)

In [None]:
# evaluate_processors([freeverb])

In [None]:
evaluate_processors([lbcf])

In [None]:
evaluate_processors([allpass_filter])

In [None]:
evaluate_processors([delay_line])

In [None]:
evaluate_processors([fir_filter], reference_fn=lambda X, params: signal.lfilter(params[fir_filter.NAME]['B'], [1.0], X))

In [None]:
evaluate_processors([iir_filter], reference_fn=lambda X, params: signal.lfilter(params[iir_filter.NAME]['B'], params[iir_filter.NAME]['A'], X), plot_params_history=False)

In [None]:
evaluate_processors([clip], reference_fn=lambda X, params: np.clip(X, params[clip.NAME]['min'], params[clip.NAME]['max']))

In [None]:
evaluate_processors([iir_filter, clip])

## Non-batch training example

In [None]:
from jaxdsp.training import IterativeTrainer

processor = lbcf
Xs = Xs_chirp
carry_target = {'params': processor.default_target_params(), 'state': processor.init_state()}
trainer = IterativeTrainer(processor)
for _ in range(200):
    X = Xs[np.random.randint(Xs.shape[0])]
    carry_target, Y_target = processor.tick_buffer(carry_target, X)
    trainer.step(X, Y_target)

trainer.params_and_loss()

In [None]:
param_inits = [{'frequency_hz': freq, 'phase': 0.0} for freq in np.linspace(100, 140, 50)]
params_target = {'frequency_hz': 120.0, 'phase': 0.0}
plot_optimization(sine_wave, Xs_chirp, param_inits, params_target, 'frequency_hz', 10)

In [None]:
# TODO: verify Lagrange interpolation converges over a 4-sample range

## Implementing Freeverb

From [Physical Audio Signal Processing](https://ccrma.stanford.edu/~jos/pasp/Freeverb.html):

![](https://ccrma.stanford.edu/~jos/pasp/img728_2x.png)

It is composed of [lowpass feedback comb filters](https://ccrma.stanford.edu/~jos/pasp/Lowpass_Feedback_Comb_Filter.html) and [Schroeder allpass sections](https://ccrma.stanford.edu/~jos/pasp/Schroeder_Allpass_Sections.html).

In order to implement Freeverb in a differentiable way, I will use plain IIR filters for each component. (Note that this is extremely inefficient compared with the direct implementation of the difference equation.) (See issues in differentiable parameterizing of delay line length above.)


$H_{AP}(z)= \dfrac{X(z)}{Y(z)} =\dfrac{g^∗+z^{−m}}{1+gz^{−m}}$

$\begin{align}
H_{LBCF}(z) &= \dfrac{z^{-N}}{1-f\frac{1-d}{1-dz^{-1}}z^{-N}}\\
&= \dfrac{z^{-N}}{\frac{1-dz^{-1}-f(1-d)z^{-N}}{1-dz^{-1}}}\\
&= \dfrac{z^{-N}(1-dz^{-1})}{1-dz^{-1}-f(1-d)z^{-N}}\\
&= \dfrac{-dz^{-N-1}+z^{-N}}{1-dz^{-1}-f(1-d)z^{-N}}\\
\end{align}$

TODO: verify the behavior is identical

In [None]:
# http://devmaster.net/forums/topic/4648-fast-and-accurate-sinecosine/
def sine_approx(x):
    y = 4/np.pi * x - 4/(np.pi**2) * x * np.abs(x)
    return y * (0.225 * (np.abs(y) - 1) + 1)

In [None]:
from matplotlib import pyplot as plt

x = np.linspace(0, 20, 100) % (2 * np.pi) - np.pi
plt.plot(sine_approx(x))
plt.plot(np.sin(x))

In [None]:
plt.plot(np.abs(jnp.fft.rfft(jnp.sin(jnp.linspace(-np.pi, np.pi, 50, True)))) / 50)