Synthesise a signal as a some of decaying sinusoids to cover the whole frequency range. The filter it with the subband filters — that will give you signals which the network is trying to learn. Then add to these subband components white noise, and synthesise the input signal using the three approaches above:

- Just add the noisy subband components
- Add the noisy subband components filtered by the used subband filters
- Add the noisy subband component filtered by the time reversed versions of the subband fiters

In [None]:
from dataclasses import dataclass
import numpy as np
import pyfar as pf
import soundfile as sf
from numpy.typing import ArrayLike, NDArray
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.fft import rfft, rfftfreq, irfft
from scipy.signal import fftconvolve
from typing import List, Tuple, Optional
import IPython
import os
os.chdir('..')  # This changes the working directory to DiffGFDN


from slope2noise.utils import octave_filtering, decay_kernel
from diff_gfdn.utils import ms_to_samps, db, db2lin, get_time_reversed_fir_filterbank

audio_path = Path('audio/filterbank_test/').resolve()
fig_path = Path('figures/test_plots/').resolve()

In [None]:
def synthesise_modes(fs: float, 
                     freq_bands: List, 
                     num_modes_per_band: List, 
                     per_band_decay_ms: List, 
                     sig_len_ms: float = 2000,
                     plot: bool = True):
    """Synthesise a sum of modes that are equally distributed in freq_bands"""
    sig_len_samp = ms_to_samps(sig_len_ms, fs)
    num_bands = len(freq_bands)
    time = np.arange(0, sig_len_samp / fs, 1.0 / fs)
    modes = np.zeros((num_bands, sig_len_samp))
    # Define the range for radius and angle
    radius_min, radius_max = db2lin(-100), db2lin(0)  # Range for magnitude (r)
    angle_min, angle_max = 0, np.pi    # Range for angle (theta) in radians
    all_modes= []

    if plot:
        fig, ax = plt.subplots(2, 1, figsize=(6, 4))

    for b_idx in range(num_bands):
        mode_freqs = np.random.uniform(low=freq_bands[b_idx] / np.sqrt(2), 
                                       high = freq_bands[b_idx] * np.sqrt(2), 
                                       size = num_modes_per_band[b_idx])

        mode_decays = 2 * np.ones((num_modes_per_band[b_idx],1)) * per_band_decay_ms[b_idx] * 1e-3
        mode_decay_envelope = decay_kernel(mode_decays, time, fs, normalize_envelope=True).squeeze()

        # Generate random radius and angle values
        radii = np.random.uniform(radius_min, radius_max, num_modes_per_band[b_idx])
        angles = np.random.uniform(angle_min, angle_max, num_modes_per_band[b_idx])

        mode_amps = np.real(radii * np.exp(1j * angles))
        modes[b_idx, :] = np.sum(mode_amps[:, None] * mode_decay_envelope * np.cos(2*np.pi*np.outer(mode_freqs, time)), axis=0)

        if plot:
            ax[0].plot(time, db(modes[b_idx, :]), label=f'{freq_bands[b_idx]} Hz')


    if plot:
        ax[1].plot(time, np.sum(modes, axis=0), label='Overall response')
        ax[0].set_xlabel('Time (s)')
        ax[0].set_ylabel('Amplitude (dB)')
        ax[0].legend()
        fig.savefig(f'{fig_path}/og_signal.png')

    return time, np.sum(modes, axis=0)


def get_fft_size(ir_len_samps: int):
    return 2**np.ceil(np.log2(ir_len_samps)).astype(np.int32)
    
def get_multichannel_noise(input_signal: NDArray, noise_rms_db: float):
    """Add Gaussian noise of desired RMS value to input signal"""
    sig_len_samp, num_bands = input_signal.shape
    noise = np.random.randn(sig_len_samp, num_bands)  # Generate standard Gaussian noise (mean=0, std=1)
    des_rms = db2lin(noise_rms_db)
    current_rms = np.sqrt(np.mean(noise**2, axis=0))  # Compute current RMS
    scaled_noise = noise * (des_rms / current_rms)  # Scale to desired RMS
    return scaled_noise
    
def plot_mag_response(signals: NDArray, fs: float, labels: List, ls: List, save_path:Optional[str]=None):
    ir_len, num_chans = signals.shape
    fft_size = get_fft_size(ir_len)

    signals_freq_response = rfft(signals, n=fft_size, axis=0)
    freq_axis = rfftfreq(fft_size, d = 1.0/fs)

    plt.figure(figsize=(8, 6))
    offset = 50
    
    for k in range(num_chans):
        plt.semilogx(freq_axis, db(signals_freq_response[:, k]) + k*offset, linestyle=ls[k], label=labels[k], alpha=0.6)
    plt.legend()
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Magnitude (dB)')
    plt.xlim([20, fs/2])
    if save_path is not None:
        plt.savefig(save_path)
    plt.show()


def filter_multiband_signal(signal: NDArray, subband_filters: NDArray) -> NDArray:
    """Filter a signal with size Nxnum_bands with subband filters of size num_bands x num_coeffs"""
    assert signal.shape[-1] == subband_filters.shape[0]
    num_bands = subband_filters.shape[0]
    output_signal = np.zeros((signal.shape[0], num_bands))
    for b_idx in range(num_bands):
        output_signal[:,b_idx] = fftconvolve(signal[:, b_idx], subband_filters[b_idx, :], mode='same')
    return output_signal


In [None]:
fs = 32000
freq_bands = [63, 125, 250, 500, 1000, 2000, 4000, 8000]
per_band_decay_ms = [1000, 800, 750, 500, 400, 300, 100, 50]
# Define the range of number of modes per band
min_num_modes, max_num_modes = 10, 2000
# Generate N log-spaced values and round to the nearest integer
num_modes_per_band = np.round(np.logspace(np.log10(min_num_modes), np.log10(max_num_modes), len(freq_bands))).astype(int)
time_axis, modal_output = synthesise_modes(fs, freq_bands, num_modes_per_band, per_band_decay_ms)
ir_len = len(time_axis)

# save the audio                                
save_path = f'{audio_path}/clean_signal.wav'
sf.write(save_path, modal_output, fs) 
IPython.display.Audio(save_path)


### Filter the modal signal with FIR filterbank

In [None]:
filtered_modes, subband_filters = octave_filtering(modal_output, fs, freq_bands, get_filter_ir=True, use_amp_preserving_filterbank = True)
modal_recons = np.sum(filtered_modes, axis=-1)
num_bands = filtered_modes.shape[-1]
twice_filtered_modes = filter_multiband_signal(filtered_modes, subband_filters)
twice_filtered_modal_recons = np.sum(twice_filtered_modes, axis=-1)

#### Filter each band with the time reversed filters and sum them to see if the output matches the input

In [None]:
fft_size = subband_filters.shape[-1]
freq_bins_rad = rfftfreq(fft_size) * 2 * np.pi
time_rev_filter_subband_resp = get_time_reversed_fir_filterbank(subband_filters, freq_bins_rad, fft_size, plot=True, freq_labels=freq_bands)
time_rev_subband_filters = irfft(time_rev_filter_subband_resp, n=fft_size, axis=-1)

time_rev_twice_filtered_modes = filter_multiband_signal(filtered_modes, time_rev_subband_filters)
time_rev_modal_recons = np.sum(time_rev_twice_filtered_modes, axis=-1)
# assert np.allclose(modal_output, time_rev_modal_recons, atol=1e-6)

all_signals = np.column_stack((modal_output, modal_recons, twice_filtered_modal_recons, time_rev_modal_recons))
plt.figure()
plt.figure(figsize=(8, 6))
labels = ['OG', 'FIR recons.', 'Twice filtered', 'Twice filtered with time reversed version']
ls = ['-','--','-.', ':']
offset = 0
for k in range(all_signals.shape[-1]):
    plt.plot(time_axis, db(all_signals[:, k]) + k*offset, linestyle=ls[k], label=labels[k], alpha=0.6)
plt.legend(bbox_to_anchor=(1.1, 1.1))
plt.xlabel('Time(s)')
plt.ylabel('Amplitude (dB)')
plt.show()


In [None]:
print(db(np.abs(modal_output[0] - time_rev_modal_recons[0])))

### Add noise to the filtered signals

In [None]:
scaled_noise = get_multichannel_noise(filtered_modes, -60)
noisy_filtered_modes = filtered_modes + scaled_noise
modal_recons_w_noise = np.sum(noisy_filtered_modes, axis=-1)

### Pass noisy signal through the filterbank

In [None]:
filtered_modes_w_noise = filter_multiband_signal(noisy_filtered_modes, subband_filters)

filtered_modal_recons_w_noise = np.sum(filtered_modes_w_noise, axis=-1)

### Pass noisy signal through time-reversed filterbank

In [None]:
time_rev_filtered_modes_w_noise = filter_multiband_signal(noisy_filtered_modes, time_rev_subband_filters)

time_rev_filtered_modal_recons_w_noise = np.sum(time_rev_filtered_modes_w_noise, axis=-1)

### Time and frequency domain plots

In [None]:
all_signals = np.column_stack((modal_recons, modal_recons_w_noise, filtered_modal_recons_w_noise, time_rev_filtered_modal_recons_w_noise))

plt.figure()
plt.figure(figsize=(8, 6))
labels = ['FIR recons.', 'FIR recons. + noise', 'FIR recons + noise + filtered', 'FIR recons + noise + time rev filtered']
ls = ['-','--','-.', ':']
offset = 80
for k in range(all_signals.shape[-1]):
    plt.plot(time_axis, db(all_signals[:, k]) + k*offset, linestyle=ls[k], label=labels[k], alpha=0.6)
plt.legend(bbox_to_anchor=(1.1, 1.1))
plt.xlabel('Time(s)')
plt.ylabel('Amplitude (dB)')
plt.savefig(f'{fig_path}/sum_cosines_impulse_response.png')
plt.show()

plot_mag_response(all_signals, fs, labels, ls, 
                  save_path=f'{fig_path}/sum_cosines_freq_response.png')

### Save the IRs

In [None]:
sf.write(f'{audio_path}/recons.wav', modal_recons, fs)
sf.write(f'{audio_path}/noisy_recons.wav', filtered_modal_recons_w_noise, fs)
sf.write(f'{audio_path}/time_rev_noisy_recons.wav', time_rev_filtered_modal_recons_w_noise, fs)