# Time Scattering in Kymatio

In this tutorial we will work through examples of `Scattering1D` (Time Scattering) in Kymatio.

The intention is to gain an intuition of the physical properties of modulated signals and their scattering transform.

### Part I
* understanding the scattering filterbank construction parameters (`J`, `Q`, `averaging`, `order`, `paths`)
* plotting the wavelet filterbank
* visualizing the response to modulated sounds that appear in music and speech
    - amplitude modulation (tremolo)
    - frequency modulation (vibrato)
    - attacks (note onset)
    - amplitude modulation interference 
    - musical instrument playing techniques

### Part II
* Generate a dataset of synthetic signals with varying spectral shape and interference patterns
* Unsupervised manifold embedding of the nearest neighbour graph (Isomap) of the dataset under Scattering1D

Further documentation can be found here: [Kymatio Github](https://github.com/kymatio/kymatio)

- $ U_1 x [\lambda, t] = |x * \psi_{\lambda}|(t) $
- $ S_1 x [\lambda, t] = (U_1 x * \phi_t)(\lambda, t) $
- $ S_2 x [\lambda, \lambda_2, t] = (|U_1 x * \psi_{\lambda_2}| * \phi_T)(\lambda, t) $

# Installation
Let's install the dev branch of Kymatio and its dependencies

In [None]:
!pip install kymatio
# !pip install git+https://github.com/kymatio/kymatio.git@dev 



#Import 
- Let's import the torch frontend
- The frontend takes few parameters and handles filterbank construction and organisation of the coefficients
- Several frontends are available, including NumPy and Tensorflow

In [None]:
import numpy as np, torch, librosa
import matplotlib.pyplot as plt
from IPython.display import Audio
import scipy

from kymatio.torch import Scattering1D
from kymatio.scattering1d.filter_bank import scattering_filter_factory
import warnings
warnings.filterwarnings("ignore")

%matplotlib inline

# Generate a Scattering Filterbank

First, we will use a core function `scattering_filter_factory` to inspect the construction of a scattering filterbank. This step is usually handled by the frontend.

To construct the filterbank, we must define the following parameters: 
* $N$: the temporal support of the filters 
  - this corresponds to the size of the input signal and must be a power of two
* $J$: the maximum scale 
  - The lowest scale (largest filter, lowest central frequency) will be defined over a time support of $2^J$. $J$ also determines the total number of octaves.
* $Q$: the number of wavelets filter per octave in the first-order filterbank 
  - $Q$ determines how well we can localize a signal in frequency. When $Q = 1$, we get a dyadic wavelet filterbank i.e. subsequent wavelet filter central frequencies are spaced by a factor two and there is exactly one filter per octave. Smaller values of $Q$ result in filters that are wider in the frequency domain and narrower in the time domain. The choice of $Q$ depends on the application. Typically for audio, higher values of $Q$ (between 4 and 16) enable better frequency localization of highly oscillatory signals.
* $T$: the temporal support of the lowpass filter
  - this controls the amount of imposed invariance to time-shifts. Set to $2^J$ by default.

We create the filters by calling `scattering_filter_factory`. It returns the lowpass filter (phi_f), the first-order wavelet filters (psi1_f), and the second-order filters (psi2_f).

In [None]:
N = 4096
J = 8
Q = 8 

phi_f, psi1_f, psi2_f, _ = scattering_filter_factory(np.log2(N), J, Q, T=2**J)

`phi_f` is a dictionary that contains the low-pass filter at different resolutions at each integer key. For example, `phi_f[0]` is at resolution `T` while `phi_f[1]` is at resolution `T / 2`.

`psi1_f` (order 1) and `psi2_f` (order 2) are lists of dictionaries that contain the specification of each wavelet bandpass filter. 

Each dictionary contains: 
- the filter `0`
- the filter's normalized centre frequency `xi` 
- the filter's width `sigma`
- the filter's scale `j`

In [None]:
'''temporal support of the lowpass filter at each resolution'''
print([len(phi) for k, phi in phi_f.items() if type(k) is int])

'''A first order filter'''
print(f'filter 0: {psi1_f[0][0]}')

'''A first order filter's scale'''
print(f'filter 0 scale: {psi1_f[0]["j"]}')

'''A first order filter's central frequency'''
print(f'filter 0 centre frequency: {psi1_f[0]["xi"]}')

'''A first order filter's central frequency in Hz, assuming a sampling rate of 4096 Hz'''
print(f'filter 0 centre frequency (Hz) {psi1_f[0]["xi"] * 4096}')

'''A first order filter's characteristic width'''
print(f'filter 0 width: {psi1_f[0]["sigma"]}')

# Plotting the filters

We can now plot the first-order filters. 

First, we plot display the lowpass filter (at full resolution) in red. We then plot each of the bandpass filters in blue (in the fourier domain). Since we don’t care about the negative frequencies, we limit the plot to the frequency interval [0, 0.5] (nyquist).

* what properties can you observe from the filters?
* how does the bandwidth change with frequency?
* how many filters are there?
  - how does this change with `J` and `Q`?
* what's the ratio between the highest and lowest centre frequency?

In [None]:
_ = plt.figure(figsize=(10, 5))

# FIRST-ORDER FILTERBANK
_ = plt.plot(np.arange(N) / N, phi_f[0], 'r')

for psi_f in psi1_f:
    plt.plot(np.arange(N)/N, psi_f[0], 'b')

plt.xlim(0, 0.5)

plt.xlabel(r'$\omega$', fontsize=18)
plt.ylabel(r'$\hat\psi_j(\omega)$', fontsize=18)
_ = plt.title('First-order filters (Q = {})'.format(Q), fontsize=18)

In [None]:
_ = plt.figure(figsize=(10, 5))

# SECOND-ORDER FILTERBANK
for psi_f in psi2_f:
    plt.plot(np.arange(N)/N, psi_f[0], 'b')

plt.xlim(0, 0.5)

plt.xlabel(r'$\omega$', fontsize=18)
plt.ylabel(r'$\hat\psi_j(\omega)$', fontsize=18)
_ = plt.title('Second-order filters (Q = 1)', fontsize=18)

# Scattering1D frontend 

- Now we will create the scattering 1D torch frontend object. This constructs the filterbank as we did in the previous section and provides convenient methods to compute the scattering transform (`__call__`) and collect the coefficients at each order (`meta`).
- The constructor requires a specification of the maximum scale `J` (number of octaves) and the input signal length `shape`. 
- We can also specify the number of filters per octave `Q` (default = 1), lowpass temporal support `T` (default = $2^J$), `max_order` (default = 2) and whether to get `U` or `S` with the kwarg `average`

In [None]:
from kymatio.torch import Scattering1D

Fs = 2**14 # sampling rate
duration = 2 # signal duration
J = 12 # maximum wavelet scattering scale
Q = 8

scat1d = Scattering1D(J=J, Q=Q, shape=duration * Fs, average=True)

* To inspect the output shape, let's compute the scattering transform of a random input. 
* Computing the scattering transform of a signal is achieved using the __call__ method of the Scattering1D class. 
* The output is an array of shape (C, T). C is the number of scattering coefficient outputs, and T is the number of samples along the time axis. 
* This is typically much smaller than the number of input samples since the scattering transform performs temporal averaging

- Try changing, the frontend's kwargs. What do you observe?

In [None]:
print(scat1d(torch.randn(duration * Fs)).shape)

# Getting first and second order coefficients
* `Scattering1D` returns a vector per timestep, concatenating zeroth, first and second order coefficients. 
* To display the scattering coefficients, we have to identify the indices for each order. 
* A `Scattering1D` object contains meta information, including the indices of each order

In [None]:
meta = scat1d.meta()
order0 = np.where(meta['order'] == 0)[0]
order1 = np.where(meta['order'] == 1)[0]
order2 = np.where(meta['order'] == 2)[0]

# Visualizing the scattering transform of modulated signals
- Now we are going to construct an amplitude modulated harmonic signal. 
  * it's a sinusoidal signal that contains 4 harmonics and sinusoidal amplitude modulation (tremolo)
- To facilitate visualization of the resulting properties of the signal, we will use the filterbank's meta
- We will set fundamental frequency to correspond to the centre frequency of a first-order wavelet
- We will set the modulation frequency to correspond to the centre frequency of a second-order wavelet

- try using different combinations of first and second-order frequencies to synthesize the signal
  * be careful here, as some second-order paths do not have every first-order as a parent

### Scattering plot helper functions

In [None]:
def compute_scattering(x, scat1d_u, scat1d, lambda1_idx=None):
    """ Compute U1, S1 and S2

        Parameters:
            x -- signal
            scat1d_u -- Scattering1D instance with average=False
            scat1d -- Scattering1D instance
            lambda1_idx -- target lambda1 index for S2 visualization

        Returns:
            u1 -- unaveraged scalogram of x
            s1 -- first order scattering transform of x
            s2 -- second order scattering transform of x
    """
    ''' compute the unaveraged scattering transform Ux'''
    Ux = scat1d_u(x)
    ''' get the first-order coefficients '''
    Ux = [u for key, u in Ux.items() if len(key) == 1]
    max_samples = max([u.shape[-1] for u in Ux]) # largest length unaveraged transform
    ''' resample the first order coefficient to same temporal shape'''
    Ux = [scipy.signal.resample(u.numpy(), max_samples, axis=-1) for u in Ux]
    Ux = torch.tensor(np.concatenate(Ux))

    meta = scat1d.meta()
    order1 = np.where(meta['order'] == 1)[0]

    ''' Select indices of order2 coefficients that corresponds to the 
        order1 parent lambda1_idx
    '''
    if lambda1_idx:
        # get s2 keys
        key = [x for x in meta['key'] if len(x) > 1]

        order2 = [x[1] for x in key if x[0] == lambda1_idx]
    else: 
        order2 = np.where(meta['order'] == 2)[0]

    '''Compute the scattering transform'''
    Sx = scat1d(x)
    S1 = Sx[order1]
    S2 = Sx[order2]

    return Ux, S1, S2
  
  
def plot_scat1d(u1, s1, s2):
    """ Plot u1, s1 and s2

      Parameters:
          u1 -- unaveraged scalogram
          s1 -- first order scattering transform
          s2 -- second order scattering transform
      Returns:
  """
    plt.figure(figsize=(10, 15))
    # plt.tick_params(
    #     axis='x',          # changes apply to the x-axis
    #     which='both',      # both major and minor ticks are affected
    #     bottom=False,      # ticks along the bottom edge are off
    #     top=False,         # ticks along the top edge are off
    #     left=False,
    #     right=False,
    #     labelbottom=False)

    ''' plot U1, S1, S2 '''
    for idx, s in enumerate([u1, s1, s2]):
        ax = plt.subplot(3, 1, idx + 1)
        plt.imshow(s, aspect='auto', origin='lower', cmap=plt.get_cmap('jet'))
        ax.set_xticklabels([]); ax.set_yticklabels([])

    plt.subplots_adjust(wspace=0, hspace=0.1)


def plot_signals(signals, scat1d_u, scat1d):
    u1_concat = []
    s1_concat = []
    s2_concat = []

    def append_silence():
        u1_concat.append(torch.zeros(u1.shape[0], u1.shape[1] // 4))
        s1_concat.append(torch.zeros(s1.shape[0], s1.shape[1] // 4))
        s2_concat.append(torch.zeros(s2.shape[0], s2.shape[1] // 4))

    for x in signals:
        u1, s1, s2 = compute_scattering(x, scat1d_u, scat1d, lambda1_idx=None)
        append_silence()
        u1_concat.append(u1)
        s1_concat.append(s1)
        s2_concat.append(s2)
        append_silence()

    u1 = np.concatenate(u1_concat, axis=-1)
    s1 = np.concatenate(s1_concat, axis=-1)
    s2 = np.concatenate(s2_concat, axis=-1)
    plot_scat1d(u1, s1, s2)

### Synthesize a harmonic signal with tremolo

In [None]:
f0_xi_idx = 30 # coefficient index for the first order frequency
fm_xi_idx = 8 # coefficient index for the second order frequency

f0 = meta['xi'][order1[f0_xi_idx], 0] * Fs # get the central frequency of a first order filter
fm = meta['xi'][order2[fm_xi_idx], 1] * Fs # get the central frequency of a second order filter
print(f'fundamental frequency: {f0} Hz')
print(f'modulation frequency: {fm} Hz')

''' synthesize a harmonic signal'''
num_harmonics = 5
harmonics = torch.zeros(num_harmonics, duration * Fs)
t = torch.arange(0, duration, float(1/Fs))
modulator = torch.sin(2.0 * np.pi * fm * t)
harmonics[0] = torch.sin(2.0 * np.pi * f0 * t) * modulator

for i in range(1, num_harmonics):
    harmonics[i] = torch.cos(2.0 * np.pi * (f0 * i) * t) * modulator
    
x_am = torch.sum(harmonics, dim=0) # sum the harmonics
x_am /= torch.max(torch.abs(x_am)) # normalize the amplitude

Audio(x_am, rate=Fs)

### Plot its fourier spectrogram

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.specgram(x_am, Fs=Fs)
plt.xlabel(r'$t$ (seconds)', fontsize=18)
plt.ylabel(r'$f$ (Hz)', fontsize=18)
_ = plt.title('Fourier Spectrogram of the AM harmonic signal', fontsize=18)

* The second-order scattering transform is of the form $S_2x[\lambda, \lambda_2]$
* We want to visualize the response of the second-order wavelet filterbank around a particular first-order parent $\lambda$
* Let's use the $\lambda$ that corresponds to the fundamental frequency of the synthesized signal
* It's important to note that not every possible pairing between $\lambda$ and $\lambda_2$ is performed by the scattering transform. For efficiency purposes, a second-order wavelet filter is only applied to first-order bands that are of smaller scale.
* We will also visualize $U_1$, to show the effects of averaging on $S_1$

### Compute Scattering of the AM Harmonic Signal

In [None]:
# unaveraged scattering
scat1d_u = Scattering1D(J=J, Q=Q, shape=duration * Fs, average=False, vectorize=False)

# scattering
scat1d = Scattering1D(J=J, Q=Q, shape=duration * Fs)

# plot U_1, S_1, and S_2 coefficients around lambda corresponding to f0
u1, s1, s2 = compute_scattering(x_am, scat1d_u, scat1d, lambda1_idx=None)
plot_scat1d(u1, s1, s2)

* What differences are there between $U_1$, $S_1$ and $S_2$?
* What information is lost? How is it recovered?

## Interference Patterns
- Now we are going to construct a chord of two notes 
  * each note is a sinusoidal signal that contains 4 harmonics.
  * the fundamental frequencies of the notes are $f_0$ and $f_0'$ 
  * let's pick an $\xi$ from the filterbank for $f_0$, and one half step above it $\xi+\xi/12$ as $f_0'$.


In [None]:
f0_idx = 30 # coefficient index for the first note in the chord
df_xi_idx = 8

f0 = meta['xi'][order1[f0_idx], 0] * Fs # get the central frequency of a first order filter
df = f0 / 12 
f0_2 = f0 + 2 * df

print(f'first note fundamental frequency: {f0} Hz')
print(f'second note fundamental frequency: {f0_2} Hz')

x_chord = 0
for f in [f0, f0_2]:
    num_harmonics = 4
    t = torch.arange(0, duration, 1/Fs)
    harmonics = torch.zeros(num_harmonics, duration * Fs)
    harmonics[0] = torch.sin(2.0 * np.pi * f * t) 
    for i in range(1, num_harmonics):
        harmonics[i] = torch.sin(2.0 * np.pi * (f * 2 ** i) * t) #four harmonics

    x = torch.sum(harmonics, dim=0) # sum the harmonics
    x_chord += x
x_chord /= torch.max(torch.abs(x_chord)) # normalize the amplitude
Audio(x_chord, rate=Fs)

- Same two notes, arranged in arpeggio instead of a single chord

In [None]:
x_arp = []
for f in [f0, f0_2]:
    num_harmonics = 4
    t = torch.arange(0, duration, 1/Fs)
    harmonics = torch.zeros(num_harmonics, duration//2 * Fs)
    harmonics[0] = torch.sin(2.0 * np.pi * f * t[:len(t)//2]) 
    for i in range(1, num_harmonics):
        harmonics[i] = torch.sin(2.0 * np.pi * (f * 2 ** i) * t[:len(t)//2]) #four harmonics

    x = torch.sum(harmonics, dim=0) # sum the harmonics
    x_arp.extend(x)
x_arp = torch.tensor(x_arp)
x_arp /= torch.max(torch.abs(x_arp)) # normalize the amplitude
Audio(x_arp,rate=Fs)

In [None]:
scat1d = Scattering1D(J=8, Q=12, shape=duration * Fs, average=True)
scat1d_u = Scattering1D(J=8, Q=12, shape=duration * Fs, average=False, vectorize=False)

plot_signals([x_chord, x_arp], scat1d_u, scat1d)

- Compare the $U_1$, $S_1$, $S_2$ of the chord and arpeggio signals, when are they more distinguishable and when are they not?

## Frequency modulated sounds
- We are constructing a harmonic sinusoidal signal with frequency modulation
  * the carrier frequency is f0
  * the modulation frequency is fm, chosen to be the same as the amplitude modulation frequency earlier
  * the depth of frequency modulation is adjustable

In [None]:
f0_idx = 30
f0 = meta['xi'][order1[f0_idx], 0] * Fs 
fm = meta['xi'][order2[8], 1] * Fs 
depth = 30
print(f'fundamental frequency: {f0} Hz')
print(f'modulation frequency: {fm} Hz')
print(f'depth of modulaion: {depth}')

num_harmonics = 5
harmonics = torch.zeros(num_harmonics, duration * Fs)
t = torch.arange(0, duration, 1/Fs)


modulator = torch.sin(2.0 * np.pi * fm * t)
harmonics[0] = torch.sin(2.0 * np.pi * f0 * t + depth * modulator)

for i in range(1, num_harmonics):
    harmonics[i] = torch.sin(2.0 * np.pi * (f0 * i) * t + depth * modulator) #four harmonics with amplitude modulator
    
x_vibrato = torch.sum(harmonics, dim=0) # sum the harmonics
x_vibrato /= torch.max(torch.abs(x_vibrato)) # normalize the amplitude
Audio(x_vibrato, rate=Fs)

In [None]:
scat1d = Scattering1D(J=J, Q=12, shape=duration * Fs, average=True)
scat1d_u = Scattering1D(J=J, Q=12, shape=duration * Fs, average=False, vectorize=False)

plot_signals([x_am, x_vibrato], scat1d_u, scat1d)

- Try adjusting Q, how does it affect difference between the $S_2$ coefficients?

## Noisy attack

In [None]:
decay_r = -6
x_attack = (torch.rand(len(t))*2-1) * torch.exp(i*decay_r*t)
x_attack /= torch.max(torch.abs(x_attack))
Audio(x_attack,rate=Fs)

In [None]:
# plot U_1, S_1, and S_2 coefficients around lambda corresponding to f0
u1, s1, s2 = compute_scattering(x_attack, scat1d_u, scat1d, lambda1_idx=None)
plot_scat1d(u1, s1, s2)

putting them all together and plot

In [None]:
scat1d = Scattering1D(J=J, Q=12, shape=duration * Fs, average=True)
scat1d_u = Scattering1D(J=J, Q=12, shape=duration * Fs, average=False, vectorize=False)
plot_signals([x_am,x_chord,x_arp,x_attack,x_vibrato], scat1d_u, scat1d)

# Part II: One or Two Frequencies? The scattering transform answers.

* This part of the tutorial will explore similarity retrieval of harmonic sounds with the scattering transform
* Scattering transforms can detect interference patterns in the time-frequency domain ...

In [None]:
import random
import scipy.signal
from matplotlib import pyplot as plt
import tqdm
from sklearn.manifold import Isomap
import kymatio

This function build one sample of complex tones according to the following additive synthesis model: 


$  \boldsymbol{y}_{\alpha,r}(t) =
    \sum_{n=1}^{N}
    \dfrac{
    1 + (-1)^{n} r
    }{
    n^{\alpha}
    }
    \cos(n f_0 t)
    \boldsymbol{\phi}_T(t)
$



In [None]:
def generate(fourier_decay, odd_to_even, f0=[16], N=2**10):
    f0_choice = random.choice(f0)
    n_partials = int(N/(2*f0_choice)) - 1
    t = np.linspace(0, 1, N, endpoint=False)
    partials = np.zeros((n_partials, N))
    for partial_id in range(n_partials):
        frequency = (1+partial_id) * f0_choice
        amplitude = (1+odd_to_even*(-((partial_id)%2)**2))/(1+partial_id)**fourier_decay
        partial = amplitude * np.cos(2*np.pi*frequency*t)
        partials[partial_id, :] = partial
        
    x = np.sum(partials, axis=0) * scipy.signal.hann(N)

    return (x , f0_choice)

Let's now plot and see the dataset (in time and frequency), with different $\alpha$ and $r$:

In [None]:
N = 2**10
f0_list = range(12,24)

alphas = np.ravel(np.tile(np.array([0, 0.5, 1.0, 2.5, 2.0]), (1, 5)))
rs = np.ravel(np.tile(np.array([0, 0.25, 0.5, 0.75, 1.0]), (5, 1)).T)
fig, axs = plt.subplots(5, 5, figsize=(10, 10), sharex=True, sharey=True)
axs = axs.flatten()

signals = []

 
for i in range(len(alphas)):
    a = alphas[i]
    r = rs[i]
    x, _ = generate(a, r, f0=f0_list, N=N)
    x = x/np.max(x)
    signals.append(x)
    axs[i].plot(0*x, 'k')
    axs[i].plot(x)
    axs[i].set_xlim(0, N)
    axs[i].set_ylim(-1.1, 1.1)
    axs[i].set_title("α = {:.2f} ; r = {:.2f}".format(a, r), fontsize=10)
    axs[i].set_yticks([])
    axs[i].set_xticks([])
    axs[i].spines["top"].set_visible(False)
    axs[i].spines["right"].set_visible(False)
    axs[i].spines["bottom"].set_visible(False)
    axs[i].spines["left"].set_visible(False)
    axs[i].grid('on', linestyle='--', alpha=0.5)
 

In [None]:

fig, axs = plt.subplots(5, 5, figsize=(10, 10), sharex=True, sharey=True)
axs = axs.flatten()

for i in range(len(alphas)):
    x = signals[i]
    t = np.linspace(0, 1, N, endpoint=False)
    axs[i].plot(np.log2(N*t[1:(1+N//2)]), np.log10(np.abs(np.fft.rfft(x)))[1:])
    axs[i].set_ylim(-5, 3)
    axs[i].set_title("α = {:.2f} ; r = {:.2f}".format(a, r), fontsize=10)
    axs[i].set_xlim(3, 9)
    axs[i].set_xticks([])
    axs[i].set_yticks([])
    axs[i].spines["top"].set_visible(False)
    axs[i].spines["right"].set_visible(False)
    axs[i].spines["bottom"].set_visible(False)
    axs[i].spines["left"].set_visible(False)

What do $\alpha$ and $r$ control ?

You can alos listen to the synthetic sound, and play with different values of $\alpha$ and $r$, with a fixed f0. 

In [None]:

from ipywidgets import interactive
from IPython.display import Audio, display
import numpy as np

def render(alpha, r):
    signal, _ = generate(alpha, r, f0=[256], N=2**15)
    rate = 22050
    display(Audio(data=signal, rate=rate, autoplay=True))
    return signal * 0.5
    
v = interactive(render, alpha=(0.0, 2.0), r=(0.0, 1.0))
display(v)

Finally, we now geneate the dataset, for all $alpha$s and $r$s chosen, with a varying f0. We refer to this dataset as the complex dataset. 

In [None]:
n_alpha = 75
n_r = 75

min_alpha = 0.0
max_alpha = 2.0
min_r = 0.0
max_r = 1.0

r_range = np.linspace(min_r, max_r, n_r, endpoint=True)
alpha_range = np.linspace(min_alpha, max_alpha, n_alpha, endpoint=True)
frequencies = np.zeros((n_r, n_alpha))


X = np.zeros((n_r, n_alpha, N))
for i, r in tqdm.tqdm(enumerate(r_range), total=len(r_range)):
    for j, a in enumerate(alpha_range):
        X[i, j, :], frequencies[i, j]  = generate(a, r, f0=f0_list, N=N)

We also create a simpler dataset, with a fixed value of f0. We refer to the this dataset as the simple dataset. 


In [None]:
n_alpha = 75
n_r = 75

min_alpha = 0.0
max_alpha = 2.0
min_r = 0.0
max_r = 1.0

r_range = np.linspace(min_r, max_r, n_r, endpoint=True)
alpha_range = np.linspace(min_alpha, max_alpha, n_alpha, endpoint=True)

X_simple = np.zeros((n_r, n_alpha, N))
for i, r in tqdm.tqdm(enumerate(r_range), total=len(r_range)):
    for j, a in enumerate(alpha_range):
        x, _ = generate(a, r, f0=[16], N=N)
        x = x/np.max(x)
        X_simple[i, j, :] = x

### 2. Representations

The goal here is to see, depending on the representation of the audio signals, if we can learn an embedding that capture the parameters of our data ($\alpha$ and $r$). 
To do so, we will usr the Isomap algorithm ([link](https://www.science.org/doi/pdf/10.1126/science.290.5500.2319?casa_token=3jO92IQhP-oAAAAA:BsWxsyAUddTwP8NUO0GY7YZ-L2CiuSX4iTfevpySjfsJmcokM8SqaFN1v3gigDWH02fmlrFbd6mmEfbN)) for unsupervised manifold learning.

We will compare applying it on either the waveform and scattering transform of the signals, on the simple dataset (one f0) and the complex dataset (sample f0). 

In [None]:
def make_manifold(Xmat, n_components=2, n_neighbors=50, n_r=50, n_alpha=50, N=2**10):

    raw_isomap = Isomap(n_neighbors=n_neighbors, n_components=n_components)
    raw_iso = raw_isomap.fit_transform(Xmat)
    raw_isoReshape = raw_iso.reshape((n_r, n_alpha, n_components))
    
    return raw_isoReshape


In [None]:
n_components=2
n_neighbors=50
Xmat_simple = X_simple.reshape((n_r*n_alpha, N))

raw_isoReshape = make_manifold(Xmat_simple, n_components, n_neighbors, n_r, n_alpha, N)

Now, to visualize the manifold, for the simple dataset. This manifold is learn directly on the waveform, without transforming the generated audio signal. 

In [None]:
alpha_tiled = np.tile(alpha_range, (n_r, 1));
r_tiled = np.tile(r_range, (n_alpha, 1)).T;

fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharex=False, sharey=True)
axs[0].scatter(
    np.ravel(raw_isoReshape[:, :, 0]),
    np.ravel(raw_isoReshape[:, :, 1]),
    c = np.ravel(r_tiled), cmap='coolwarm');
axs[1].scatter(
    np.ravel(raw_isoReshape[:, :, 0]),
    np.ravel(raw_isoReshape[:, :, 1]),
    c = np.ravel(alpha_tiled), cmap='coolwarm');

Let's see how things go with the more complex dataset :

In [None]:
Xmat = X.reshape((n_r*n_alpha, N))


raw_isoReshape = make_manifold(Xmat,  n_components, n_neighbors, n_r, n_alpha, N)


alpha_tiled = np.tile(alpha_range, (n_r, 1));
r_tiled = np.tile(r_range, (n_alpha, 1)).T;

fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharex=False, sharey=True)
axs[0].scatter(
    np.ravel(raw_isoReshape[:, :, 0]),
    np.ravel(raw_isoReshape[:, :, 1]),
    c = np.ravel(r_tiled), cmap='coolwarm');
axs[1].scatter(
    np.ravel(raw_isoReshape[:, :, 0]),
    np.ravel(raw_isoReshape[:, :, 1]),
    c = np.ravel(alpha_tiled), cmap='coolwarm');

What differences can you see between the two manifolds ?


Let's try it with the scattering transform.

In [None]:
def make_scattering(X, n_r=50, n_alpha=50, N=2**10):

    Xmat = X.reshape((n_r*n_alpha, N))

    scattering = kymatio.Scattering1D(J=int(np.log2(N)), Q=1, shape=(N,))
    Xmat_torch = torch.from_numpy(Xmat).float()
    Smat = np.maximum(0, scattering(Xmat_torch).numpy()[:, :, 0])
    return Smat

In [None]:
Smat = make_scattering(X_simple, n_r, n_alpha)

And now, learn a manifold for this matrix

In [None]:

n_components = 3
n_neighbors = 100

manifold_scattering = make_manifold(Smat, n_components, n_neighbors,  n_r, n_alpha, N)

alpha_tiled = np.tile(alpha_range, (n_r, 1));
r_tiled = np.tile(r_range, (n_alpha, 1)).T;


fig, axs = plt.subplots(1, 2, figsize=(15, 5), sharex=False, sharey=True)
axs = axs.flatten()

axs[0].scatter(
    np.ravel(manifold_scattering[:, :, 0]),
    np.ravel(manifold_scattering[:, :, 1]),
    c = np.ravel(r_tiled), cmap='coolwarm');
axs[1].scatter(
    np.ravel(manifold_scattering[:, :, 0]),
    np.ravel(manifold_scattering[:, :, 1]),
    c = np.ravel(alpha_tiled), cmap='coolwarm');


Do you see a difference between the manifold learn on waveform and on scattering trasnform for the simple dataset ?

Now, we do the same but with the complex dataset.

In [None]:
Smat = make_scattering(X, n_r, n_alpha)

n_components = 3
n_neighbors = 100

manifold_scattering = make_manifold(Smat, n_components, n_neighbors,   n_r, n_alpha, N)

In [None]:
alpha_tiled = np.tile(alpha_range, (n_r, 1));
r_tiled = np.tile(r_range, (n_alpha, 1)).T;


fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharex=False, sharey=True)
axs = axs.flatten()

axs[0].scatter(
    np.ravel(manifold_scattering[:, :, 0]),
    np.ravel(manifold_scattering[:, :, 1]),
    c = np.ravel(r_tiled), cmap='coolwarm');
axs[1].scatter(
    np.ravel(manifold_scattering[:, :, 0]),
    np.ravel(manifold_scattering[:, :, 1]),
    c = np.ravel(alpha_tiled), cmap='coolwarm');
axs[2].scatter(
    np.ravel(manifold_scattering[:, :, 0]),
    np.ravel(manifold_scattering[:, :, 1]),
    c = np.ravel(frequencies), cmap='coolwarm');

What differences between the two manifold for the complex dataset do you see ? 
Can you retrieve the parameters of our signals ?


We can also vizualise the manifold in 3D :


In [None]:
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 unused import


fig = plt.figure(figsize=(3, 3))
ax = fig.add_subplot(111, projection='3d')

ax.scatter(
    manifold_scattering[:, :, 0],
    manifold_scattering[:, :, 1],
    manifold_scattering[:, :, 2],
    c=np.ravel(r_tiled),
    s=6.0,
    alpha=0.5, cmap='coolwarm')

ax.set_xlabel('')
ax.set_ylabel('')
ax.set_zlabel('')
plt.gca().set_xticklabels([])
plt.gca().set_yticklabels([])
plt.gca().set_zticklabels([])
plt.gca().grid(color='g')


fig = plt.figure(figsize=(3, 3))
ax = fig.add_subplot(111, projection='3d')

ax.scatter(
    manifold_scattering[:, :, 0],
    manifold_scattering[:, :, 1],
    manifold_scattering[:, :, 2],
    c=np.ravel(alpha_tiled),
    s=20.0,
    alpha=0.5, cmap='coolwarm')

ax.set_xlabel('')
ax.set_ylabel('')
ax.set_zlabel('')
plt.gca().set_xticklabels([])
plt.gca().set_yticklabels([])
plt.gca().set_zticklabels([])



fig = plt.figure(figsize=(3, 3))
ax = fig.add_subplot(111, projection='3d')

ax.scatter(
    manifold_scattering[:, :, 0],
    manifold_scattering[:, :, 1],
    manifold_scattering[:, :, 2],
    c=-np.ravel(frequencies),
    s=6.0,
    alpha=0.5, cmap='coolwarm')

ax.set_xlabel('')
ax.set_ylabel('')
ax.set_zlabel('')
plt.gca().set_xticklabels([])
plt.gca().set_yticklabels([])
plt.gca().set_zticklabels([])
