<a href="https://colab.research.google.com/github/cyrusvahidi/timbre2023/blob/main/Timbre_2023_Neurophysiological_Simulation_for_Digital_Audio_Effects.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Timbre 2023: Neurophysiological Simulation for Digital Audio Effects


Welcome to Neurophysiological Simulation for Digital Audio Effects at Timbre 2023

This is a tutorial about scattering transforms in Kymatio.


## **What is Kymatio?**

[Kymatio](https://kymat.io) [(Andreux et al. 2020)](https://scholar.google.com/scholar_url?url=https://www.jmlr.org/papers/volume21/19-047/19-047.pdf&hl=en&sa=T&oi=gsb-ggp&ct=res&cd=0&d=13496502122111863300&ei=gZamZKPcB9KsmgHR5qyoDg&scisig=ABFrs3xtOuMXLkkilckRcYmjWbZA) is an open-source Python package for applications at the intersection of deep learning and wavelet scattering.

Its forthcoming stable release (v0.4) provides an implementation of the joint time—frequency scattering transform (JTFS).


## **Why scattering at Timbre 2023?**

JTFS is an idealisation of a neurophysiological model that is commonly known in timbre perception research: the spectrotemporal receptive field (STRF) [Patil et al., 2012](https://scholar.google.com/scholar_url?url=https://journals.plos.org/ploscompbiol/article%3Fid%3D10.1371/journal.pcbi.1002759&hl=en&sa=T&oi=gsb-ggp&ct=res&cd=0&d=13643030775703850815&ei=4JamZO6rEd2Sy9YP2fyhwAg&scisig=ABFrs3xGck11FtfqtyU_-VGjR1TE).

The STRF simulates processing in the mammalian primary cortex, based on empirical measurements of the response of neurons to moving ripple spectra.

JTFS offers benefits from the STRF in five factors: (i) differentiability, (ii)  efficiency, (iii) numerical accuracy, (iv) GPU-compatibility and (v) portability across scientific computing and machine learning frameworks in Python.

Of relevance to timbre research, we previously demonstrated that JTFS accurately represents similarity between spectrotemporal modulations and serves as a state-of-the-art feature extractor for musical instrument classification (Muradeli et al., 2022). Spectrotemporal modulations in the STRF have shown to serve as a predictor of timbre dissimilarity judgements (Patil et al., 2012, Thoret et al., 2021).

Likewise, Euclidean distances in JTFS can predict human judgments of similarity between musical instrument playing techniques (Lostanlen et al., 2021).

Recently, JTFS has been used as a loss function to control sound synthesis parameters that generate spectrotemporal modulations (Vahidi et al., 2023).


## **Scattering tools in this notebook**:

In this notebook, we will gain intuition on:

- Understanding Kymatio's frontend parameters and how the scattering transform works
- How the scattering transform represents the physical properties of modulated signals
- The scattering transform's time and frequency transposition invariance
- Acoustic modelling of sound synthesis parameters

We will tour Kymatio for audio signals, via:

- Filterbank construction and visualization
- `Scattering1D` and `TimeFrequencyScattering` frontends
- Visualizing first-order ($ S_1 x [\lambda, t]$) and second-order ($S_2 x [\lambda, \lambda_2, t]$) scattering coefficients
- Similarity retrieval of spectrotemporal modulation signals
- Differentiable modelling of spectrotemporal modulations
- Comparison to short-time Fourier representation of music signals

## **Authors**

Cyrus Vahidi (https://github.com/cyrusvahidi/, https://twitter.com/cyrusasfa)

Vincent Lostanlen (https://github.com/lostanlen/, https://twitter.com/lostanlen)

Kymatio (https://github.com/kymatio/kymatio)


## **References**

- [Andreux, M. et al. (2020). **Kymatio: Scattering transforms in Python.** The Journal of Machine Learning Research, 21(1), 2256-2261.](https://scholar.google.com/scholar_url?url=https://www.jmlr.org/papers/volume21/19-047/19-047.pdf&hl=en&sa=T&oi=gsb-ggp&ct=res&cd=0&d=13496502122111863300&ei=gZamZKPcB9KsmgHR5qyoDg&scisig=ABFrs3xtOuMXLkkilckRcYmjWbZA)
- [Lostanlen, V., El-Hajj, C., Rossignol, M., Lafay, G., Andén, J., & Lagrange, M. (2021). **Time–frequency scattering accurately models auditory similarities between instrumental playing techniques.** EURASIP Journal on Audio, Speech, and Music Processing, 2021(1), 1-21.](https://scholar.google.com/scholar_url?url=https://asmp-eurasipjournals.springeropen.com/articles/10.1186/s13636-020-00187-z&hl=en&sa=T&oi=gsb-ggp&ct=res&cd=0&d=14039807909065615562&ei=vpamZJDaNcr2mgH68LmoAg&scisig=ABFrs3yvkUpho3swu_2K3TZKa7LO)
- [Muradeli, J., Vahidi, C., Wang, C., Han, H., Lostanlen, V., Lagrange, M., & Fazekas, G. (2022, September). **Differentiable Time-Frequency Scattering On GPU. In Digital Audio Effects Conference (DAFx).**](https://scholar.google.com/scholar_url?url=https://arxiv.org/abs/2204.08269&hl=en&sa=T&oi=gsb&ct=res&cd=1&d=13536706587774301644&ei=ypamZL3ZHPqDy9YP-7iN-Ak&scisig=ABFrs3zg-sPJPLN_fVnXOZax2p31)
- [Patil, K., Pressnitzer, D., Shamma, S., & Elhilali, M. (2012). **Music in our ears: the biological bases of musical timbre perception.** PLoS computational biology, 8(11), e1002759.](https://scholar.google.com/scholar_url?url=https://journals.plos.org/ploscompbiol/article%3Fid%3D10.1371/journal.pcbi.1002759&hl=en&sa=T&oi=gsb-ggp&ct=res&cd=0&d=13643030775703850815&ei=4JamZO6rEd2Sy9YP2fyhwAg&scisig=ABFrs3xGck11FtfqtyU_-VGjR1TE)
- [Thoret, E., Caramiaux, B., Depalle, P., & Mcadams, S. (2021). **Learning metrics on spectrotemporal modulations reveals the perception of musical instrument timbre.** Nature Human Behaviour, 5(3), 369-377.](https://scholar.google.com/scholar_url?url=https://www.nature.com/articles/s41562-020-00987-5&hl=en&sa=T&oi=gsb&ct=res&cd=0&d=2045098833587514380&ei=7ZamZLeBIcSlmAG0wbfwDg&scisig=ABFrs3yyaKxrp3DiNPUS8C7mf3jj)
- [Vahidi, C., Han, H., Wang, C., Lagrange, M., Fazekas, G., & Lostanlen, V. (2023). **Mesostructures: Beyond Spectrogram Loss in Differentiable Time-Frequency Analysis.** arXiv preprint arXiv:2301.10183.](https://scholar.google.com/scholar_url?url=https://arxiv.org/abs/2301.10183&hl=en&sa=T&oi=gsb&ct=res&cd=0&d=9636633872609712841&ei=-5amZJy0MYekmwHAipuwDA&scisig=ABFrs3x0p-dWnaZd4WVxBEr5jB3j)

## **Setup**
- Run the `installation` cells below
- Ensure that imports load and GPU runtime is enabled


In [None]:
#@title Install Dependencies
#@markdown Let's install Kymatio from source and some utilities
!pip install git+https://github.com/kymatio/kymatio.git
!pip install git+https://github.com/cyrusvahidi/meso-dtfa.git#egg=meso_dtfa

In [None]:
#@title Import Python libraries

#@markdown We will use `Scattering1D` and `TimeFrequencyScattering` PyTorch frontends from Kymatio,
#@markdown numpy and some custom signal generators and distances.
import numpy as np
import torch
import librosa
import tqdm
import scipy
import random

import matplotlib.pyplot as plt
from IPython.display import Audio

from kymatio.torch import Scattering1D, TimeFrequencyScattering
from kymatio.scattering1d.filter_bank import scattering_filter_factory

from meso_dtfa.core import generate_am_chirp, grid2d
from meso_dtfa.loss import MultiScaleSpectralLoss, TimeFrequencyScatteringLoss
from meso_dtfa.plot import plot_cqt, plot_contour_gradient, mesh_plot_3d

import warnings
warnings.filterwarnings("ignore")

%matplotlib inline

In [None]:
#@title Check GPU Status
import subprocess

USE_GPU = True #@param {type:"boolean"}

def to_device(tensor: torch.tensor):
    """ Move a torch tensor to the current device
    """
    return tensor.cuda() if USE_GPU else tensor.cpu()

#!nvidia-smi
nvidiasmi_output = subprocess.run(['nvidia-smi', '-L'], stdout=subprocess.PIPE).stdout.decode('utf-8')
print(nvidiasmi_output)
print(f"Using device: {'GPU' if USE_GPU else 'CPU'}")

## `Scattering1D` (Time Scattering)

- $ U_1 x [\lambda, t] = |x * \psi_{\lambda}|(t) $
  - constant-Q wavelet scalogram
  - convolves the signal $x$ with a constant-Q wavelet filterbak $\psi$.
  - Each filter is described by the geometrically spaced log-frequency variable $\lambda$
- $ S_1 x [\lambda, t] = (U_1 x * \phi_T)(\lambda, t) $
  - The **first-order scattering transform** is the result of local averaging $U_1$ with a lowpass filter $\phi_T$
  - It is a a spectral descriptor that offers time-shift invariance upto a support of $T$
- $U_2 x [\lambda, \lambda_2, t] = |U_1 * \psi_{\lambda_2}|(\lambda, t) $
  - High-frequency oscillations in the scalogram are lost in $S_1$ due to averaging
  - We recover them by convolving another wavelet filterbank with $U_1$ along time at every frequency bin $\lambda$
- $ S_2 x [\lambda, \lambda_2, t] = (U_2 x * \phi_T)(\lambda, \lambda_2, t) $
  - **Second-order scattering coefficients**
  - to get a locally time-invariant descriptor of temporal modulations, we average $U_2$ with a lowpass filter $\phi_T$


* $t$: time
* $\lambda$: first-order log-frequency variable
* $\lambda_2$: second-order temporal modulation rate variable
* $\psi_{\lambda}$: first-order wavelet filterbank
* $\psi_{\lambda_2}$: second-order wavelet filterbank
* $\phi_T$: Gaussian lowpass filter of scale $T$
* $U_1$: wavelet scalogram
* $S_1$: first-order scattering transform
* $S_2$: second-order scattering transform

![](https://freight.cargo.site/t/original/i/e44c49dbc48f6fedde0f1e11dfddbe659e0197d518c143db2c23c0843449369a/scattering-diagram.png)

In [None]:
#@title Construct a `Scattering1D` frontend
#@markdown Let's use Kymatio's `torch.Scattering1D` frontend to construct Time Scattering filterbanks


#@markdown $J$ defines the maximal wavelet scale and number of octaves. The largest filter of lowest centre frequency is defined over a support of $2^J$.
#@markdown At larger $J$, we can cover larger time scales and slower frequencies
J = 8 #@param {type:"slider", min:1, max:13, step:1}
#@markdown $Q$ determines how well we can localize a signal in frequency

#@markdown number of filters per octave (first-order):
Q1 = 8 #@param {type:"slider", min:1, max:16, step:1}
#@markdown number of filters per octave (second-order):
Q2 = 1 #@param {type:"slider", min:1, max:16, step:1}

#@markdown $T$ controls the amount of imposed invariance to time-shifts. This defines the scale and stride of the lowpass filters, which is set to $2^J$ by default.
T = 256 #@param {type:"slider", min:1, max:8192, step:1}
Q = (Q1, Q2)

#@markdown We must specify the signal length $N$, which defines the temporal support of the filters
log2_n = 13 #@param {type:"slider", min:4, max:16, step:1}
N = 2**log2_n
x = torch.randn((N, ))

scat1d = to_device(Scattering1D(shape=x.shape, J=J, Q=Q, T=T))

phi_f = scat1d.phi_f
psi1_f = scat1d.psi1_f
psi2_f = scat1d.psi2_f
N_padded = scat1d._N_padded

In [None]:
#@title Plotting the wavelet filterbanks
#@markdown Let's plot the frequency response of first and second-order filterbanks `psi1_f` and `psi2_f`

#@markdown Try changing $J$, $Q_1$ and $Q_2$. What do you observe?

def plot_filterbank(phi, psi, N, order=1):
    _ = plt.figure(figsize=(10, 5))

    _ = plt.plot(np.arange(N_padded) / (N_padded), phi['levels'][0], 'r')

    for psi_f in psi:
        plt.plot(np.arange(N_padded)/(N_padded), psi_f['levels'][0], 'b')

    plt.xlabel(r'$\omega$', fontsize=18)
    plt.ylabel(r'$\hat\psi_j(\omega)$', fontsize=18)
    _ = plt.title(f'order {order} filters', fontsize=18)

plot_filterbank(phi_f, psi1_f, N_padded, order=1)
plot_filterbank(phi_f, psi2_f, N_padded, order=2)

## Filterbank details


We create the filters by constructing an instance of the `kymatio.torch.Scattering1D` frontend.

- `phi_f` contains lowpass filters at various resolutions, For example, `phi_f[0]` is at resolution `T`.
- `psi1_f` contains the first-order Morlet wavelet filters
- `psi2_f` contains the second-order Morlet wavelet filters
- `psi1_f` and `psi2_f` differ in choice of `Q`


Each filter is structured as a dictionary containing:
- the filter coefficients (in Fourier)
- `xi`: the filter's normalized centre frequency
- `sigma`: the filter's width
- `j`: the filter's scale

In [None]:
#@title Let's take a look at a filters' metadata

print(f"Temporal support of the lowpass filter at each resolution: {[len(phi) for phi in phi_f['levels']]}")

print(f"First filter coefficient: {psi1_f[0]['levels'][0]}")

print(f"First filter scale, j: {psi1_f[0]['j']}")

print(f"A first order filter's central frequency $xi$ in Hz, assuming a sampling rate of 4096 Hz:  {psi1_f[0]['xi'] * 4096}")

print(f"A first order filter's characteristic width, sigma: {psi1_f[0]['sigma']}")

In [None]:
#@title Plotting the first-order coefficients, $S_1$
#@markdown This is like a time-averaged CQT
import librosa

def plot_scalogram(Sx, S, duration, sr=22050):
    """
    Args:
      Sx: first-order scattering coefficients
      S: scattering transform frontend
    """
    x_coords = librosa.times_like(Sx, hop_length=S.T)
    y_coords = [psi["xi"]*sr for psi in S.psi1_f] # convert xi to Hz

    plt.figure(figsize=(8, 4))
    librosa.display.specshow(Sx[1:,:].numpy(), sr=sr,
        x_coords=x_coords, x_axis="time",
        y_coords=y_coords, y_axis="cqt_hz",
        cmap="magma")
    plt.xlabel("Time (seconds)")
    plt.ylabel("Frequency (Hz)")
    plt.ylim(0, sr // 2)
    plt.minorticks_off()

N = 2**15
SR = 2**15
kwargs = {'shape': N, 'Q': Q, 'J': J, 'T': T}

#@markdown we set `max_order=1` to only compute first-order coefficients
S = to_device(Scattering1D(**kwargs, max_order=1))

#@markdown Let's plot the $S_1$ of an impulse
x = to_device(torch.zeros(N))
x[N // 2] = 1
Sx = S(x)
plot_scalogram(Sx.cpu(), S, duration=1, sr=SR)

In [None]:
#@title Plotting $S_1$ of an AM/FM signal

#@markdown Next, an amplitude-modulated chirp signal, or chirplet arpeggiator

#@markdown $    \boldsymbol{g_\theta}: t\longmapsto \boldsymbol{\phi}_{w}(\gamma t) \sin(2\pi f_\mathrm{m} t) \sin\left( \dfrac{2\pi f_{\mathrm{c}}}{\gamma \log 2} 2^{\gamma t}\right)$

#@markdown amplitude modulation (AM) is parameterised by modulator frequency $f_m$ and frequency modulation (FM) by chirp rate $\gamma$

#@markdown The chirp's instantaneous frequency increases exponentially with time, while it's amplitude is modulated by a sinusoid.

#@markdown Perceived pitch therefore grows linearly, and we see a linear change along our log-frequency axis

#@markdown see `help(generate_am_chirp)` for details on the synth
fc, fm, gamma = torch.tensor(512.0), torch.tensor(4.0), torch.tensor(1.0)
x = generate_am_chirp([fc, fm, gamma], sr=SR, duration = 1)
Sx = S(to_device(x))
plot_scalogram(Sx.cpu(), S, duration=1, sr=SR)

#@markdown What happens to the amplitude modulations in $S_1$ if you increase T or $f_m$?

In [None]:
help(generate_am_chirp)

## Getting first and second order coefficients
* `Scattering1D` returns a vector per timestep, concatenating zeroth, first and second order coefficients of shape `(num_paths, timesteps)`
* To display the scattering coefficients, we must identify the indices for each order.
* A `Scattering1D` object contains metadata that includes the indices for each scattering order

In [None]:
Fs = 2**14 # sampling rate
duration = 2 # signal duration
J = 12 # maximum wavelet scattering scale
Q = 8

scat1d = Scattering1D(J=J, Q=Q, shape=duration * Fs, T=2**J)

meta = scat1d.meta()
order0 = np.where(meta['order'] == 0)[0] # zeroth-order indices
order1 = np.where(meta['order'] == 1)[0] # first-order indices
order2 = np.where(meta['order'] == 2)[0] # second-order indices

## Visualizing $U_1$, $S_1$ and $S2$ of an AM signal

In [None]:
#@markdown Let's create a harmonic tremolo signal.

#@markdown To help with visualisation, we set its fundamental frequency to the centre frequency of a first-order wavelet filter and the modulation frequency to the centre frequency of a second-order wavelet filter.

f0_xi_idx = 30 # first-order filter index for the fundamental frequency
fm_xi_idx = 8 # second-order filter index for the modulationf requency

f0 = meta['xi'][order1[f0_xi_idx], 0] * Fs # get the fundamental frequency
fm = meta['xi'][order2[fm_xi_idx], 1] * Fs # get the modulation frequency
print(f'fundamental frequency: {f0} Hz')
print(f'modulation frequency: {fm} Hz')

#@markdown We synthesize a harmonic signal with 5 harmonics that are amplitude-modulated by a sinusoid
num_harmonics = 5
harmonics = torch.zeros(num_harmonics, duration * Fs)
t = torch.arange(0, duration, float(1/Fs))
modulator = torch.sin(2.0 * np.pi * fm * t)
harmonics[0] = torch.sin(2.0 * np.pi * f0 * t) * modulator

for i in range(1, num_harmonics):
    harmonics[i] = torch.cos(2.0 * np.pi * (f0 * i) * t) * modulator

x_am = torch.sum(harmonics, dim=0) # sum the harmonics
x_am *= torch.hann_window(x_am.shape[0]).numpy()
x_am /= torch.max(torch.abs(x_am)) # normalize the amplitude

Audio(x_am, rate=Fs)

In [None]:
#@title $U_1$ vs $S_1$ vs $S_2$
def compute_scattering(x, scat1d_u, scat1d, lambda1_idx=None):
    """Compute U1, S1, and S2 for the given signal.

    Parameters:
        x (ndarray): Input signal.
        scat1d_u (Scattering1D): Scattering1D instance with average=False.
        scat1d (Scattering1D): Scattering1D instance.
        lambda1_idx (int or None): Target lambda1 index (first-order parent) for S2 visualization.

    Returns:
        tuple: A tuple containing the following elements:
            u1 (Tensor): Unaveraged scalogram of x.
            s1 (ndarray): First-order scattering transform of x.
            s2 (ndarray): Second-order scattering transform of x.
    """
    # compute the unaveraged scattering transform Ux
    Ux = scat1d_u(x)
    # get the first-order coefficients
    Ux = [u['coef'] for u in Ux if u['order'] == 1]
    max_resolution = max([u.shape[-1] for u in Ux]) # largest length of unaveraged transform
    # resample the first order coefficient to same temporal shape
    Ux = [scipy.signal.resample(u.numpy(), max_resolution, axis=-1) for u in Ux]
    Ux = torch.tensor(np.stack(Ux))

    meta = scat1d.meta()
    order1 = np.where(meta['order'] == 1)[0]

    ''' Select indices of order2 coefficients that corresponds to the
        order1 parent lambda1_idx
    '''
    if lambda1_idx:
        # get s2 keys

        order2 = [i for i, x in enumerate(scat1d.meta()['key']) if x and x[0] == lambda1_idx and scat1d.meta()['order'][i] == 2]
    else:
        order2 = np.where(meta['order'] == 2)[0]

    '''Compute the scattering transform'''
    Sx = scat1d(x)
    S1 = Sx[order1]
    S2 = Sx[order2]

    return Ux, S1, S2

scat1d = Scattering1D(J=10, Q=12, shape=duration * Fs) # scattering transform
scat1d_u = Scattering1D(J=10, Q=12, shape=duration * Fs, T=0, out_type="list") # scattering transform without averaged coefficients

U1, S1, S2 = compute_scattering(x_am, scat1d_u, scat1d, lambda1_idx=f0_xi_idx)

fig, ax = plt.subplots(3, 1, figsize=(8, 15))
for i, (title, sx) in enumerate([("$U_1$", U1), ("$S_1$", S1), ("$S_2$", S2)]):
  ax[i].set_xticks([])
  ax[i].set_yticks([])
  ax[i].set_title(title)
  # plt.ylabel("$\lambda$ ($\log$-frequency)")
  # plt.xlabel("time")
  ax[i].imshow(sx, aspect="auto")
fig.tight_layout()

#@markdown Observe the difference in temporal resolution between $U_1$ and $S_1$/$S_2$
#@markdown Why have the amplitude modulation disappeared in $S_1$?
#@markdown We only plot $S_2$ around a particular first-order parent, i.e. the $\lambda_1$ that is centred around the $f_0$ of the signal
#@markdown Since our modulations are "slow", we see a response in the lowest octave of $S_2$

# `TimeFrequencyScattering` (JTFS)

Let's take a look at `TimeFrequencyScattering`. Unlike `Scattering1D`, JTFS analyses the time-frequency domain ($U_1$) with a cascade of two filterbanks defined along time and frequency: $\psi_{\alpha}^{(t)}$ and $\psi_{\beta}^{(f)}$

This results in a 2-dimensional filterbank, that analyses joint spectrotemporal modulation patterns

- $ U_1 x [\lambda, t] = |x * \psi_{\lambda}| $ constant-Q wavelet scalogram
- $ S_1 x [\lambda, \alpha = 0, \beta t] = |U_1 x(\lambda, t) * \phi_T * \psi_{\beta}| $ first-order time-frequency scattering transform
- $ U_2 x [\lambda, \alpha, \beta, t] = |U_1 x(\lambda, t) * \psi_{\alpha} * \psi_{\beta}| $
- $ S_2 x [\lambda, \lambda_2, t] = (U_2 x(\lambda, \alpha, \beta, t) * \phi_T * \phi_F) $ second-order time-frequency scattering transform


* $t$: time
* $\lambda$: first-order log-frequency
* $\alpha$: second-order temporal modulation rate
* $\beta$: frequential modulation scale - positive and negative
* $\psi_{\lambda}^{(t)}$: first-order wavelet filterbank
* $\psi_{\alpha}^{(t)}$: second-order temporal wavelet filterbank
* $\psi_{\beta}^{(f)}$: frequential wavelet filterbank
* $\phi_T$: Gaussian lowpass filter of scale $T$
* $\phi_F$: Gaussian frequential lowpass filter of scale $F$
* $U_1$: wavelet scalogram
* $S_1$: first-order time-frequency scattering transform
* $S_2$: second-order time-frequency scattering transform


The image below illustrates the 2-dimensional shapes of the filters (left) and their response to a signal containing a sinusoid, impulse and a chirp:
![](https://freight.cargo.site/t/original/i/88e0cdfbc9fd77878caa7a628d7dbf337aad5c1608ec0a45e672b2450b42e44a/Screenshot-2023-07-07-at-14.15.19.png)
Each filter varies in temporal modulation rate, scale in frequency and orientation along the frequency axis.
The more compact filters focus on fast oscillations. Vertically oriented filters ignore frequential modulations, while horizontally oriented filters ignore amplitude modulations.
A sinusoid is like an impulse on the frequency axis, so we see a response across the entire frequential filterbank's spectrum.

## Manifold Embedding of Spectrotemporal Modulations
Let's return to the AM/FM chirp signal. We will compare how various acoustic features represent similarity between sound generated by the synth as its 3 parameters vary.

In [None]:
SR = 2**13

fc = 362 #@param {type:"slider", min:128.0, max:1024.0, step:1.0}
am = 15 #@param {type:"slider", min:4.0, max:16.0, step:1.0}
fm = 1.5 #@param {type:"slider", min:0.5, max:4.0, step:0.5}

fc = torch.tensor(fc, dtype=torch.float32)
am = torch.tensor(am, dtype=torch.float32)
fm = torch.tensor(fm, dtype=torch.float32)
x = generate_am_chirp([fc, am, fm], bw=2, duration=4)


plot_cqt(x.numpy(), sr=SR)
Audio(x, rate=SR)

In [None]:
#@title Similarity Retrieval
import os, tqdm
from sklearn.manifold import Isomap
import numpy as np, matplotlib.pyplot as plt, scipy

def run_isomap(X, params, n_neighbors=40):
    model = Isomap(n_components=3, n_neighbors=n_neighbors)
    Y = model.fit_transform(X)

    plot_isomap(Y, params)


def plot_isomap(Y, params):
    fig = plt.figure(figsize=plt.figaspect(0.5))
    axs = []

    for i in range(3):
        ax = fig.add_subplot(1, 3, i + 1, projection='3d')
        ax.scatter3D(Y[:, 0], Y[:, 1], Y[:, 2], c=params[i], cmap='bwr');

        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_zticklabels([])
        axs.append(ax)

    plt.subplots_adjust(wspace=0, hspace=0)

    # rotate the axes and update
    for angle in range(60, 360, 60):
        for ax in axs:
            ax.view_init(30, angle)
            plt.draw()

In [None]:
#@title Synthetic AM chirp dataset generation
#@markdown We generate 4096 signals lying on a 3D manifold in ($f_c$, $f_m$, $\gamma$)

#@markdown $f_c \in [512, 1024]$

#@markdown $f_m \in [4, 16]$

#@markdown $\gamma \in [0.5, 4]$
n_steps = 16
f0_min, f0_max = 512, 1024
am_min, am_max = 4, 16
fm_min, fm_max = 0.5, 4
bw = 2
duration = 4
sr = 2**13
f0s = np.logspace(np.log10(f0_min), np.log10(f0_max), n_steps)
AM = np.logspace(np.log10(am_min), np.log10(am_max), n_steps)
FM = np.logspace(np.log10(fm_min), np.log10(fm_max), n_steps)

audio = np.zeros((len(f0s), len(AM), len(FM), duration * sr))
params = np.zeros((3, len(f0s) * len(AM) * len(FM)))
c = 0

print('Generating Audio ...')
for i, f0 in tqdm.tqdm(enumerate(f0s)):
    for j, am in enumerate(AM):
        for k, fm in enumerate(FM):
            theta = torch.tensor([f0, am, fm])
            x = generate_am_chirp(theta, sr=sr, duration=duration).numpy()
            audio[i, j, k, :] = x / np.linalg.norm(x)
            params[0, c], params[1, c], params[2, c] = f0, am, fm
            c += 1

In [None]:
#@title Similarity Retrieval - `MFCC`
n_mfcc = 40

Sx = np.zeros((len(f0s), len(AM), len(FM), n_mfcc))

print('Extracting MFCCs ...')
for i, f0 in tqdm.tqdm(enumerate(f0s)):
    for j, fm in enumerate(AM):
        for k, gamma in enumerate(FM):
            Sx[i, j, k,:] = np.mean(librosa.feature.mfcc(y=audio[i,j,k], sr=sr, n_mfcc=n_mfcc), axis=-1)
mfccs = Sx.reshape(-1, Sx.shape[-1])

In [None]:
run_isomap(mfccs, params)

In [None]:
import math
#@title Similarity Retrieval - `Scattering1D`
batch_size = 128
N = duration * sr
Q = 1 #@param {type:"slider", min:1, max:12, step:1}
scat = Scattering1D(shape=(N, ), T=N, Q=Q, J=int(np.log2(N) - 1))
scat = scat.cuda()

X = torch.tensor(audio.reshape(-1, audio.shape[-1]), dtype=torch.float32).cuda()
n_samples = X.shape[0]
n_paths = scat(X[0]).shape[0]

Sx = torch.zeros(n_samples, n_paths).cuda()

for i in tqdm.tqdm(range(math.ceil(n_samples / batch_size))):
    start = i * batch_size
    end = (i + 1) * batch_size
    Sx[start:end, :] = scat(X[start:end, :])[:, :, 0]

In [None]:
run_isomap(Sx.cpu().numpy(), params)

In [None]:
#@title Similarity Retrieval - `TimeFrequencyScattering`
import math
batch_size = 16
N = duration * sr
jtfs = TimeFrequencyScattering(shape=(N,),
                               T=N,
                               Q=(8, 1),
                               J=13, # int(np.log2(N) - 1),
                               Q_fr=2,
                               J_fr=5,
                               F=0,
                               format="time")
jtfs = jtfs.cuda()

Q = 8 #@param {type:"slider", min:1, max:12, step:1}

X = torch.tensor(audio.reshape(-1, audio.shape[-1]),
                 dtype=torch.float32)
n_samples, n_paths = X.shape[0], jtfs(X[0].cuda()).shape[0]

Sx = torch.zeros(n_samples, n_paths)

for i in tqdm.tqdm(range(math.ceil(n_samples / batch_size))):
    start = i * batch_size
    end = (i + 1) * batch_size
    Sx[start:end, :] = jtfs(X[start:end, :].cuda())[:, :, 0]


In [None]:
run_isomap(Sx, params)

## Differentiable Parametric Similarity Retrieval

In [None]:
#@markdown Back to our arpeggiator ...
#@markdown this thing is a differentiable ... $g(\theta)$
#@markdown we can compose it with another function $\Phi = S \cdot g(\theta)$
#@markdown then get gradient of this operator with respect to $\theta$
#@markdown if $S$ represents $\theta$ well, we can backpropagate through $\Phi$ to model $\theta$ and solve optimisation problems with our artificial hearing device
#@markdown let's call this a differentiable mesostructural operator

def run_gradient_viz(loss_type="jtfs", time_shift=None):
    f0 = torch.tensor([512.0], dtype=torch.float32, requires_grad=False).cuda()
    N = 20

    target_idx = N * (N // 2) + (N // 2)

    AM, FM = grid2d(x1=4, x2=16, y1=0.5, y2=4, n=N)
    X = AM.numpy().reshape((N, N))
    Y = FM.numpy().reshape((N, N))
    AM.requires_grad = True
    FM.requires_grad = True
    thetas = torch.stack([AM, FM], dim=-1).cuda()

    sr = 2**13
    duration = 4
    n_input = sr * duration

    theta_target = thetas[target_idx].clone().detach().requires_grad_(False)
    target = (
        generate_am_chirp(
            [f0, theta_target[0], theta_target[1]], sr=sr, duration=duration
        )
        .cuda()
        .detach()
    )

    if loss_type == "jtfs":
        loss_fn = TimeFrequencyScatteringLoss(
            shape=(n_input,),
            #T=2**13,
            Q=(8, 2),
            J=12,
            J_fr=5,
            F="global",
            Q_fr=2,
            format="time",
        )
        Sx_target = loss_fn.ops[0](target.cuda()).detach()
    elif loss_type == "mss":
        loss_fn = MultiScaleSpectralLoss(max_n_fft=1024)

    x, y, u, v = [], [], [], []
    losses, grads = [], []
    for theta in tqdm(thetas):
        am = torch.tensor(theta[0], requires_grad=True, dtype=torch.float32)
        fm = torch.tensor(theta[1], requires_grad=True, dtype=torch.float32)
        audio = generate_am_chirp(
            [torch.tensor([768.0], dtype=torch.float32, requires_grad=False).cuda(), am, fm],
            sr=sr,
            duration=duration,
            delta=(2 ** random.randint(8, 12) if time_shift == "random" else 2**8)
            if time_shift
            else 0,
        )

        loss = (
            loss_fn(audio.cuda(), Sx_target.cuda(), transform_y=False)
            if loss_type == "jtfs"
            else loss_fn(audio, target)
        )
        loss.backward()
        losses.append(float(loss.detach().cpu().numpy()))
        x.append(float(am))
        y.append(float(fm))
        u.append(float(-am.grad))
        v.append(float(-fm.grad))

        grad = np.stack([float(-am.grad), float(-fm.grad)])
        grads.append(grad)

    zs = np.array(losses)
    Z = zs.reshape(X.shape)

    plot_contour_gradient(
        X,
        Y,
        Z,
        target_idx,
        grads,
    )
    mesh_plot_3d(
        X, Y, Z, target_idx,
    )