In [None]:
import numpy as np
import torch
import re
import os
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.io import loadmat
import soundfile as sf
from torch import nn
from scipy.fft import rfft, rfftfreq
from numpy.typing import ArrayLike
from typing import Optional

from diff_gfdn.config.config import DiffGFDNConfig, CouplingMatrixType
from diff_gfdn.solver import convert_common_slopes_rir_to_room_dataset
from diff_gfdn.utils import is_unitary, db2lin, db, ms_to_samps
from diff_gfdn.plot import plot_edr
from run_model import load_and_validate_config
os.chdir('..')  # This changes the working directory to DiffGFDN

In [None]:
config_path = 'data/config/'
fig_path = 'figures/'
config_file = config_path + 'single_rir_fit_broadband_two_stage_decay_scalar_coupling.yml'
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)
trainer_config = config_dict.trainer_config

In [None]:
room_data = convert_common_slopes_rir_to_room_dataset(config_dict.room_dataset_path)
config_dict = config_dict.copy(update={"num_groups": room_data.num_rooms})

### Plot output data and compare RIR

In [None]:
true_ir, fs = sf.read(Path(config_dict.ir_path).resolve())
match = re.search(r'ir_\([^)]+\)', config_dict.ir_path)
dir_name = Path(trainer_config.ir_dir).parts[-1]
ir_name = match.group()
approx_ir_path = f'{trainer_config.ir_dir}/approx_{ir_name}.wav'
approx_ir, fs = sf.read(Path(approx_ir_path).resolve())

plt.plot(np.stack((true_ir, approx_ir[:len(true_ir), 0]), axis=-1))
plt.xlim([0, int(1.5 * config_dict.sample_rate)])
plt.savefig(f'{fig_path}/compare_{dir_name}_{ir_name}.png')
plt.show()

# find receiver position from string
match = re.search(r'ir_\(([^,]+), ([^,]+), ([^,]+)\)', ir_name)
# Convert the extracted values to floats
x, y, z = map(float, match.groups())
rec_pos = np.array([x, y, z])

# find amplitudes corresponding to the receiver position
rec_pos_idx = np.where(
    np.all(room_data.receiver_position == rec_pos, axis=1))[0]
amplitudes = room_data.amplitudes[..., rec_pos_idx]

# plot EDF
mixing_time_samp = ms_to_samps(20.0, fs)

true_edf = np.flipud(np.cumsum(np.flipud(true_ir**2), axis=-1))
synth_edf = np.flipud(np.cumsum(np.flipud(approx_ir[mixing_time_samp:mixing_time_samp + len(true_ir), 0]**2), axis=-1))
time = np.linspace(0, (len(true_ir) - 1) / fs,
                   len(true_ir))
plt.plot(time, db(true_edf, is_squared=True), label='True EDF')
plt.plot(time, db(synth_edf, is_squared=True), label='Synthesised EDF')
plt.title(
    f'EDF at position {x:.2f}, {y:.2f}, {z:.2f} m'
)
plt.legend()
plt.show()

# plot EDR
plot_edr(torch.tensor(true_ir), config_dict.sample_rate, 
         title=f'True RIR EDR', save_path=f'{fig_path}/{dir_name}_true_edr_{ir_name}.png')
plot_edr(torch.tensor(approx_ir[:, 0]), config_dict.sample_rate, 
         title=f'Estimated RIR EDR', save_path=f'{fig_path}/{dir_name}_approx_edr_{ir_name}.png')

### Investigate optimised outputs

In [None]:
param_path = Path(trainer_config.train_dir + '/parameters_opt.mat')
opt_params = loadmat(param_path.resolve())
print(opt_params.keys())

In [None]:
feedback_matrix = opt_params['coupled_feedback_matrix']
if feedback_matrix.ndim == 3:
    raise NotImplementedError("PU matrix cannot be used for broadband gains")
else:
    if config_dict.feedback_loop_config.coupling_matrix_type == CouplingMatrixType.SCALAR:
        assert is_unitary(torch.from_numpy(feedback_matrix))[0]    
        coupling_matrix = opt_params['coupling_matrix']
        plt.figure()
        plt.subplot(211)
        plt.matshow(np.abs(coupling_matrix), fignum=False)
        plt.colorbar()
        plt.title('Coupling matrix')
        plt.subplot(212)
        plt.matshow(np.abs(feedback_matrix), fignum=False)
        plt.colorbar()
        plt.title('Coupled feedback matrix')
        plt.tight_layout()
        # plt.savefig(f'{fig_path}/scalar_coupling_matrix.png')
    else:
        unit_flag, max_val = is_unitary(torch.tensor(feedback_matrix), max_tol=1e-4)
        assert unit_flag
        plt.figure()
        plt.matshow(np.abs(feedback_matrix))
        plt.title('Optimised feedback matrix')

In [None]:
input_gains = opt_params['input_gains'][0]
print(f'Norm of input gain vector is {np.linalg.norm(input_gains):.4f}')
output_gains = opt_params['output_gains'][0]
print(f'Norm of output gain vector is {np.linalg.norm(output_gains):.4f}')

### Compare the magnitude response of the true and synthesized IR

In [None]:
def plot_mag_response(true_ir: ArrayLike, synth_ir: ArrayLike, fs: float, nfft: Optional[int]=None):
    if nfft is None:
        nfft = int(np.pow(2, np.ceil(np.log2(len(true_ir)))))
    true_resp = rfft(true_ir, nfft)
    synth_resp = rfft(synth_ir, nfft)
    freq_bins = rfftfreq(nfft, d = 1.0/fs)

    fig = plt.figure(figsize=(8, 6))
    plt.semilogx(freq_bins, db(true_resp), label='True')
    plt.semilogx(freq_bins, db(synth_resp), linestyle= '--', label='DiffGFDN')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Magnitude (dB)')
    plt.xlim([20, fs / 2])
    plt.ylim([-60, 20])
    plt.legend()
    

In [None]:
plot_mag_response(true_ir, approx_ir[:, 0], fs)