In [1]:
!python --version

Python 3.8.6


### TODO

* ~Try a built-in optimizer (like grad descent w/ momentum) instead of constant learning rate. See how convergence time changes~
* ~Plot loss~
* ~Refactor (processors are modules)~
* ~Define `init_params` methods for each processor, initializing to params that have no effect on the signal (no-op  prior)~
  - this is something to note in writeup. talk about small changes to params having big effect in processors with feedback. And how starting from no effect initially is good for live performance settings.
* ~Add a simple clipping nonlinearity effect to see how `grad` does with that~
* ~Refactoring~
  - ~Use dicts for labeled params instead of n-d arrays (jax handles this)~
  - ~Move parameter label creation into processor fns~
* `Estimate params for multiple serially-connected filters (create a general `serial_processors` processor)~
* Animate changes in output signal over time, compared with target
* Run on GPU & measure performance differences
* Performance viability of realistic-length signals
* Apply to audio to hear perceptual differences
* More than one training pair (more than just impulse response)
* Use a perceptual loss function instead of mse (steal DDSP's multi-scale spectral loss fn)
* Improve performance
  - ~Speed things up with [JIT](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Using-jit-to-speed-up-functions)~
* ~I don't think functions truly need to be pure. I tried passing in a plain np array and changing it in place and it seemed to work fine. Maybe we can use shared buffers to improve memory usage?~
  - ~Answer - this is because you can create and use state _inside_ functions transformed by JAX. I found a decent middleground here that lets me use processor classes with internal state by instantiating them inside the transformed fn.~
* End-goal (ready for blog post): Implement freeverb and perform dereverbing

In [2]:
import jax.numpy as jnp
from scipy import signal
import numpy as np

In [3]:
import sys
sys.path.append('./')
sys.path.append('./processors')
from fir_filter import FirFilter
from iir_filter import IirFilter
from clip import Clip
from plotters import plot_filter, plot_loss, plot_params
from train import train, train_serial

In [4]:
%%time
X = jnp.concatenate([jnp.array([1.0]), jnp.zeros(30)])
    
params_estimated, params_target, Y_estimated, Y_target, params_history, loss_history = train(FirFilter, X)



TypeError: Argument '<class 'fir_filter.FirFilter'>' of type <class 'type'> is not a valid JAX type

In [5]:
plot_loss(loss_history)
plot_params(params_target, params_history)
Y_reference = signal.lfilter(params_target['B'], [1.0], X)
plot_filter(X, Y_target, Y_reference, Y_estimated, 'FIR Filter')

NameError: name 'loss_history' is not defined

In [None]:
%%time
params_estimated, params_target, Y_estimated, Y_target, params_history, loss_history = train(IirFilter, X)

In [None]:
plot_loss(loss_history)
plot_params(params_target, params_history)
Y_reference = signal.lfilter(params_target['B'], params_target['A'], X)
plot_filter(X, Y_target, Y_reference, Y_estimated, 'IIR Filter')

In [None]:
%%time
X = np.random.uniform(-2, 2, X.size)
params_estimated, params_target, Y_estimated, Y_target, params_history, loss_history = train(Clip, X)

In [None]:
plot_loss(loss_history)
plot_params(params_target, params_history)
Y_reference = np.clip(X, params_target['min'], params_target['max'])
plot_filter(X, Y_target, Y_reference, Y_estimated, 'Clipping nonlinearity')

In [None]:
%%time
X = np.random.uniform(-2, 2, X.size)
params_estimated, params_target, Y_estimated, Y_target, params_history, loss_history = train_serial([IirFilter, Clip], X)

In [None]:
plot_loss(loss_history)
plot_params(params_target['iir_filter'], params_history['iir_filter'])
plot_params(params_target['clip'], params_history['clip'])

plot_filter(X, Y_target, None, Y_estimated, 'IIR Filter + Clip')

In [None]:
# import IPython
# from scipy.io.wavfile import read as wavread

# wav_fs, audio = wavread('speech-male.wav')
# audio = audio.astype(float) / audio.max()
# print('Original audio:')
# IPython.display.Audio(audio, rate=wav_fs)