### In this notebook, we want to investigate the colouration when we filter and sum multiple subband GFDNs

In [None]:
import numpy as np
import torch
import re
import pyfar as pf
import os
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.io import loadmat
from scipy.signal import sosfreqz, sosfiltfilt
from scipy.fft import rfft
from torch import nn
from torchaudio.functional import lfilter
from numpy.typing import ArrayLike, NDArray
from typing import Optional, List, Tuple
from IPython import display
import soundfile as sf
from loguru import logger
from copy import deepcopy

os.chdir('..')  # This changes the working directory to DiffGFDN

from diff_gfdn.dataloader import load_dataset, RIRData, ThreeRoomDataset, RoomDataset
from diff_gfdn.config.config import DiffGFDNConfig, CouplingMatrixType, SubbandProcessingConfig
from diff_gfdn.solver import convert_common_slopes_rir_to_room_dataset
from diff_gfdn.model import DiffGFDNVarReceiverPos
from diff_gfdn.utils import db, ms_to_samps, get_response, spectral_flatness, normalised_echo_density, get_time_reversed_fir_filterbank
from diff_gfdn.plot import plot_magnitude_response
from diff_gfdn.colorless_fdn.utils import get_colorless_fdn_params
from diff_gfdn.colorless_fdn.model import ColorlessFDN
from slope2noise.utils import schroeder_backward_int
from src.run_model import load_and_validate_config

In [None]:
def time_reversed_filtering(num_poly: NDArray, den_poly: NDArray, signal_freq_domain: torch.Tensor, num_fft_bins: int) -> torch.Tensor:
    """
    Time reversed filtering of input signal with FIR filterbank h
    Args:
        num_poly, denom_poly: numerator and denominator coefficients of the time reversed filterbank
        signal_freq_domain: input signal of shape num_freq_bins x num_bands
    Returns:
        torch.tensor: output signal of shape num_freq_bins x num_bands
    """
    signal = torch.fft.irfft(signal_freq_domain, dim=0)
    num_coeffs = len(num_poly) 
    output_signal = torch.zeros(*signal.shape)
    den_poly_pad = torch.tensor(den_poly, dtype=torch.float32)
    num_poly_pad = torch.cat([torch.tensor(num_poly.copy(), dtype=torch.float32), torch.zeros(num_coeffs-1, dtype=torch.float32)])
    output_signal = lfilter(signal, num_poly_pad, den_poly_pad)
    return torch.fft.rfft(output_signal, n=num_fft_bins, dim=0)

In [None]:
freqs = [63, 125, 250, 500, 1000, 2000, 4000, 8000]

for k in range(len(freqs)):
    config_path = 'data/config/'
    fig_path = 'figures/'
    config_name = f'treble_data_grid_training_{freqs[k]}Hz_colorless_loss_diff_delays'
    config_file = config_path + config_name
    config_dict = load_and_validate_config(config_file + '.yml',
                                           DiffGFDNConfig)
    room_data = ThreeRoomDataset(Path(config_dict.room_dataset_path).resolve(), config_dict)
    dataset_has_cs_params = True    
    config_dict = config_dict.model_copy(update={"num_groups": room_data.num_rooms})
    trainer_config = config_dict.trainer_config

    ## initialise variables
    if k == 0:
        H_fdn_init = torch.zeros((trainer_config.num_freq_bins // 2 + 1, config_dict.num_groups, len(freqs)), 
                            dtype=torch.complex64, requires_grad=False)
        H_fdn_subband_init = torch.zeros_like(H_fdn_init)
        H_fdn_subband_time_rev_init = torch.zeros_like(H_fdn_init)
        
        H_fdn_final = torch.zeros_like(H_fdn_init)
        H_fdn_subband_final = torch.zeros_like(H_fdn_final)
        H_fdn_subband_time_rev_final = torch.zeros_like(H_fdn_final)
    
    # force the trainer config device to be CPU
    if trainer_config.device != 'cpu':
        trainer_config = trainer_config.model_copy(update={"device": 'cpu'})

    # get the colorless FDN params
    if config_dict.colorless_fdn_config.use_colorless_prototype:
        colorless_fdn_params = get_colorless_fdn_params(config_dict)
    else:
        colorless_fdn_params = None
    
    # initialise the model
    model = DiffGFDNVarReceiverPos(config_dict.sample_rate,
                     config_dict.num_groups,
                     config_dict.delay_length_samps,
                     trainer_config.device,
                     config_dict.feedback_loop_config,
                     config_dict.output_filter_config,
                     config_dict.decay_filter_config.use_absorption_filters,
                     common_decay_times=room_data.common_decay_times if config_dict.decay_filter_config.initialise_with_opt_values else None,
                     learn_common_decay_times=config_dict.decay_filter_config.learn_common_decay_times,
                     colorless_fdn_params=colorless_fdn_params,
                     use_colorless_loss=trainer_config.use_colorless_loss
                     )

    ### Get each FDN's magnitude response
    freq_bins_rad = torch.tensor(room_data.freq_bins_rad)
    freq_bins_hz = room_data.freq_bins_hz
    z_values = torch.polar(torch.ones_like(freq_bins_rad),
                           freq_bins_rad * 2 * np.pi)
    
    # load the trained weights for the particular epoch
    max_epochs = trainer_config.max_epochs
    checkpoint_dir = Path(trainer_config.train_dir + 'checkpoints/').resolve()

    init_checkpoint = torch.load(f'{checkpoint_dir}/model_e-1.pt',
                                 weights_only=True,
                                 map_location=torch.device('cpu'))
    model.load_state_dict(init_checkpoint, strict=False)
    model.eval()
    H_fdn_init[...,k], _ = model.sub_fdn_output(z_values)

    final_checkpoint = torch.load(f'{checkpoint_dir}/model_e{max_epochs-1}.pt',
                                  weights_only=True,
                                  map_location=torch.device('cpu'))
    # Load the trained model state
    model.load_state_dict(final_checkpoint, strict=False)
    model.eval()
    H_fdn_final[..., k] , _ = model.sub_fdn_output(z_values)

    ### Filter the sub-FDN outputs with the pyfar filters, sum them and observe colouration
    if trainer_config.subband_process_config.use_amp_preserving_filterbank:
        print("I am here")
        if k == 0:
            subband_filters, subband_freqs = pf.dsp.filter.reconstructing_fractional_octave_bands(
                        None,
                        num_fractions=trainer_config.subband_process_config.num_fraction_octaves,
                        frequency_range=trainer_config.subband_process_config.frequency_range,
                        sampling_rate=config_dict.sample_rate,
                    )
            h_time_rev = get_time_reversed_fir_filterbank(subband_filters.coefficients, 
                                                      freq_bins_rad.detach().numpy(), 
                                                      trainer_config.num_freq_bins, 
                                                      plot=True, 
                                                      freq_labels=freqs)

        subband_filter_idx = np.argmin(
            np.abs(subband_freqs -
                   trainer_config.subband_process_config.centre_frequency))
        subband_filter = torch.tensor(
            subband_filters.coefficients[subband_filter_idx])
        subband_filter_freq_resp = torch.fft.rfft(
            subband_filter, n=trainer_config.num_freq_bins)
        subband_rev_filter_req_response = torch.tensor(h_time_rev[subband_filter_idx])

        H_fdn_subband_init[..., k] = H_fdn_init[..., k] * subband_filter_freq_resp.unsqueeze(1)
        H_fdn_subband_final[..., k] = H_fdn_final[..., k] * subband_filter_freq_resp.unsqueeze(1)

        # get time reversed filtered output
        H_fdn_subband_time_rev_init[..., k] = H_fdn_init[...,k] * subband_rev_filter_req_response.unsqueeze(1)
        H_fdn_subband_time_rev_final[..., k] = H_fdn_final[...,k] * subband_rev_filter_req_response.unsqueeze(1)
        
        
    else:
        if k == 0:
            subband_filters = pf.dsp.filter.fractional_octave_bands(
                    None,
                    num_fractions=trainer_config.subband_process_config.num_fraction_octaves,
                    frequency_range=trainer_config.subband_process_config.frequency_range,
                    sampling_rate=config_dict.sample_rate,
                )
            subband_freqs, _ = pf.dsp.filter.fractional_octave_frequencies(
                num_fractions=trainer_config.subband_process_config.num_fraction_octaves,
                frequency_range=trainer_config.subband_process_config.frequency_range,
            )

        subband_filter_idx = np.argmin(
            np.abs(subband_freqs -
                   trainer_config.subband_process_config.centre_frequency))
    
        # safest to filter in time domain and then take transform
        h_fdn_subband_init = sosfiltfilt(subband_filters.coefficients[subband_filter_idx, ...], 
                                         torch.fft.irfft(H_fdn_init[...,k], dim=0).detach().numpy(), 
                                         axis=0)
        H_fdn_subband_init[...,k] = torch.fft.rfft(torch.from_numpy(h_fdn_subband_init.copy()), n=trainer_config.num_freq_bins, dim=0)
        
        h_fdn_subband_final = sosfiltfilt(subband_filters.coefficients[subband_filter_idx, ...], 
                                          torch.fft.irfft(H_fdn_final[...,k], dim=0).detach().numpy(), 
                                          axis=0)
        H_fdn_subband_final[...,k] = torch.fft.rfft(torch.from_numpy(h_fdn_subband_final.copy()), n=trainer_config.num_freq_bins, dim=0) 

    
    # plot colouration of individual FDNs
    save_path = f'{fig_path}/{config_name}_mag_spectrum.png'
    plot_magnitude_response(room_data, config_dict, model, save_path)

#### Plot colouration of filtered and summed FDNs

In [None]:
def mag_response_plot(room_data:RoomDataset, H_sub_fdn_filtered: torch.tensor, freq_bins_hz:torch.tensor, freq_labels:List[int], 
                      title:Optional[str]=None, save_path:Optional[str] = None):
    """Plot magnitude response of filtered sub-FDN responses
    Args:
        H_sub_fdn_filtered is of shape num_bins x num_groups x num_freq_bands
    """
    # Create subplots
    fig, axes = plt.subplots(room_data.num_rooms,
                             1,
                             figsize=(8, 10),
                             sharex=True)

    for i in range(room_data.num_rooms):
       
        axes[i].semilogx(
            freq_bins_hz,
            db(H_sub_fdn_filtered[:, i, :].detach().numpy()),
            label=[f"{freq_labels[k]} Hz" for k in range(len(freq_labels))],
            linestyle="-",
            alpha=0.8,
        )

        axes[i].semilogx(freq_bins_hz, db(torch.sum(H_sub_fdn_filtered[:,i,:], dim=-1).detach().numpy()), 
                         label='summed', linestyle='--')

        axes[i].set_ylabel("Magnitude (dB)")
        axes[i].set_xlabel('Frequencies (Hz)')
        axes[i].set_title(f"FDN {i+1}")
        axes[i].grid(True)
        axes[i].set_xlim([20, 16000])
        axes[i].set_ylim([-60, 20])

        logger.info(
            f'FDN {i+1} spectral flatness is {spectral_flatness(db(torch.sum(H_sub_fdn_filtered[:, i, :], dim=-1).detach().numpy())):.3f}'
        )
    
    axes[i].legend(loc="upper left", bbox_to_anchor=(1, 1))
    if title is not None:
        fig.suptitle(title)
    if save_path is not None:
        fig.savefig(save_path)

mag_response_plot(room_data, H_fdn_subband_init, freq_bins_hz, freqs, title="Initialisation")

In [None]:
fig_path = Path('figures/').resolve()
save_path = f'{fig_path}/test_plots/{config_name}_opt_gfdn_summed_spectrum.png'
mag_response_plot(room_data, H_fdn_final, freq_bins_hz, freqs, title="Post optimisation, no filtering", save_path=save_path)

In [None]:
save_path = f'{fig_path}/test_plots/{config_name}_opt_gfdn_filtered_spectrum.png'
mag_response_plot(room_data, H_fdn_subband_final, freq_bins_hz, freqs, title="Post optimisation with filtering", save_path=save_path)

In [None]:
if trainer_config.subband_process_config.use_amp_preserving_filterbank:
    save_path = f'{fig_path}/test_plots/{config_name}_opt_gfdn_time_reversed_filtered_spectrum.png'
    mag_response_plot(room_data, H_fdn_subband_time_rev_final, freq_bins_hz, freqs, 
                      title="Post optimisation with time-reversed filtering", save_path=save_path)

#### Plot NED

In [None]:
# Impulse response of each FDN
h_init_fdn = torch.fft.irfft(torch.sum(H_fdn_subband_init, dim=-1), dim=0)
h_final_fdn = torch.fft.irfft(torch.sum(H_fdn_subband_final, dim=-1), dim=0)

# initial group delay because of FIR filtering with pyfar
init_fir_delay = subband_filters.coefficients[0].shape[0] // 2

# impulse response of the GFDN
h_init = torch.sum(h_init_fdn, dim=-1)
h_final = torch.sum(h_final_fdn, dim=-1)

fs = room_data.sample_rate
mixing_time_samp = ms_to_samps(0.0, fs)
crop_end_samp = ms_to_samps(5.0, fs)
h_init_trunc = h_init[mixing_time_samp:-crop_end_samp-init_fir_delay]
h_final_trunc = h_final[init_fir_delay+mixing_time_samp:-crop_end_samp]
len_ir = len(h_init_trunc)
time = np.linspace(0, (len_ir-1)/ fs, len_ir-1)

ned_init = normalised_echo_density(h_init_trunc, fs, window_length_ms=50)
ned_final = normalised_echo_density(h_final_trunc, fs, window_length_ms=50)

fig, ax = plt.subplots(2, 1, figsize=(6, 8))
ax[0].plot(time, ned_init, label='Initial')
ax[0].plot(time, ned_final, label='Post optimisation')
for k in range(config_dict.num_groups):
    ned_init_fdn = normalised_echo_density(h_init_fdn[mixing_time_samp:-crop_end_samp-init_fir_delay, k], fs, window_length_ms=50)
    ned_final_fdn = normalised_echo_density(h_final_fdn[init_fir_delay + mixing_time_samp:-crop_end_samp, k], 
                                            fs, window_length_ms=50)
    ax[1].plot(time, ned_init_fdn, label=f'Initial, group={k+1}')
    ax[1].plot(time, ned_final_fdn, label=f'Final, group={k+1}')

for i in range(2):
    ax[i].set_xlabel('Time (s)')
    ax[i].set_ylabel('NED')  
    ax[i].legend(loc="upper left", bbox_to_anchor=(1, 1))
    ax[i].set_xlim([0.001, max(time)])
plt.show()
plt.savefig(f'{fig_path}/test_plots/{config_name}_ned.png')