### Load dataset and save IRs

In [None]:
from diff_gfdn.dataloader import ThreeRoomDataset, load_dataset
from diff_gfdn.config.config import DiffGFDNConfig, CouplingMatrixType
from diff_gfdn.model import DiffGFDNVarReceiverPos
from diff_gfdn.utils import get_response, db, is_unitary, is_paraunitary
from diff_gfdn.losses import get_stft_torch, get_edr_from_stft
from diff_gfdn.plot import plot_polynomial_matrix_magnitude_response, plot_edr
from run_model import load_and_validate_config

from pathlib import Path
from typing import Tuple, Optional
from numpy.typing import ArrayLike
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torchaudio
import soundfile as sf
from loguru import logger

### Read config files and dataset

In [None]:
config_path = '../data/config/'
config_name = 'antialiasing_reg_loss_more_layers_filter_coupling'
config_file = config_path + f'{config_name}.yml'
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)
room_data = ThreeRoomDataset(Path(config_dict.room_dataset_path).resolve(), config_dict)

In [None]:
# add number of groups to the config dictionary
config_dict = config_dict.copy(update={"num_groups": room_data.num_rooms})

if config_dict.sample_rate != room_data.sample_rate:
    logger.warn("Config sample rate does not match data, alterning it")
    config_dict.sample_rate = sample_rate

# get the training config
trainer_config = config_dict.trainer_config

# force the trainer config device to be CPU
if trainer_config.device != 'cpu':
    trainer_config = trainer_config.copy(update={"device": 'cpu'})

# prepare the training and validation data for DiffGFDN
train_dataset, valid_dataset = load_dataset(
    room_data, trainer_config.device, trainer_config.train_valid_split,
    trainer_config.batch_size)

### Check output data and compare with true IR

In [None]:
# initialise the model
model = DiffGFDNVarReceiverPos(room_data.sample_rate, room_data.num_rooms,
                 config_dict.delay_length_samps,
                 trainer_config.device, 
                 config_dict.feedback_loop_config,
                 config_dict.output_filter_config,
                 config_dict.use_absorption_filters,
                 common_decay_times=room_data.common_decay_times,
                 band_centre_hz=room_data.band_centre_hz
                )


In [None]:
audio_directory  = Path("../audio/")
fig_path = Path("../figures").resolve()
checkpoint_dir = Path(trainer_config.train_dir + 'checkpoints/').resolve()
max_epochs = trainer_config.max_epochs
save_ir_dir = Path(trainer_config.ir_dir).resolve() 
save_ir = False
plot_ir = not save_ir 
pos_to_investigate = [0.20, 5.60, 1.50]

In [None]:
for epoch in range(max_epochs):
    # load the trained weights for the particular epoch
    checkpoint = torch.load(f'{checkpoint_dir}/model_e{epoch}.pt', weights_only=True, map_location=torch.device('cpu'))
    # Load the trained model state
    model.load_state_dict(checkpoint)
    # in eval mode, no gradients are calculated
    model.eval()
    
    for data in train_dataset:
        position = data['listener_position']
        H, h = get_response(data, model)
        
        for num_pos in range(position.shape[0]):
            filename = f'ir_({position[num_pos,0]:.2f}, {position[num_pos, 1]:.2f}, {position[num_pos, 2]:.2f}).wav'
            
            # find the true IR corresponding to this position
            filepath_true = os.path.join(Path(audio_directory/'true').resolve(), filename)
            h_true = torch.from_numpy(sf.read(filepath_true)[0])

            if plot_ir and np.array_equal(position[num_pos, :], pos_to_investigate):
                # get the biquad coefficients for this position
                param_dict = model.get_param_dict()
                output_biquad_coeffs = param_dict['output_biquad_coeffs'][num_pos]

                # plot the EDRs of the true and estimated
                plot_edr(h_true, model.sample_rate, title=f'True RIR EDR, epoch={epoch}', 
                         save_path=f'{fig_path}/true_edr_{filename}_{config_name}_epoch={epoch}.png')

                plot_edr(h[num_pos, :], model.sample_rate, title=f'Estimated RIR EDR, epoch={epoch}', 
                         save_path=f'{fig_path}/approx_edr_{filename}_{config_name}_epoch={epoch}.png')
        
                plt.plot(torch.stack((h_true, h[num_pos, :len(h_true)]), dim=-1))
                plt.xlim([0, int(1.5 * model.sample_rate)])
                plt.savefig(f'{fig_path}/ir_compare_{filename}_{config_name}_epoch={epoch}.png')
                plt.show()
                break

            if save_ir and (epoch == max_epochs - 1):
                outer_loop_break = False
                filepath = os.path.join(save_ir_dir, filename)
                torchaudio.save(filepath,
                            torch.stack((h[num_pos, :], h[num_pos, :]),
                                        dim=1).cpu(),
                            int(model.sample_rate),
                            bits_per_sample=32,
                            channels_first=False)


#### Plot the output filter response for the position under investigation

In [None]:
from scipy.signal import sosfreqz, sos2zpk

fig, ax = plt.subplots()
fig2, ax2 = plt.subplots(subplot_kw={'projection': 'polar'})

for n in range(room_data.num_rooms):
    cur_biquad_coeffs = output_biquad_coeffs[n]
    num_biquads = cur_biquad_coeffs.shape[0]
    # ensure a0 = 1 (needed by scipy)
    for k in range(num_biquads):
        cur_biquad_coeffs[k,:] /= cur_biquad_coeffs[k, 3]

    freqs, filt_response = sosfreqz(cur_biquad_coeffs, worN=2**9, fs=room_data.sample_rate)
    ax.semilogx(freqs, 20*np.log10(np.abs(filt_response)), label=f'Group {n}')

    # also plot the poles and zeros
    zeros, poles, gain = sos2zpk(cur_biquad_coeffs)
    ax2.plot(np.angle(zeros), np.abs(zeros), 'o', label=f'Group {n}')
    ax2.plot(np.angle(poles), np.abs(poles), 'x', label=f'Group {n}')

    
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Magnitude (dB)')
ax.set_title(f'Output filter for position {pos_to_investigate}')
ax.legend()
ax.grid(True)
fig.savefig(f'{fig_path}/{config_name}_output_filter_response.png')

ax2.set_rmax(1)
ax2.set_rticks([0.25, 0.5, 1])  # Less radial ticks
ax2.set_rlabel_position(-22.5)  # Move radial labels away from plotted line
ax2.grid(True)
fig2.savefig(f'{fig_path}/{config_name}_output_filter_pz_plot.png')



### Get the final trained parameters and investigate them

In [None]:
from scipy.io import loadmat
param_path = Path(trainer_config.train_dir + '/parameters_opt.mat')
opt_params = loadmat(param_path.resolve())
print(opt_params.keys())

#### Observe the individual mixing matrices

In [None]:
from importlib import reload
import diff_gfdn
reload(diff_gfdn.utils)
from diff_gfdn.utils import is_unitary

if config_dict.feedback_loop_config.coupling_matrix_type in (CouplingMatrixType.SCALAR, CouplingMatrixType.FILTER):
    M_list = opt_params['feedback_loop.M']
    num_groups = model.num_groups
    fig, ax = plt.subplots(num_groups, 1, figsize=(6,6))
    
    for i in range(num_groups):
        M = torch.from_numpy(M_list[i, ...])
        with torch.no_grad():
            M_ortho = model.feedback_loop.ortho_param(M)
        ax[i].matshow(torch.abs(M_ortho))
        ax[i].set_title(f'Room {i}')
        is_ortho, max_val = is_unitary(M_ortho)
        print(max_val)
    plt.tight_layout()
    plt.savefig(f'{fig_path}/individual_feedback_matrices.png')

#### Observe the input gains, coupling matrix and the coupled mixing matrix

In [None]:
coupled_feedback_matrix = opt_params['coupled_feedback_matrix']
input_gains = opt_params['input_gains'][0]
output_gains = opt_params['output_gains'][0]
print(f'Norm of input gains {np.linalg.norm(input_gains)}')
print(f'Norm of output gains {np.linalg.norm(output_gains)}')

if config_dict.feedback_loop_config.coupling_matrix_type == CouplingMatrixType.SCALAR:
    assert is_unitary(torch.from_numpy(coupled_feedback_matrix))[0]    
    coupling_matrix = opt_params['coupling_matrix']
    plt.figure()
    plt.subplot(211)
    plt.matshow(np.abs(coupling_matrix), fignum=False)
    plt.colorbar()
    plt.title('Coupling matrix')
    plt.subplot(212)
    plt.matshow(np.abs(coupled_feedback_matrix), fignum=False)
    plt.colorbar()
    plt.title('Coupled feedback matrix')
    plt.tight_layout()
    plt.savefig(f'{fig_path}/scalar_coupling_matrix.png')


elif config_dict.feedback_loop_config.coupling_matrix_type == CouplingMatrixType.FILTER:
    # assert is_paraunitary(torch.from_numpy(coupled_feedback_matrix))[0]
    coupling_matrix = opt_params['coupling_matrix']
    num_freq_bins = 2**10
    plot_polynomial_matrix_magnitude_response(coupling_matrix, model.sample_rate, num_freq_bins, 'Coupling matrix response')
    plt.savefig(f'{fig_path}/pu_coupling_matrix.png')

else:
    feedback_matrix = opt_params['coupled_feedback_matrix']
    unit_flag, max_val = is_unitary(torch.tensor(feedback_matrix), max_tol=1e-4)
    assert unit_flag
    plt.figure()
    plt.matshow(np.abs(feedback_matrix))
    plt.title('Optimised feedback matrix')
    plt.savefig(f'{fig_path}/random_coupling_matrix.png')