In [None]:
import pickle
import numpy as np
import torch
import re
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.io import loadmat
import soundfile as sf

from diff_gfdn.config.config import DiffGFDNConfig
from diff_gfdn.filters import decay_times_to_gain_filters
from diff_gfdn.plot import plot_edr, plot_polynomial_matrix_magnitude_response
from run_model import load_and_validate_config

In [None]:
fpath = Path('../resources/Georg_3room_FDTD/srirs.pkl')
with open (fpath, 'rb') as f:
    srirs = pickle.load(f)

common_decay_times = np.asarray(np.squeeze(srirs['common_decay_times'], axis=1))
band_centre_hz = srirs['band_centre_hz']

In [None]:
config_path = '../data/config/'
fig_path = '../figures/'
config_file = config_path + 'single_rir_fit_random_coupling.yml'
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)
config_dict = config_dict.copy(update={"num_groups": 3})
trainer_config = config_dict.trainer_config

### Plot the T60 of the delay line filters

In [None]:
delays = config_dict.delay_length_samps
num_del_per_group = int(config_dict.num_delay_lines / config_dict.num_groups)
delays_by_group = [delays[i:i+num_del_per_group] for i in range(0, config_dict.num_delay_lines, num_del_per_group)]

gain_per_sample = [decay_times_to_gain_filters(band_centre_hz,
                                              common_decay_times[:, i],
                                              delays_by_group[i],
                                              config_dict.sample_rate,
                                              plot_response=True) for i in range(config_dict.num_groups)]

### Investigate optimised outputs

In [None]:
param_path = Path(trainer_config.train_dir + '/parameters_opt.mat')
opt_params = loadmat(param_path.resolve())
print(opt_params.keys())

In [None]:
feedback_matrix = opt_params['coupled_feedback_matrix']
if feedback_matrix.ndim == 2:
    plt.figure()
    plt.matshow(np.abs(feedback_matrix))
    plt.title('Optimised feedback matrix')
elif feedback_matrix.ndim == 3:
    coupling_matrix = opt_params['coupling_matrix']
    plot_polynomial_matrix_magnitude_response(coupling_matrix, config_dict.sample_rate, 
                                              num_bins=2**11, title='Optimised coupling matrix')

In [None]:
input_gains = opt_params['input_gains'][0]
print(f'Norm of input gain vector is {np.linalg.norm(input_gains):.4f}')

if 'output_gains' in opt_params:
    output_gains = opt_params['output_gains'][0]
    print(f'Norm of output gain vector is {np.linalg.norm(output_gains):.4f}')

### Check output data and compare with true IR

In [None]:
true_ir, fs = sf.read(Path(config_dict.ir_path).resolve())
match = re.search(r'ir_\([^)]+\)', config_dict.ir_path)
ir_name = match.group()
approx_ir_path = f'{trainer_config.ir_dir}/approx_{ir_name}.wav'
approx_ir, fs = sf.read(Path(approx_ir_path).resolve())

In [None]:
plot_edr(torch.tensor(true_ir), config_dict.sample_rate, 
         title=f'True RIR EDR', save_path=f'{fig_path}/true_edr_{ir_name}.png')
plot_edr(torch.tensor(approx_ir[:, 0]), config_dict.sample_rate, 
         title=f'Estimated RIR EDR', save_path=f'{fig_path}/approx_edr_{ir_name}.png')