### Load dataset and save IRs

In [None]:
from pathlib import Path
from typing import Tuple, Optional, Dict, List
from numpy.typing import ArrayLike, NDArray
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torchaudio
import soundfile as sf
from loguru import logger

os.chdir('..')  # This changes the working directory to DiffGFDN
from diff_gfdn.dataloader import ThreeRoomDataset, load_dataset
from diff_gfdn.config.config import DiffGFDNConfig, CouplingMatrixType
from diff_gfdn.model import DiffGFDNVarReceiverPos
from diff_gfdn.utils import get_response, db, ms_to_samps, is_unitary, is_paraunitary, normalised_echo_density
from diff_gfdn.losses import get_stft_torch, get_edr_from_stft
from diff_gfdn.plot import (plot_polynomial_matrix_magnitude_response, 
                            plot_edr, plot_subband_amplitudes, plot_subband_edc,
                            plot_learned_svf_response, plot_amps_in_space, plot_magnitude_response)
from diff_gfdn.analysis import get_amps_for_rir
from src.run_model import load_and_validate_config


### Read config files and dataset

In [None]:
config_path = 'data/config/'
config_name = 'treble_data_grid_training_full_band_colorless_loss'
config_file = config_path + f'{config_name}.yml'
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)


if "3room_FDTD" in config_dict.room_dataset_path:
    room_data = ThreeRoomDataset(Path(config_dict.room_dataset_path).resolve(), config_dict)
else:
    room_data = convert_common_slopes_rir_to_room_dataset(config_dict.room_dataset_path, 
                                                          num_freq_bins=config_dict.trainer_config.num_freq_bins,
                                                          )

In [None]:
# add number of groups to the config dictionary
config_dict = config_dict.model_copy(update={"num_groups": room_data.num_rooms})

if config_dict.sample_rate != room_data.sample_rate:
    logger.warn("Config sample rate does not match data, alterning it")
    config_dict.sample_rate = sample_rate

# get the training config
trainer_config = config_dict.trainer_config

# force the trainer config device to be CPU
if trainer_config.device != 'cpu':
    trainer_config = trainer_config.model_copy(update={"device": 'cpu'})

# prepare the training and validation data for DiffGFDN
train_dataset, valid_dataset = load_dataset(
    room_data, trainer_config.device, train_valid_split_ratio=1.0,
    batch_size=trainer_config.batch_size, shuffle=False)

### Check output data and compare with true IR

In [None]:
# initialise the model
model = DiffGFDNVarReceiverPos(room_data.sample_rate, 
                               room_data.num_rooms,
                               config_dict.delay_length_samps,
                               trainer_config.device, 
                               config_dict.feedback_loop_config,
                               config_dict.output_filter_config,
                               config_dict.decay_filter_config.use_absorption_filters,
                               common_decay_times=room_data.common_decay_times,
                               band_centre_hz=room_data.band_centre_hz,
                               )

In [None]:
audio_directory  = Path("audio/")
fig_path = Path("figures").resolve()
checkpoint_dir = Path(trainer_config.train_dir + 'checkpoints/').resolve()
max_epochs = trainer_config.max_epochs
save_ir_dir = Path(trainer_config.ir_dir).resolve() 
plot_ir = False
# investigate outputs at this position
pos_to_investigate = [9.3, 6.60, 1.50] #[2.0, 6.8, 1.5]
desired_filename = f'ir_({pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f}).wav'

# find amplitudes corresponding to the receiver position
rec_pos_idx = np.where(
    np.all(np.round(room_data.receiver_position, 2) == pos_to_investigate, axis=1))[0]
amplitudes = room_data.amplitudes[rec_pos_idx, ...]
h_true = np.squeeze(room_data.rirs[rec_pos_idx, :])

all_pos = []
all_rirs = []
h_approx_list = []

In [None]:
output_biquad_coeffs = []
for epoch in [-1,max_epochs-1]:
    # load the trained weights for the particular epoch
    checkpoint = torch.load(f'{checkpoint_dir}/model_e{epoch}.pt', weights_only=True, map_location=torch.device('cpu'))
    # Load the trained model state
    model.load_state_dict(checkpoint)
    # in eval mode, no gradients are calculated
    model.eval()
    
    for data in train_dataset:
        position = data['listener_position']
        H, h = get_response(data, model)
        
        for num_pos in range(position.shape[0]):
            filename = f'ir_({position[num_pos,0]:.2f}, {position[num_pos, 1]:.2f}, {position[num_pos, 2]:.2f}).wav'

            # collate all RIRs at all positions
            if epoch == max_epochs - 1:
                all_pos.append(position[num_pos])
                all_rirs.append(h[num_pos, :])

            if filename == desired_filename:
                # get parameter dictionary used in inferencing
                inf_param_dict = model.get_param_dict_inference(data)

                # get the biquad coefficients for this position
                h_approx_list.append(h[num_pos,:])
                output_biquad_coeffs.append(inf_param_dict['output_biquad_coeffs'][num_pos])

                if plot_ir:
                    # plot the EDRs of the true and estimated
                    plot_edr(torch.tensor(h_true), model.sample_rate, title=f'True RIR EDR, epoch={epoch}', 
                             save_path=f'{fig_path}/true_edr_{filename}_{config_name}_epoch={epoch}.png')
    
                    plot_edr(h[num_pos, :], model.sample_rate, title=f'Estimated RIR EDR, epoch={epoch}', 
                             save_path=f'{fig_path}/approx_edr_{filename}_{config_name}_epoch={epoch}.png')
            
                    plt.plot(torch.stack((torch.tensor(h_true), h[num_pos, :len(h_true)]), dim=-1))
                    plt.xlim([0, int(1.5 * model.sample_rate)])
                    plt.savefig(f'{fig_path}/ir_compare_{filename}_{config_name}_epoch={epoch}.png')
                    plt.show()


#### Plot subband EDC as a function of epoch number

In [None]:
from diff_gfdn.plot import plot_subband_edc
plot_subband_edc(h_true, h_approx_list, config_dict.sample_rate, room_data.band_centre_hz, pos_to_investigate, 
                 save_path=f'{fig_path}/compare_synth_edf_{pos_to_investigate}_{config_name}.png', 
                 epoch_numbers=[-1, max_epochs-1])

#### Plot the output filter response for the position under investigation

In [None]:
from diff_gfdn.plot import plot_learned_svf_response
save_path = f'{fig_path}/{config_name}'
plot_learned_svf_response(room_data.num_rooms, int(room_data.sample_rate), 
                          output_biquad_coeffs, pos_to_investigate, save_path=save_path, epoch_numbers=[-1, max_epochs-1])

#### Investigate the modes of the receiver filters to see how they affect the modes of the network

In [None]:
from numpy.typing import NDArray
from typing import Tuple, Optional, Dict, List
from scipy.signal import sos2zpk, zpk2tf, residue

def compute_modes_from_sos(sos: NDArray) -> Tuple[ArrayLike, ArrayLike]:
    """
    Compute the modes (poles and residues) using partial fraction expansion from SOS representation.

    Args:
        sos (array-like): Second-order sections (SOS) filter representation.

    Returns:
        poles (np.ndarray): Poles of the system.
        residues (np.ndarray): Residues corresponding to each pole.
    """
    # Convert SOS to zero-pole-gain form
    z, p, k = sos2zpk(sos)

    # Convert to transfer function form
    b, a = zpk2tf(z, p, k)  # Get numerator and denominator polynomials

    # Perform partial fraction expansion
    residues, poles, _ = residue(b, a)

    return poles, residues
    
def zp_to_modes(residues: List, poles: List, fs: float) -> Dict:
    radius = np.abs(poles)
    angles = np.angle(poles)
    assert (radius < np.ones_like(radius)).all()
    modal_params = {}
    # in Hz
    modal_params['freqs'] = (angles * fs) / (2*np.pi)
    # T60 in ms
    modal_params['decay'] = (6.91 / (-np.log(radius) * fs)) * 1e3
    # in db
    modal_params['amps'] = db(residues)
    return modal_params
    
for idx, epoch in zip([1], [max_epochs-1]):
    cur_biquad_coeffs = output_biquad_coeffs[idx]
    for k in range(room_data.num_rooms):
        cur_sos = cur_biquad_coeffs[k]
        poles, residues = compute_modes_from_sos(cur_sos)
        modal_params = zp_to_modes(residues, poles, room_data.sample_rate)
        logger.info(f'Modal parameters for receiver filters in group {k+1} are: ')
        logger.info(f'{np.round(modal_params['freqs'], 3)} Hz, \n {np.round(modal_params['decay'], 3)} ms, \n {np.round(modal_params['amps'], 3)} dB')

#### Plot actual and achieved subband amplitudes for the position under investigation

In [None]:
# get the actual RIR levels
reload(diff_gfdn.plot)
from diff_gfdn.plot import plot_subband_amplitudes
final_approx_rir = h_approx_list[-1].clone().detach()
plot_subband_amplitudes(h_true, final_approx_rir, room_data.sample_rate, 
                        config_dict.num_groups, amplitudes, np.squeeze(room_data.common_decay_times), room_data.band_centre_hz)

### Get the final trained parameters and investigate them

In [None]:
from scipy.io import loadmat
param_path = Path(trainer_config.train_dir + '/parameters_opt.mat').resolve()
opt_params = loadmat(param_path)

#### Observe the individual mixing matrices

In [None]:
from importlib import reload
import diff_gfdn
reload(diff_gfdn.utils)
from diff_gfdn.utils import is_unitary

if config_dict.feedback_loop_config.coupling_matrix_type in (CouplingMatrixType.SCALAR, CouplingMatrixType.FILTER):
    M_list = opt_params['feedback_loop.M']
    num_groups = model.num_groups
    fig, ax = plt.subplots(num_groups, 1, figsize=(6,6))
    
    for i in range(num_groups):
        M = torch.from_numpy(M_list[i, ...])
        with torch.no_grad():
            M_ortho = model.feedback_loop.ortho_param(M)
        ax[i].matshow(torch.abs(M_ortho))
        ax[i].set_title(f'Room {i}')
        is_ortho, max_val = is_unitary(M_ortho)
    plt.tight_layout()
    plt.savefig(f'{fig_path}/individual_feedback_matrices.png')

#### Observe the input gains, coupling matrix and the coupled mixing matrix

In [None]:
coupled_feedback_matrix = opt_params['coupled_feedback_matrix']
input_gains = opt_params['input_gains'][0]
output_gains = opt_params['output_gains'][0]
print(f'Norm of input gains {np.linalg.norm(input_gains)}')
print(f'Norm of output gains {np.linalg.norm(output_gains)}')

if config_dict.feedback_loop_config.coupling_matrix_type == CouplingMatrixType.SCALAR:
    # assert is_unitary(torch.from_numpy(coupled_feedback_matrix))[0]    
    coupling_matrix = opt_params['coupling_matrix']
    plt.figure()
    plt.subplot(211)
    plt.matshow(np.abs(coupling_matrix), fignum=False)
    plt.colorbar()
    plt.title('Coupling matrix')
    plt.subplot(212)
    plt.matshow(np.abs(coupled_feedback_matrix), fignum=False)
    plt.colorbar()
    plt.title('Coupled feedback matrix')
    plt.tight_layout()
    plt.savefig(f'{fig_path}/scalar_coupling_matrix.png')


elif config_dict.feedback_loop_config.coupling_matrix_type == CouplingMatrixType.FILTER:
    # assert is_paraunitary(torch.from_numpy(coupled_feedback_matrix))[0]
    coupling_matrix = opt_params['coupling_matrix']
    num_freq_bins = 2**10
    plot_polynomial_matrix_magnitude_response(coupling_matrix, model.sample_rate, num_freq_bins, 
                                             'Coupling matrix response', save_path=f'{fig_path}/pu_coupling_matrix.png')
else:
    feedback_matrix = opt_params['coupled_feedback_matrix']
    unit_flag, max_val = is_unitary(torch.tensor(feedback_matrix), max_tol=1e-4)
    assert unit_flag
    plt.figure()
    plt.matshow(np.abs(feedback_matrix))
    plt.title('Optimised feedback matrix')
    plt.savefig(f'{fig_path}/random_coupling_matrix.png')

#### Plot magnitude response of each sub-FDN to inspect colouration

In [None]:
save_path = f'{fig_path}/{config_name}_mag_spectrum.png'
plot_magnitude_response(room_data, config_dict, model, save_path)

#### Plot NED pre and post optimisation

In [None]:
fs = room_data.sample_rate
mixing_time_samp = ms_to_samps(20.0, fs)
crop_end_samp = ms_to_samps(5.0, fs)
trunc_true_ir = h_true[mixing_time_samp:-crop_end_samp]
len_ir = len(trunc_true_ir)
time = np.linspace(0, (len_ir-1)/ config_dict.sample_rate, len_ir-1)

ned_fdn = np.zeros((len_ir-1, max_epochs))
ned_true = normalised_echo_density(h_true[mixing_time_samp:-crop_end_samp], 
                                   config_dict.sample_rate, window_length_ms=50)

fig, ax = plt.subplots()
# ax.plot(time, ned_true, label='Reference')
iterate_over_epochs = [-1, max_epochs-1]
for k, epoch in zip(range(len(iterate_over_epochs)), iterate_over_epochs):
    h_cur = h_approx_list[k]
    ned_fdn[:, epoch] = normalised_echo_density(h_cur[mixing_time_samp: mixing_time_samp + len(trunc_true_ir)], 
                                                config_dict.sample_rate, window_length_ms=50)
    ax.plot(time, ned_fdn[:, epoch], label=f'GFDN, Epoch={epoch}')
ax.set_xlabel('Time (s)')
ax.set_ylabel('NED')
ax.legend()
ax.set_xlim([0.001, max(time)])
fig.savefig(f'{fig_path}/{config_name}_ned.png')
plt.show()

### Plot the amplitude distribution for each RIR as a position of space for 1kHz band

In [None]:
from importlib import reload
import diff_gfdn
reload(diff_gfdn.plot)
from diff_gfdn.plot import plot_amps_in_space, plot_edc_error_in_space

plot_edc_error_in_space(room_data, all_rirs, all_pos, freq_to_plot=None,save_path=f'{fig_path}/{config_name}')