In [None]:
import numpy as np
import torch
import re
import os
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.io import loadmat
import soundfile as sf
from torch import nn
from scipy.fft import rfft, rfftfreq
from numpy.typing import ArrayLike, NDArray
from typing import Optional, List
from IPython import display
from loguru import logger
from copy import deepcopy
from scipy.signal import sosfreqz, sos2zpk, tf2zpk

from diff_gfdn.dataloader import load_dataset, RIRData
from diff_gfdn.config.config import DiffGFDNConfig, CouplingMatrixType
from diff_gfdn.solver import convert_common_slopes_rir_to_room_dataset
from diff_gfdn.model import DiffGFDNSinglePos
from diff_gfdn.utils import is_unitary, db2lin, db, ms_to_samps, get_response
from diff_gfdn.absorption_filters import decay_times_to_gain_filters_geq
from diff_gfdn.filters.geq import eq_freqs
from diff_gfdn.plot import plot_edr, animate_coupled_feedback_matrix, plot_subband_edc, plot_subband_amplitudes, plot_learned_svf_response
from diff_gfdn.analysis import get_decay_fit_net_params, get_edc_params, get_decay_times_for_rirs
from diff_gfdn.gain_filters import SOSFilter
from slope2noise.slope2noise.utils import octave_filtering, calculate_amplitudes_least_squares
from run_model import load_and_validate_config

os.chdir('..')  # This changes the working directory to DiffGFDN

In [None]:
config_path = 'data/config/'
fig_path = 'figures/'
config_name = 'single_rir_fit_subband_two_stage_decay_scalar_coupling.yml'
config_file = config_path + config_name
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)

In [None]:
room_data = convert_common_slopes_rir_to_room_dataset(config_dict.room_dataset_path, 
                                                      num_freq_bins=config_dict.trainer_config.num_freq_bins,
                                                      )

config_dict = config_dict.copy(update={"num_groups": room_data.num_rooms})

trainer_config = config_dict.trainer_config
# prepare the training and validation data for DiffGFDN
if trainer_config.batch_size != room_data.num_freq_bins:
    trainer_config = trainer_config.copy(
        update={"batch_size": room_data.num_freq_bins})

# initialise the model
model = DiffGFDNSinglePos(config_dict.sample_rate,
                 config_dict.num_groups,
                 config_dict.delay_length_samps,
                 trainer_config.device,
                 config_dict.feedback_loop_config,
                 config_dict.output_filter_config,
                 use_absorption_filters=True,
                 common_decay_times=room_data.common_decay_times,
                 band_centre_hz=room_data.band_centre_hz,
                 )

### Plot the T60 of the delay line filters

In [None]:
delays = config_dict.delay_length_samps
num_del_per_group = int(config_dict.num_delay_lines / config_dict.num_groups)
delays_by_group = [delays[i:i+num_del_per_group] for i in range(0, config_dict.num_delay_lines, num_del_per_group)]
common_decay_times = room_data.common_decay_times
band_centre_hz = room_data.band_centre_hz
gain_per_sample = [decay_times_to_gain_filters_geq(band_centre_hz,
                                              common_decay_times[:, i],
                                              delays_by_group[i],
                                              config_dict.sample_rate,
                                              plot_response=True) for i in range(config_dict.num_groups)]

fig2, ax2 = plt.subplots(subplot_kw={'projection': 'polar'})
for n in range(config_dict.num_groups):
    # also plot the poles and zeros
    all_zeros = []
    all_poles = []
    for k in range(num_del_per_group):
        if gain_per_sample[n].ndim == 4:
            cur_filt = gain_per_sample[n][:, k, ...].detach().numpy()
            sos_coeffs = np.concatenate((cur_filt[..., 0], cur_filt[...,1]), axis=-1)
            for i in range(sos_coeffs.shape[0]):
                sos_coeffs[i, :] /= sos_coeffs[i, 3]
            zeros, poles, gains = sos2zpk(sos_coeffs)
        else:
            zeros, poles, gain = tf2zpk(gain_per_sample[n][k, :, 0], gain_per_sample[n][k, :, 1])
        all_zeros.append(zeros)
        all_poles.append(poles)
        
    ax2.plot(np.angle(all_poles), np.abs(all_poles), 'x', label=f'Group {n}')    

ax2.set_rmax(1.1)
ax2.set_rticks([0.25, 0.5, 0.75, 1,])  # Less radial ticks
ax2.set_rlabel_position(-22.5)  # Move radial labels away from plotted line
ax2.grid(True)

### Plot output data and compare RIR

In [None]:
true_ir, fs = sf.read(Path(config_dict.ir_path).resolve())
match = re.search(r'ir_\([^)]+\)', config_dict.ir_path)
dir_name = Path(trainer_config.ir_dir).parts[-1]
ir_name = match.group()
approx_ir_path = f'{trainer_config.ir_dir}/approx_{ir_name}.wav'
final_approx_ir, fs = sf.read(Path(approx_ir_path).resolve())

plt.plot(np.stack((true_ir, final_approx_ir[:len(true_ir), 0]), axis=-1))
plt.xlim([0, int(1.5 * config_dict.sample_rate)])
plt.savefig(f'{fig_path}/compare_{dir_name}_{ir_name}.png')
plt.show()

# find receiver position from string
match = re.search(r'ir_\(([^,]+), ([^,]+), ([^,]+)\)', ir_name)
# Convert the extracted values to floats
x, y, z = map(float, match.groups())
rec_pos = np.array([x, y, z])

# find amplitudes corresponding to the receiver position
rec_pos_idx = np.argwhere(
    np.all(np.round(room_data.receiver_position,2) == rec_pos, axis=1))[0]
amplitudes = np.squeeze(room_data.amplitudes[rec_pos_idx, :]).T
if room_data.num_rooms == 1:
    amplitudes = amplitudes[:, np.newaxis]


# plot EDR
plot_edr(torch.tensor(true_ir), config_dict.sample_rate, 
         title=f'True RIR EDR', save_path=f'{fig_path}/{dir_name}_true_edr_{ir_name}.png')
plot_edr(torch.tensor(final_approx_ir[:, 0]), config_dict.sample_rate, 
         title=f'Estimated RIR EDR', save_path=f'{fig_path}/{dir_name}_approx_edr_{ir_name}.png')


### Plot EDF and estimated amplitudes for each epoch

In [None]:
# read the model parameters per epoch
checkpoint_dir = Path(trainer_config.train_dir + 'checkpoints/').resolve()

# create RIRDataset
rir_data = RIRData(config_dict.ir_path,
                   common_decay_times=room_data.common_decay_times,
                   band_centre_hz=room_data.band_centre_hz,
                   amplitudes=amplitudes,
                   nfft=config_dict.trainer_config.num_freq_bins,
                  )

# prepare the training and validation data for DiffGFDN
train_dataset = load_dataset(
    rir_data, trainer_config.device, train_valid_split_ratio=1.0,
    batch_size=trainer_config.batch_size, shuffle=False)


mixing_time_samp = ms_to_samps(20.0, fs)
crop_end_samp = ms_to_samps(5.0, fs)
trunc_true_ir = true_ir[mixing_time_samp:-crop_end_samp]
num_bands = len(band_centre_hz)
filtered_true_ir = octave_filtering(trunc_true_ir, fs, band_centre_hz)
# get decay fit net params
og_est_t60, og_est_amp, _, _, og_fitted_edc = get_edc_params(trunc_true_ir, 
                                                             config_dict.num_groups, 
                                                             config_dict.sample_rate, 
                                                             f_bands=room_data.band_centre_hz)

In [None]:
fig, ax = plt.subplots(num_bands, 1, figsize=(6, 12))
for k in range(num_bands):
    
    true_edf = np.flipud(np.cumsum(np.flipud(filtered_true_ir[:, k]**2), axis=-1))
    time = np.linspace(0, (len(trunc_true_ir) - 1) / fs,
                           len(trunc_true_ir))
    
    ax[k].plot(time, db(true_edf, is_squared=True), label='True EDF')
    ax[k].plot(time[:og_fitted_edc.shape[-1]], db(og_fitted_edc[k, :], is_squared=True), label='DecayFitNet EDF')
    ax[k].plot(np.zeros(room_data.num_rooms),  db(amplitudes[k, :], is_squared=True), 'kx', label='Original amplitudes')
    ax[k].plot(np.zeros(room_data.num_rooms),  db(og_est_amp[k, :] , is_squared=True), 'ko', label='Fitted amplitudes')
    ax[k].set_title(f'{band_centre_hz[k]:.0f} Hz')

fig.subplots_adjust(hspace=0.7)
ax[-1].legend()

### Get parameters as a function of epoch number

In [None]:
max_epochs = trainer_config.max_epochs
est_t60 = np.zeros((config_dict.num_groups, max_epochs))
est_amp = np.zeros_like(est_t60)
coupled_feedback_matrix = []
coupling_matrix = []
input_gains = []
output_gains = []
output_biquad_coeffs = []
svf_params = []
absorption_filters = []
h_approx_list = []

for epoch in range(max_epochs):
    # load the trained weights for the particular epoch
    try:
        checkpoint = torch.load(f'{checkpoint_dir}/model_e{epoch}.pt', weights_only=True, map_location=torch.device('cpu'))
    except:
        max_epochs = epoch - 1
        break
    
    # Load the trained model state
    model.load_state_dict(checkpoint)
    # in eval mode, no gradients are calculated
    model.eval()

    with torch.no_grad():
        # update estimated parameters
        param_dict = model.get_param_dict()
        input_gains.append(deepcopy(param_dict['input_gains']))
        output_gains.append(deepcopy(param_dict['output_gains']))
        
        if 'coupling_matrix' in param_dict.keys():
            coupling_matrix.append(deepcopy(param_dict['coupling_matrix']))
            
        if 'coupled_feedback_matrix' in param_dict.keys():
            coupled_feedback_matrix.append(deepcopy(param_dict['coupled_feedback_matrix']))
        
        if 'output_biquad_coeffs' in param_dict.keys():
            output_biquad_coeffs.append(deepcopy(param_dict['output_biquad_coeffs']))
            svf_params.append(deepcopy(param_dict['output_svf_params']))

        if 'absorption_filters' in param_dict.keys():
            absorption_filters.append(deepcopy(param_dict['absorption_filters']))
            
        for data in train_dataset:
            H, approx_ir = get_response(data, model)
            h_approx_list.append(approx_ir)


In [None]:
plot_subband_edc(true_ir, h_approx_list[::2], config_dict.sample_rate, band_centre_hz, rec_pos, 
                 save_path=f'{fig_path}/compare_synth_edf_{rec_pos}_{config_name}.png')

### Investigate output SVFs as a function of epoch number

In [None]:
from importlib import reload
import diff_gfdn
reload(diff_gfdn.plot)
from diff_gfdn.plot import plot_learned_svf_response
plot_learned_svf_response(config_dict.num_groups, config_dict.sample_rate, 
                          output_biquad_coeffs, rec_pos, save_path=f'{fig_path}/{config_name}')

### Compare the subband EDC parameters of the true RIR with the GFDN's IR

#### Check the subband T60s

In [None]:
def plot_absorption_filter_response(absorption_filters: List[SOSFilter], fs:float):
    fig, ax = plt.subplots(1, 1, figsize=(6, 4))
    for n in range(config_dict.num_delay_lines):
        cur_filter = absorption_filters[n]
        sos_coeffs = np.concatenate((cur_filter.biquad_cascade.num_coeffs, cur_filter.biquad_cascade.den_coeffs), axis=-1)
        for i in range(sos_coeffs.shape[0]):
            sos_coeffs[i, :] /= sos_coeffs[i, 3]
        freq_axis, freq_response = sosfreqz(sos_coeffs, worN=1024, fs=fs)
        ax.semilogx(freq_axis, db(freq_response), label=f'Delay line {n}')

    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('Magnitude (dB)')
    ax.set_xlim([20, fs/2])
    ax.set_title('Delay line absorption filters')
        

final_approx_ir = h_approx_list[-1]
true_t60, est_t60 = get_decay_times_for_rirs(torch.tensor(true_ir), final_approx_ir, room_data.num_rooms, fs, band_centre_hz, plot_edc=True)
logger.info(f'Original RIR T60s are {np.round(np.squeeze(true_t60),3)} s')
logger.info(f'Estimated RIR T60s are {np.round(np.squeeze(est_t60), 3)} s')
plot_absorption_filter_response(absorption_filters[-1], fs)

#### Check the subband amplitudes

In [None]:
final_approx_rir = h_approx_list[-1]
plot_subband_amplitudes(true_ir, final_approx_rir, room_data.sample_rate, 
                        config_dict.num_groups, amplitudes, common_decay_times, band_centre_hz)