### In this notebook, we want to investigate the colouration when we filter and sum multiple subband GFDNs

In [None]:
import numpy as np
import torch
import torchaudio.functional as F
import re
import pyfar as pf
import os
from pathlib import Path
import matplotlib.pyplot as plt
scale = 2
plt.rcParams.update({
    'font.size': scale * 8,  # base font size
    'axes.labelsize': scale * 9,  # x/y label
    'xtick.labelsize': scale * 8,
    'ytick.labelsize': scale * 8,
    'legend.fontsize': scale * 8,
    'axes.titlesize': scale * 10,  # usually unused in journal figures
})


from scipy.io import loadmat
from scipy.signal import sosfreqz, sosfiltfilt
from scipy.fft import rfft, irfft, rfftfreq
from torch import nn
from torchaudio.functional import lfilter
from numpy.typing import ArrayLike, NDArray
from typing import Optional, List, Tuple
from IPython import display
import soundfile as sf
from loguru import logger
from copy import deepcopy

os.chdir('..')  # This changes the working directory to DiffGFDN

from diff_gfdn.dataloader import load_dataset, RIRData, ThreeRoomDataset, RoomDataset
from diff_gfdn.config.config import DiffGFDNConfig, CouplingMatrixType, SubbandProcessingConfig
from diff_gfdn.solver import convert_common_slopes_rir_to_room_dataset
from diff_gfdn.model import DiffGFDNVarReceiverPos
from diff_gfdn.utils import db, ms_to_samps,is_unitary, get_response, spectral_flatness, normalised_echo_density, time_reversed_filtering
from diff_gfdn.plot import plot_magnitude_response
from diff_gfdn.colorless_fdn.utils import get_colorless_fdn_params
from diff_gfdn.colorless_fdn.model import ColorlessFDN
from slope2noise.utils import schroeder_backward_int
from src.run_model import load_and_validate_config

In [None]:
freqs = [63, 125, 250, 500, 1000, 2000, 4000, 8000]
num_fir_taps = 2**12

for k in range(len(freqs)):
    config_path = 'data/config/'
    fig_path = 'figures/'
    config_name = f'treble_data_grid_training_{freqs[k]}Hz_colorless_loss_diff_delays'
    config_file = config_path + config_name
    config_dict = load_and_validate_config(config_file + '.yml',
                                           DiffGFDNConfig)
    room_data = ThreeRoomDataset(Path(config_dict.room_dataset_path).resolve(), config_dict)
    dataset_has_cs_params = True    
    config_dict = config_dict.model_copy(update={"num_groups": room_data.num_rooms})
    trainer_config = config_dict.trainer_config

    ## initialise variables
    if k == 0:
        H_fdn_init = torch.zeros((trainer_config.num_freq_bins // 2 + 1, config_dict.num_groups, len(freqs)), 
                            dtype=torch.complex64, requires_grad=False)
        H_fdn_subband_init = torch.zeros(((trainer_config.num_freq_bins  + num_fir_taps) // 2, config_dict.num_groups, len(freqs)), 
                            dtype=torch.complex64, requires_grad=False)
        
        H_fdn_final = torch.zeros_like(H_fdn_init)
        H_fdn_subband_final = torch.zeros_like(H_fdn_subband_init)
    
    # force the trainer config device to be CPU
    if trainer_config.device != 'cpu':
        trainer_config = trainer_config.model_copy(update={"device": 'cpu'})

    # get the colorless FDN params
    if config_dict.colorless_fdn_config.use_colorless_prototype:
        colorless_fdn_params = get_colorless_fdn_params(config_dict)
    else:
        colorless_fdn_params = None
    
    # initialise the model
    model = DiffGFDNVarReceiverPos(config_dict.sample_rate,
                     config_dict.num_groups,
                     config_dict.delay_length_samps,
                     trainer_config.device,
                     config_dict.feedback_loop_config,
                     config_dict.output_filter_config,
                     config_dict.decay_filter_config.use_absorption_filters,
                     common_decay_times=room_data.common_decay_times if config_dict.decay_filter_config.initialise_with_opt_values else None,
                     learn_common_decay_times=config_dict.decay_filter_config.learn_common_decay_times,
                     colorless_fdn_params=colorless_fdn_params,
                     use_colorless_loss=trainer_config.use_colorless_loss
                     )

    ### Get each FDN's magnitude response
    freq_bins_rad = torch.tensor(room_data.freq_bins_rad)
    freq_bins_hz = torch.fft.rfftfreq((trainer_config.num_freq_bins  + num_fir_taps - 1), d=1.0/config_dict.sample_rate)
    z_values = torch.polar(torch.ones_like(freq_bins_rad),
                           freq_bins_rad)
    
    # load the trained weights for the particular epoch
    max_epochs = trainer_config.max_epochs
    checkpoint_dir = Path(trainer_config.train_dir + 'checkpoints/').resolve()

    init_checkpoint = torch.load(f'{checkpoint_dir}/model_e-1.pt',
                                 weights_only=True,
                                 map_location=torch.device('cpu'))
    model.load_state_dict(init_checkpoint, strict=False)
    model.eval()
    H_fdn_init[...,k], _ = model.sub_fdn_output(z_values)

    final_checkpoint = torch.load(f'{checkpoint_dir}/model_e{max_epochs-1}.pt',
                                  weights_only=True,
                                  map_location=torch.device('cpu'))
    # Load the trained model state

    model.load_state_dict(final_checkpoint, strict=False)
    model.eval()
    
    # check for losslessness
    param_dict = model.get_param_dict()
    coupled_feedback_matrix = torch.tensor(param_dict['coupled_feedback_matrix'])
    feedback_matrix = torch.tensor(param_dict['individual_mixing_matrix'])
    assert [is_unitary(feedback_matrix[n]) for n in range(config_dict.num_groups)]
    assert is_unitary(coupled_feedback_matrix)
    
    H_fdn_final[..., k] , _ = model.sub_fdn_output(z_values)

    ### Filter the sub-FDN outputs with the pyfar filters, sum them and observe colouration
    if trainer_config.subband_process_config.use_amp_preserving_filterbank:
        if k == 0:
            subband_filters, subband_freqs = pf.dsp.filter.reconstructing_fractional_octave_bands(
                        None,
                        num_fractions=trainer_config.subband_process_config.num_fraction_octaves,
                        frequency_range=trainer_config.subband_process_config.frequency_range,
                        sampling_rate=config_dict.sample_rate,
            )

        subband_filter_idx = np.argmin(
            np.abs(subband_freqs -
                   trainer_config.subband_process_config.centre_frequency))
        subband_filter = torch.tensor(
            subband_filters.coefficients[subband_filter_idx])
      
        # proper filtering with linear convolution
        H_fdn_subband_init[..., k] = torch.fft.rfft(F.fftconvolve(torch.fft.irfft(H_fdn_init[..., k], dim=0).T, 
                                                                  subband_filter.unsqueeze(0)), dim=-1).T
        H_fdn_subband_final[..., k] = torch.fft.rfft(F.fftconvolve(torch.fft.irfft(H_fdn_final[..., k], dim=0).T, 
                                                                   subband_filter.unsqueeze(0)), dim=-1).T
        
        
    else:
        if k == 0:
            subband_filters = pf.dsp.filter.fractional_octave_bands(
                    None,
                    num_fractions=trainer_config.subband_process_config.num_fraction_octaves,
                    frequency_range=trainer_config.subband_process_config.frequency_range,
                    sampling_rate=config_dict.sample_rate,
                )
            subband_freqs, _ = pf.dsp.filter.fractional_octave_frequencies(
                num_fractions=trainer_config.subband_process_config.num_fraction_octaves,
                frequency_range=trainer_config.subband_process_config.frequency_range,
            )

        subband_filter_idx = np.argmin(
            np.abs(subband_freqs -
                   trainer_config.subband_process_config.centre_frequency))
    
        # safest to filter in time domain and then take transform
        h_fdn_subband_init = sosfiltfilt(subband_filters.coefficients[subband_filter_idx, ...], 
                                         torch.fft.irfft(H_fdn_init[...,k], dim=0).detach().numpy(), 
                                         axis=0)
        H_fdn_subband_init[...,k] = torch.fft.rfft(torch.from_numpy(h_fdn_subband_init.copy()), n=trainer_config.num_freq_bins, dim=0)
        
        h_fdn_subband_final = sosfiltfilt(subband_filters.coefficients[subband_filter_idx, ...], 
                                          torch.fft.irfft(H_fdn_final[...,k], dim=0).detach().numpy(), 
                                          axis=0)
        H_fdn_subband_final[...,k] = torch.fft.rfft(torch.from_numpy(h_fdn_subband_final.copy()), n=trainer_config.num_freq_bins, dim=0) 

    
    # plot colouration of individual FDNs
    save_path = f'{fig_path}/{config_name}_mag_spectrum.png'
    plot_magnitude_response(room_data, config_dict, model, save_path)

#### Plot colouration of filtered and summed FDNs

In [None]:
def mag_response_plot(room_data:RoomDataset, H_sub_fdn_filtered: NDArray, freq_bins_hz:torch.tensor, freq_labels:List[int], 
                      title:Optional[str]=None, save_path:Optional[str] = None, subband_filter_coeffs: Optional[NDArray]=None):
    """Plot magnitude response of filtered sub-FDN responses
    Args:
        H_sub_fdn_filtered is of shape num_bins x num_groups x num_freq_bands
    """
    # Create subplots
    fig, axes = plt.subplots(room_data.num_rooms,
                             1,
                             figsize=(8, 10),
                             sharex=True)

    if subband_filter_coeffs is not None:
        subband_filter_response = rfft(subband_filter_coeffs, n=2 * H_sub_fdn_filtered.shape[0]-1, axis=-1)

    for i in range(room_data.num_rooms):
       
        axes[i].semilogx(
            freq_bins_hz,
            db(H_sub_fdn_filtered[:, i, :]),
            label=[f"{freq_labels[k]} Hz" for k in range(len(freq_labels))],
            linestyle="-",
            alpha=0.6,
        )

        axes[i].semilogx(freq_bins_hz, db(np.sum(H_sub_fdn_filtered[:,i,:], axis=-1)), 
                         label='summed', linestyle=':')
        
        if subband_filter_coeffs is not None:
            axes[i].semilogx(
                freq_bins_hz,
                db(subband_filter_response.T),
                label=[f"OG filter response at {freq_labels[k]} Hz" for k in range(len(freq_labels))],
                linestyle="-.",
                alpha=0.8,
            )

        # axes[i].semilogx(freq_bins_hz, db(np.sum(subband_filter_response, axis=0)), 
        #                  label='summed filter response', linestyle='-')

        # draw octave bands
        for freq in freq_labels:
            axes[i].axvline(freq, ymin=0, ymax=1, color="red", linestyle="--")
        

        axes[i].set_ylabel("Magnitude (dB)")
        axes[i].set_xlabel('Frequencies (Hz)')
        axes[i].set_title(f"FDN {i+1}")
        axes[i].grid(True)
        axes[i].set_xlim([20, 16000])
        axes[i].set_ylim([-60, 20])

        logger.info(
            f'FDN {i+1} spectral flatness is {spectral_flatness(db(np.sum(H_sub_fdn_filtered[:, i, :], axis=-1))):.3f}'
        )
    
    axes[i].legend(loc="upper left", bbox_to_anchor=(1, 1))
    fig.subplots_adjust(hspace=0.5)
    # plt.tight_layout()

    if title is not None:
        fig.suptitle(title)
    if save_path is not None:
        fig.savefig(save_path, bbox_inches='tight')


def impulse_response_plot(room_data:RoomDataset, H_sub_fdn_filtered: NDArray, freq_labels:List[int], 
                      title:Optional[str]=None, save_path:Optional[str] = None):
    # Create subplots
    fig, axes = plt.subplots(room_data.num_rooms,
                             1,
                             figsize=(8, 10),
                             sharex=True)
    
    h_sub_fdn_filtered = irfft(H_sub_fdn_filtered, axis=0)
    time = np.arange(0, h_sub_fdn_filtered.shape[0]/room_data.sample_rate, 1.0/room_data.sample_rate)

    for i in range(room_data.num_rooms):
        for k in range(H_sub_fdn_filtered.shape[-1]):
            axes[i].plot(
                time,
                db(h_sub_fdn_filtered[:, i, k]),
                label=f"{freq_labels[k]} Hz",
                linestyle="-",
                alpha=0.8,
            )
            # display.display(fig)  # Display the updated figure
            # display.clear_output(wait=True)  # Clear the previous output to keep updates in place
            # plt.pause(1.0)

        # axes[i].plot(time, db(np.sum(h_sub_fdn_filtered[:,i,:], axis=-1)), 
        #                  label='summed', linestyle='--', alpha=0.8)
        
        axes[i].set_ylabel("Amplitude (dB)")
        axes[i].set_xlabel('Time (s)')
        axes[i].set_title(f"FDN {i+1}")
        axes[i].grid(True)
        axes[i].set_ylim([-80, 0])
    
    axes[i].legend(loc="upper left", bbox_to_anchor=(1, 1))
    fig.subplots_adjust(hspace=0.5)
    plt.tight_layout()
    if title is not None:
        fig.suptitle(title)
    if save_path is not None:
        fig.savefig(save_path)


### Pre-optimisation response

In [None]:
fig_path = Path('figures/').resolve()
save_path = f'{fig_path}/test_plots/{config_name}_init_gfdn_summed_spectrum.png'
mag_response_plot(room_data, H_fdn_subband_init.detach().numpy(), freq_bins_hz, freqs, title="Initialisation", save_path=save_path)

impulse_response_plot(room_data, H_fdn_subband_init.detach().numpy(), freqs, 
                      title="Initialisation", 
                      save_path=f'{save_path[:-4]}_ir.png')

### Post-optimisation response, no filterbank

In [None]:
fig_path = Path('figures/').resolve()
save_path = f'{fig_path}/test_plots/{config_name}_opt_gfdn_summed_spectrum.png'
mag_response_plot(room_data, H_fdn_final.detach().numpy(), torch.fft.rfftfreq(trainer_config.num_freq_bins, d=1.0/config_dict.sample_rate),
                  freqs, 
                  title="Post optimisation, no filtering", 
                  save_path=save_path)
impulse_response_plot(room_data, H_fdn_final.detach().numpy(), freqs, 
                      title="Post optimisation, no filtering", 
                      save_path=f'{save_path[:-4]}_ir.png')

### Post optimisation response with filterbank

In [None]:
save_path = f'{fig_path}/test_plots/{config_name}_opt_gfdn_filtered_spectrum.png'
mag_response_plot(room_data, H_fdn_subband_final.detach().numpy(), freq_bins_hz, freqs, 
                  title="Post optimisation with filtering", 
                  save_path=save_path, 
                  # subband_filter_coeffs=subband_filters.coefficients
                 )
impulse_response_plot(room_data, H_fdn_subband_final.detach().numpy(), freqs, 
                      title="Post optimisation with filtering", 
                      save_path=f'{save_path[:-4]}_ir.png')

### Post optimisation filtering with time-reversed filterbank

In [None]:
from importlib import reload
import diff_gfdn
reload(diff_gfdn.utils)
from diff_gfdn.utils import time_reversed_filtering

if trainer_config.subband_process_config.use_amp_preserving_filterbank:
    h_fdn_final = torch.fft.irfft(H_fdn_final, n=trainer_config.num_freq_bins, dim=0).detach().numpy()

    h_fdn_subband_time_rev_final = time_reversed_filtering(h_fdn_final, subband_filters.coefficients, 
                                                           time_axis=0, 
                                                           freq_labels=subband_freqs)
    H_fdn_subband_time_rev_final = rfft(h_fdn_subband_time_rev_final, n=trainer_config.num_freq_bins + num_fir_taps - 1, axis=0)

    save_path = f'{fig_path}/test_plots/{config_name}_opt_gfdn_time_reversed_filtered_spectrum.png'
    mag_response_plot(room_data, H_fdn_subband_time_rev_final, freq_bins_hz, freqs,
                      title="Post optimisation with time-reversed filtering", 
                      save_path=save_path,
                     )
    impulse_response_plot(room_data, H_fdn_subband_time_rev_final, freqs, 
                          title="Post optimisation with time-reversed filtering", 
                          save_path=f'{save_path[:-4]}_ir.png',
                         )

#### Get time domain impulse responses

In [None]:
# Impulse response of each FDN
h_init_fdn = torch.fft.irfft(torch.sum(H_fdn_subband_init, dim=-1), dim=0)
h_final_fdn_no_filter = torch.fft.irfft(torch.sum(H_fdn_final, dim=-1), dim=0)
h_final_fdn = torch.fft.irfft(torch.sum(H_fdn_subband_final, dim=-1), dim=0)

# impulse response of the GFDN
audio_path = Path('audio/filterbank_test/').resolve()
h_final_no_filter = torch.sum(h_final_fdn_no_filter, dim=-1)
h_init = torch.sum(h_init_fdn, dim=-1)
h_final = torch.sum(h_final_fdn, dim=-1)

if trainer_config.subband_process_config.use_amp_preserving_filterbank:
    h_final_time_rev_fdn = irfft(np.sum(H_fdn_subband_time_rev_final, axis=-1), axis=0)
    h_final_time_rev_fdn = np.sum(h_final_time_rev_fdn, axis=-1)
    sf.write(f'{audio_path}/{config_name}_sum_time_rev_filtered_colorless_fdn.wav', h_final_time_rev_fdn, room_data.sample_rate)

sf.write(f'{audio_path}/{config_name}_sum_colorless_gfdn.wav', h_final_no_filter.detach().numpy(), room_data.sample_rate)
sf.write(f'{audio_path}/{config_name}_sum_filtered_colorless_gfdn.wav', h_final.detach().numpy(), room_data.sample_rate)


#### Plot NED

In [None]:
# initial group delay because of FIR filtering with pyfar
init_fir_delay = subband_filters.coefficients[0].shape[0] // 2

fs = room_data.sample_rate
mixing_time_samp = ms_to_samps(50.0, fs)
crop_end_samp = ms_to_samps(2000.0, fs)
h_init_trunc = h_init[mixing_time_samp:-crop_end_samp-init_fir_delay]
h_final_trunc = h_final[init_fir_delay+mixing_time_samp:-crop_end_samp]
len_ir = len(h_init_trunc)
time = np.linspace(0, (len_ir-1)/ fs, len_ir-1)

ned_init = normalised_echo_density(h_init_trunc, fs, window_length_ms=50)
ned_final = normalised_echo_density(h_final_trunc, fs, window_length_ms=50)

fig, ax = plt.subplots(config_dict.num_groups + 1, 1, figsize=(6, 10))
ax[0].plot(time, ned_init, label='Initial')
ax[0].plot(time, ned_final, label='Post optimisation')
for k in range(config_dict.num_groups):
    ned_init_fdn = normalised_echo_density(h_init_fdn[mixing_time_samp:-crop_end_samp-init_fir_delay, k], fs, window_length_ms=50)
    ned_final_fdn = normalised_echo_density(h_final_fdn[init_fir_delay + mixing_time_samp:-crop_end_samp, k], 
                                            fs, window_length_ms=50)
    ax[k+1].plot(time, ned_init_fdn, label=f'Initial, group={k+1}')
    ax[k+1].plot(time, ned_final_fdn, label=f'Final, group={k+1}')

for i in range(len(ax)):
    ax[i].set_xlabel('Time (s)')
    ax[i].set_ylabel('NED')  
    ax[i].legend(loc="upper left", bbox_to_anchor=(1, 1))
    ax[i].set_xlim([0.001, max(time)])

fig.savefig(f'{fig_path}/test_plots/{config_name}_ned.png')
plt.show()
