In [None]:
import numpy as np
import torch
import re
import os
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.io import loadmat
from torch import nn
from numpy.typing import ArrayLike
from typing import Optional, List
from IPython import display
import soundfile as sf
from loguru import logger
from copy import deepcopy

os.chdir('..')  # This changes the working directory to DiffGFDN

from diff_gfdn.dataloader import load_dataset, RIRData
from diff_gfdn.config.config import DiffGFDNConfig, CouplingMatrixType
from diff_gfdn.solver import convert_common_slopes_rir_to_room_dataset
from diff_gfdn.model import DiffGFDNVarSourceReceiverPos
from diff_gfdn.utils import is_unitary, db2lin, db, ms_to_samps, get_response
from diff_gfdn.plot import plot_edr, animate_coupled_feedback_matrix, plot_amps_in_space, plot_edc_error_in_space
from diff_gfdn.analysis import get_decay_fit_net_params
from diff_gfdn.colorless_fdn.utils import get_colorless_fdn_params
from src.run_model import load_and_validate_config

In [None]:
config_path = 'data/config/'
fig_path = 'figures/'
config_name = 'synth_data_broadband_two_coupled_rooms_multi_source_colorless_loss.yml'
config_file = config_path + config_name
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)

In [None]:
room_data = convert_common_slopes_rir_to_room_dataset(config_dict.room_dataset_path, 
                                                      num_freq_bins=config_dict.trainer_config.num_freq_bins,
                                                      )

config_dict = config_dict.copy(update={"num_groups": room_data.num_rooms})

trainer_config = config_dict.trainer_config

# force the trainer config device to be CPU
if trainer_config.device != 'cpu':
    trainer_config = trainer_config.copy(update={"device": 'cpu'})

# prepare the training and validation data for DiffGFDN
train_dataset, valid_dataset = load_dataset(
    room_data, trainer_config.device, train_valid_split_ratio=1.0,
    batch_size=trainer_config.batch_size, shuffle=False)

# get the colorless FDN params
if config_dict.colorless_fdn_config.use_colorless_prototype:
    colorless_fdn_params = get_colorless_fdn_params(config_dict)
else:
    colorless_fdn_params = None

# initialise the model
model = DiffGFDNVarSourceReceiverPos(config_dict.sample_rate,
                 config_dict.num_groups,
                 config_dict.delay_length_samps,
                 trainer_config.device,
                 config_dict.feedback_loop_config,
                 config_dict.output_filter_config,
                 config_dict.input_filter_config,
                 use_absorption_filters=False,
                 common_decay_times=room_data.common_decay_times,
                 colorless_fdn_params=colorless_fdn_params,
                 use_colorless_loss=trainer_config.use_colorless_loss
                 )

In [None]:
audio_directory  = Path("audio/")
fig_path = Path("figures").resolve()
checkpoint_dir = Path(trainer_config.train_dir + 'checkpoints/').resolve()
max_epochs = trainer_config.max_epochs
plot_ir = True
use_fixed_pos = True
if use_fixed_pos:
    rec_pos_to_investigate = [1.12, 1.47, 1.69]
    src_pos_to_investigate = [4.48, 2.59, 1.29]
else:
    rec_idx = np.random.randint(0, high=room_data.num_rec, size=1, dtype=int)
    src_idx = np.random.randint(0, high=room_data.num_src, size=1, dtype=int)
    rec_pos_to_investigate = np.round(np.squeeze(room_data.receiver_position[rec_idx,:]), 2)
    src_pos_to_investigate = np.round(np.squeeze(room_data.source_position[src_idx,:]), 2)
desired_filename = f'ir_src_pos=({src_pos_to_investigate[0]:.2f}, {src_pos_to_investigate[1]:.2f}, {src_pos_to_investigate[2]:.2f})'\
+f'_rec_pos=({rec_pos_to_investigate[0]:.2f}, {rec_pos_to_investigate[1]:.2f}, {rec_pos_to_investigate[2]:.2f}).wav'

# find amplitudes corresponding to the receiver position
rec_pos_idx = np.argwhere(
    np.all(np.round(room_data.receiver_position,2) == rec_pos_to_investigate, axis=1))[0]
src_pos_idx = np.argwhere(
    np.all(np.round(room_data.source_position,2) == src_pos_to_investigate, axis=1))[0] 
amps_at_pos = np.squeeze(room_data.amplitudes[src_pos_idx, rec_pos_idx, :])
h_true = np.squeeze(room_data.rirs[src_pos_idx, rec_pos_idx, :])

### Iterate through epochs

In [None]:
h_approx_list = []
output_gains = []
input_gains = []
input_scalars = []
output_scalars = []
coupled_feedback_matrix = []
coupling_matrix = []
all_rirs_mat = np.zeros_like(room_data.rirs)


for epoch in range(max_epochs):
    # load the trained weights for the particular epoch
    checkpoint = torch.load(f'{checkpoint_dir}/model_e{epoch}.pt', weights_only=True, map_location=torch.device('cpu'))
    # Load the trained model state
    model.load_state_dict(checkpoint)
    # in eval mode, no gradients are calculated
    model.eval()
    break_outer_loop = False

    with torch.no_grad():
        param_dict = model.get_param_dict()
        input_gains.append(deepcopy(param_dict['input_gains']))
        output_gains.append(deepcopy(param_dict['output_gains']))
        coupled_feedback_matrix.append(deepcopy(param_dict['coupled_feedback_matrix']))
        coupling_matrix.append(deepcopy(param_dict['coupling_matrix']))
    
        for data in train_dataset:
            rec_position = data['listener_position']
            src_position = data['source_position']
            rec_pos_idx = np.zeros(trainer_config.batch_size, dtype=np.int32)
            src_pos_idx = np.zeros(trainer_config.batch_size, dtype=np.int32)
            
            if trainer_config.use_colorless_loss:
                H, H_sub_fdn, h = get_response(data, model)
            else:
                H, h = get_response(data, model)

            for (src_idx, num_pos) in zip(range(src_position.shape[0]), range(rec_position.shape[0])):
                if epoch == max_epochs - 1:
                    rec_pos_idx[num_pos] = np.argwhere(np.all(np.isclose(room_data.receiver_position, rec_position[num_pos]), axis=1))[0][0]
                    src_pos_idx[src_idx] = np.argwhere(np.all(np.isclose(room_data.source_position, src_position[src_idx]), axis=1))[0][0]

                filename = f'ir_src_pos=({src_position[src_idx, 0]:.2f}, {src_position[src_idx, 1]:.2f}, {src_position[src_idx, 2]:.2f})'\
                f'_rec_pos=({rec_position[num_pos,0]:.2f}, {rec_position[num_pos, 1]:.2f}, {rec_position[num_pos, 2]:.2f}).wav'
                
                if filename == desired_filename:
                    # get parameter dictionary used in inferencing
                    inf_param_dict = model.get_param_dict_inference(data)

                    # get the ir at this position
                    h_approx_list.append(h[num_pos])
    
                    # get the gains for this position
                    if 'output_scalars' in inf_param_dict.keys():
                        output_scalars.append(deepcopy(inf_param_dict['output_scalars'][num_pos]))
                    if 'input_scalars' in inf_param_dict.keys():
                        input_scalars.append(deepcopy(inf_param_dict['input_scalars'][src_idx]))

                    # not breaking the loop to collect all the RIRs
                    # break_outer_loop = True
                    # break
            if epoch == max_epochs - 1:
                all_rirs_mat[src_pos_idx, rec_pos_idx, :] = h


#### Plot the EDCs as a function of epoch number

In [None]:
def plot_edc(h_true: ArrayLike, h_approx: List[ArrayLike], fs: float, pos_to_investigate: List, amps_at_pos: List, mixing_time_ms:float=20.0):
    """Plot true and synthesised EDC curves"""
    
    mixing_time_samp = ms_to_samps(mixing_time_ms, fs)
    crop_end_samp = ms_to_samps(5.0, fs)
    
    trunc_true_ir = h_true[mixing_time_samp:-crop_end_samp]
    true_edf = np.flipud(np.cumsum(np.flipud(trunc_true_ir**2), axis=-1))
    time = np.linspace(0, (len(trunc_true_ir) - 1) / fs,
                       len(trunc_true_ir))
  
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(time, db(true_edf, is_squared=True), label='True EDF')
    ax.plot(np.zeros(len(amps_at_pos)), db(amps_at_pos, is_squared=True), 'kx')       
    ax.set_title(
        f'Truncated EDF at position {pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f} m'
    )
    
    num_epochs = len(h_approx)
    for epoch in range(0, num_epochs, 4):
        approx_ir = h_approx[epoch]
        trunc_approx_ir = approx_ir[mixing_time_samp: mixing_time_samp + len(trunc_true_ir)]
        synth_edf = np.flipud(np.cumsum(np.flipud(trunc_approx_ir**2), axis=-1))
        ax.plot(time, db(synth_edf, is_squared=True), label=f'Epoch={epoch}')
        ax.legend()
        
        display.display(fig)  # Display the updated figure
        display.clear_output(wait=True)  # Clear the previous output to keep updates in place
        plt.pause(0.1)

    fig.savefig(Path(f'{fig_path}/compare_synth_edf_{pos_to_investigate}_{config_name}.png').resolve())
    plt.show()

plot_edc(h_true, h_approx_list, config_dict.sample_rate, rec_pos_to_investigate, amps_at_pos)


#### Plot the desired and final EDC

In [None]:
def get_edc_params(rir: ArrayLike, n_slopes: int, fs:float):
    est_params_decay_net, norm_vals, fitted_edc_subband = get_decay_fit_net_params(rir, None, n_slopes, fs)
    est_t60 = np.mean(est_params_decay_net[0], axis=0)
    est_amp = np.mean(est_params_decay_net[1], axis=0)
    est_noise = np.mean(est_params_decay_net[2], axis=0)
    fitted_edc = torch.squeeze(torch.mean(fitted_edc_subband, dim=0))
    return est_t60, est_amp, est_noise, fitted_edc

def plot_final_decay_fit_net_edc(num_groups:int, fs: float, og_amps: List, est_amps: List, 
                                 og_edc: ArrayLike, synth_edc:ArrayLike, 
                                 src_pos:Optional[List]=None, rec_pos: Optional[List]=None):
    
    time = np.linspace(0, (len(og_edc) - 1) / fs, len(og_edc))
    fig, ax = plt.subplots(figsize=(6, 4))

    ax.plot(time, db(og_edc, is_squared=True), linestyle='-', label=f'Original estimated EDC (norm)')
    ax.plot(time[:len(synth_edc)], db(synth_edc, is_squared=True), linestyle='--', label=f'Final estimated EDC (norm)')
    ax.plot(np.zeros(num_groups), db(og_amps, is_squared=True), 'kx', label='Original amps')
    ax.plot(np.zeros(num_groups), db(est_amps, is_squared=True), 'gd', label='Synth amps')

    if src_pos is not None and rec_pos is not None:
        ax.set_title(f"Source pos = {np.round(src_pos, 2)}, receiver pos = {np.round(rec_pos, 2)}")

    ax.legend()
    plt.show()

mixing_time_samp = ms_to_samps(20.0, config_dict.sample_rate)
crop_end_samp = ms_to_samps(5.0, config_dict.sample_rate)
trunc_true_ir = h_true[mixing_time_samp:-crop_end_samp]
trunc_approx_ir = h_approx_list[-1][mixing_time_samp:len(trunc_true_ir)]
og_est_t60, og_est_amp, og_noise_floor, og_fitted_edc = get_edc_params(trunc_true_ir, config_dict.num_groups, config_dict.sample_rate)
est_t60, est_amp, _, fitted_edc = get_edc_params(trunc_approx_ir, config_dict.num_groups, config_dict.sample_rate)

plot_final_decay_fit_net_edc(config_dict.num_groups, config_dict.sample_rate, amps_at_pos, est_amp, og_fitted_edc, fitted_edc, 
                             src_pos=src_pos_to_investigate, rec_pos=rec_pos_to_investigate)

### Get the overall amplitude and EDC mismatch as a function of spatial location

In [None]:
from importlib import reload
import diff_gfdn
reload(diff_gfdn.plot)
from diff_gfdn.plot import plot_amps_in_space, plot_edc_error_in_space

plot_edc_error_in_space(room_data, all_rirs_mat, room_data.receiver_position, freq_to_plot=None, 
                         scatter=True, save_path=f'{fig_path}/{config_name}_mlp', pos_sorted=True)

estimated_amps = plot_amps_in_space(room_data, all_rirs_mat, room_data.receiver_position, 
                                    freq_to_plot=None, scatter=True, save_path=f'{fig_path}/multiple_sources/{config_name}_mlp', pos_sorted=True)