In [None]:
import numpy as np
import torch
import re
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.io import loadmat
from scipy.signal import sosfreqz, sos2zpk, tf2zpk
import soundfile as sf

from diff_gfdn.config.config import DiffGFDNConfig
from diff_gfdn.dataloader import ThreeRoomDataset
from diff_gfdn.filters import decay_times_to_gain_filters
from diff_gfdn.utils import is_unitary
from diff_gfdn.plot import plot_edr, plot_polynomial_matrix_magnitude_response
from run_model import load_and_validate_config

In [None]:
config_path = '../data/config/'
fig_path = '../figures/'
config_file = config_path + 'single_rir_fit_random_coupling_out_filters.yml'
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)
trainer_config = config_dict.trainer_config

In [None]:
room_data = ThreeRoomDataset(Path(config_dict.room_dataset_path).resolve())
config_dict = config_dict.copy(update={"num_groups": room_data.num_rooms})

common_decay_times = room_data.common_decay_times
band_centre_hz = room_data.band_centre_hz

### Plot the T60 of the delay line filters

In [None]:
delays = config_dict.delay_length_samps
num_del_per_group = int(config_dict.num_delay_lines / config_dict.num_groups)
delays_by_group = [delays[i:i+num_del_per_group] for i in range(0, config_dict.num_delay_lines, num_del_per_group)]

gain_per_sample = [decay_times_to_gain_filters(band_centre_hz,
                                              common_decay_times[:, i],
                                              delays_by_group[i],
                                              config_dict.sample_rate,
                                              plot_response=True) for i in range(config_dict.num_groups)]

fig2, ax2 = plt.subplots(subplot_kw={'projection': 'polar'})
for n in range(config_dict.num_groups):
    # also plot the poles and zeros
    all_zeros = []
    all_poles = []
    for k in range(num_del_per_group):
        zeros, poles, gain = tf2zpk(gain_per_sample[n][k, :, 0], gain_per_sample[n][k, :, 1])
        all_zeros.append(zeros)
        all_poles.append(poles)
        
    ax2.plot(np.angle(all_poles), np.abs(all_poles), 'x', label=f'Group {n}')    

ax2.set_rmax(1.1)
ax2.set_rticks([0.25, 0.5, 0.75, 1,])  # Less radial ticks
ax2.set_rlabel_position(-22.5)  # Move radial labels away from plotted line
ax2.grid(True)

### Check output data and compare with true IR

In [None]:
true_ir, fs = sf.read(Path(config_dict.ir_path).resolve())
match = re.search(r'ir_\([^)]+\)', config_dict.ir_path)
dir_name = Path(trainer_config.ir_dir).parts[-1]
ir_name = match.group()
approx_ir_path = f'{trainer_config.ir_dir}/approx_{ir_name}.wav'
approx_ir, fs = sf.read(Path(approx_ir_path).resolve())

plt.plot(np.stack((true_ir, approx_ir[:len(true_ir), 0]), axis=-1))
plt.xlim([0, int(1.5 * config_dict.sample_rate)])
plt.savefig(f'{fig_path}/compare_{dir_name}_{ir_name}.png')
plt.show()

# find receiver position from string
match = re.search(r'ir_\(([^,]+), ([^,]+), ([^,]+)\)', ir_name)
# Convert the extracted values to floats
x, y, z = map(float, match.groups())
rec_pos = np.array([x, y, z])

# find amplitudes corresponding to the receiver position
rec_pos_idx = np.where(
    np.all(room_data.receiver_position == rec_pos, axis=1))[0]
amplitudes = room_data.amplitudes[..., rec_pos_idx]

In [None]:
plot_edr(torch.tensor(true_ir), config_dict.sample_rate, 
         title=f'True RIR EDR', save_path=f'{fig_path}/{dir_name}_true_edr_{ir_name}.png')
plot_edr(torch.tensor(approx_ir[:, 0]), config_dict.sample_rate, 
         title=f'Estimated RIR EDR', save_path=f'{fig_path}/{dir_name}_approx_edr_{ir_name}.png')

### Investigate optimised outputs

In [None]:
param_path = Path(trainer_config.train_dir + '/parameters_opt.mat')
opt_params = loadmat(param_path.resolve())
print(opt_params.keys())

In [None]:
feedback_matrix = opt_params['coupled_feedback_matrix']
if feedback_matrix.ndim == 2:
    unit_flag, max_val = is_unitary(torch.tensor(feedback_matrix), max_tol=1e-4)
    assert unit_flag
    plt.figure()
    plt.matshow(np.abs(feedback_matrix))
    plt.title('Optimised feedback matrix')
elif feedback_matrix.ndim == 3:
    coupling_matrix = opt_params['coupling_matrix']
    plot_polynomial_matrix_magnitude_response(coupling_matrix, config_dict.sample_rate, 
                                              num_bins=2**11, title='Optimised coupling matrix')

In [None]:
input_gains = opt_params['input_gains'][0]
print(f'Norm of input gain vector is {np.linalg.norm(input_gains):.4f}')

if 'output_gains' in opt_params:
    output_gains = opt_params['output_gains'][0]
    print(f'Norm of output gain vector is {np.linalg.norm(output_gains):.4f}')
elif 'output_biquad_coeffs' in opt_params:

    output_biquad_coeffs = opt_params['output_biquad_coeffs']
    svf_params = opt_params['output_svf_params']
    fig, ax = plt.subplots()
    fig2, ax2 = plt.subplots(subplot_kw={'projection': 'polar'})
    for n in range(config_dict.num_groups):
        cur_biquad_coeffs = output_biquad_coeffs[n]
        cur_svf_params = svf_params[n, ...]
        
        # ensure a0 = 1 (needed by scipy)
        for k in range(cur_biquad_coeffs.shape[0]):
            cur_biquad_coeffs[k,:] /= cur_biquad_coeffs[k, 3]

        freqs, filt_response = sosfreqz(cur_biquad_coeffs, worN=2**9, fs=config_dict.sample_rate)
        ax.semilogx(freqs, 20*np.log10(np.abs(filt_response)), label=f'Group {n}')
        # also plot the amplitudes estimated from the common slope model
        ax.semilogx(band_centre_hz, 20*np.log10(np.abs(amplitudes[:, n])), marker='o', label=f'Amplitude {n}')

        # also plot the poles and zeros
        zeros, poles, gain = sos2zpk(cur_biquad_coeffs)
        ax2.plot(np.angle(zeros), np.abs(zeros), 'o', label=f'Group {n}')
        ax2.plot(np.angle(poles), np.abs(poles), 'x', label=f'Group {n}')

        print(f'Pole frequencies (rad): {np.angle(poles)}')
        print(f'SVF Q factor : {cur_svf_params[:, 1]}')
        
        
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('Magnitude (dB)')
    ax.set_title(f'Output filter response')
    ax.set_ylim([-100, 20])
    ax.legend()

    ax2.set_rmax(2)
    ax2.set_rticks([0.5, 1, 1.5, 2])  # Less radial ticks
    ax2.set_rlabel_position(
        -22.5)  # Move radial labels away from plotted line
    ax2.grid(True)


In [None]:
import numpy as np
from numpy.typing import ArrayLike
from scipy.fft import rfftfreq
from diff_gfdn.losses import scaled_shifted_sigmoid_inverse

nfft = 2**9
x = rfftfreq(nfft, d=1.0/room_data.sample_rate)
cutoff = 1e3
num_scale = 10
scale_factor = np.logspace(-3, 0, num_scale)
top = 2
bottom = 1.0

plt.figure()
for sf in range(len(scale_factor)):
    fn = scaled_shifted_sigmoid_inverse(torch.tensor(x), scale_factor[sf], cutoff, top, bottom)
    plt.plot(x, fn, label=f'scale_factor: {scale_factor[sf]:.3f}')
plt.legend()
plt.show()
