In [None]:
import numpy as np
import torch
import re
import os
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.io import loadmat
from torch import nn
from numpy.typing import ArrayLike
from typing import Optional, List
from IPython import display
import soundfile as sf
from loguru import logger
from copy import deepcopy

os.chdir('..')  # This changes the working directory to DiffGFDN

from diff_gfdn.dataloader import load_dataset, RIRData
from diff_gfdn.config.config import DiffGFDNConfig, CouplingMatrixType
from diff_gfdn.solver import convert_common_slopes_rir_to_room_dataset
from diff_gfdn.model import DiffGFDNVarReceiverPos
from diff_gfdn.utils import is_unitary, db2lin, db, ms_to_samps, get_response
from diff_gfdn.plot import plot_edr, animate_coupled_feedback_matrix, plot_amps_in_space
from diff_gfdn.analysis import get_decay_fit_net_params
from diff_gfdn.colorless_fdn.utils import get_colorless_fdn_params
from src.run_model import load_and_validate_config

In [None]:
config_path = 'data/config/'
fig_path = 'figures/'
config_name = 'synth_data_broadband_two_coupled_rooms_grid_training_colorless_loss.yml'
config_file = config_path + config_name
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)

In [None]:
room_data = convert_common_slopes_rir_to_room_dataset(config_dict.room_dataset_path, 
                                                      num_freq_bins=config_dict.trainer_config.num_freq_bins,
                                                      )

config_dict = config_dict.copy(update={"num_groups": room_data.num_rooms})

trainer_config = config_dict.trainer_config

# force the trainer config device to be CPU
if trainer_config.device != 'cpu':
    trainer_config = trainer_config.copy(update={"device": 'cpu'})

# prepare the training and validation data for DiffGFDN
train_dataset, valid_dataset = load_dataset(
    room_data, trainer_config.device, train_valid_split_ratio=1.0,
    batch_size=trainer_config.batch_size, shuffle=False)

# get the colorless FDN params
if config_dict.colorless_fdn_config.use_colorless_prototype:
    colorless_fdn_params = get_colorless_fdn_params(config_dict)
else:
    colorless_fdn_params = None

# initialise the model
model = DiffGFDNVarReceiverPos(config_dict.sample_rate,
                 config_dict.num_groups,
                 config_dict.delay_length_samps,
                 trainer_config.device,
                 config_dict.feedback_loop_config,
                 config_dict.output_filter_config,
                 use_absorption_filters=False,
                 common_decay_times=room_data.common_decay_times,
                 colorless_fdn_params=colorless_fdn_params,
                 use_colorless_loss=trainer_config.use_colorless_loss
                 )

In [None]:
audio_directory  = Path("audio/")
fig_path = Path("figures").resolve()
checkpoint_dir = Path(trainer_config.train_dir + 'checkpoints/').resolve()
max_epochs = trainer_config.max_epochs
plot_ir = True 
pos_to_investigate = [1.21, 2.92, 0.83] 
desired_filename = f'ir_({pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f}).wav'

# find amplitudes corresponding to the receiver position
rec_pos_idx = np.argwhere(
    np.all(np.round(room_data.receiver_position,2) == pos_to_investigate, axis=1))[0]
amps_at_pos = np.squeeze(room_data.amplitudes[rec_pos_idx, :])
h_true = np.squeeze(room_data.rirs[rec_pos_idx, :])

### Iterate through epochs

In [None]:
h_approx_list = []
output_gains = []
input_gains = []
input_scalars = []
output_scalars = []
coupled_feedback_matrix = []
coupling_matrix = []
all_pos = []
all_rirs = []

for epoch in range(max_epochs):
    # load the trained weights for the particular epoch
    checkpoint = torch.load(f'{checkpoint_dir}/model_e{epoch}.pt', weights_only=True, map_location=torch.device('cpu'))
    # Load the trained model state
    model.load_state_dict(checkpoint)
    # in eval mode, no gradients are calculated
    model.eval()
    break_outer_loop = False

    with torch.no_grad():
        param_dict = model.get_param_dict()
        input_gains.append(deepcopy(param_dict['input_gains']))
        input_scalars.append(deepcopy(param_dict['input_scalars']))
        output_gains.append(deepcopy(param_dict['output_gains']))
        coupled_feedback_matrix.append(deepcopy(param_dict['coupled_feedback_matrix']))
        coupling_matrix.append(deepcopy(param_dict['coupling_matrix']))
    
        for data in train_dataset:
            position = data['listener_position']
            if trainer_config.use_colorless_loss:
                H, H_sub_fdn, h = get_response(data, model)
            else:
                H, h = get_response(data, model)            
            for num_pos in range(position.shape[0]):
                if epoch == max_epochs - 1:
                    all_pos.append(position[num_pos])
                    all_rirs.append(h[num_pos, :])
                filename = f'ir_({position[num_pos,0]:.2f}, {position[num_pos, 1]:.2f}, {position[num_pos, 2]:.2f}).wav'
                
                if filename == desired_filename:
                    # get the ir at this position
                    h_approx_list.append(h[num_pos, :])
    
                    # get the gains for this position
                    if 'output_scalars' in param_dict.keys():
                        output_scalars.append(deepcopy(param_dict['output_scalars'][num_pos]))

                    # not breaking the loop to collect all the RIRs
                    # break_outer_loop = True
                    # break
     

#### Plot the EDCs as a function of epoch number

In [None]:
def plot_edc(h_true: ArrayLike, h_approx: List[ArrayLike], fs: float, pos_to_investigate: List, amps_at_pos: List, mixing_time_ms:float=20.0):
    """Plot true and synthesised EDC curves"""
    
    mixing_time_samp = ms_to_samps(mixing_time_ms, fs)
    crop_end_samp = ms_to_samps(5.0, fs)
    
    trunc_true_ir = h_true[mixing_time_samp:-crop_end_samp]
    true_edf = np.flipud(np.cumsum(np.flipud(trunc_true_ir**2), axis=-1))
    time = np.linspace(0, (len(trunc_true_ir) - 1) / fs,
                       len(trunc_true_ir))
  
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(time, db(true_edf, is_squared=True), label='True EDF')
    ax.plot(np.zeros(len(amps_at_pos)), db(amps_at_pos, is_squared=True), 'kx')

    num_epochs = len(h_approx)
    for epoch in range(num_epochs):
        approx_ir = h_approx[epoch]
        trunc_approx_ir = approx_ir[mixing_time_samp: mixing_time_samp + len(trunc_true_ir)]
        synth_edf = np.flipud(np.cumsum(np.flipud(trunc_approx_ir**2), axis=-1))
        ax.plot(time, db(synth_edf, is_squared=True), label=f'Epoch={epoch}')
        ax.legend()
        
        display.display(fig)  # Display the updated figure
        display.clear_output(wait=True)  # Clear the previous output to keep updates in place
        plt.pause(0.1)
        
    ax.set_title(
        f'Truncated EDF at position {pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f} m'
    )
    fig.savefig(Path(f'{fig_path}/compare_synth_edf_{pos_to_investigate}_{config_name}.png').resolve())
    plt.show()

plot_edc(h_true, h_approx_list, config_dict.sample_rate, pos_to_investigate, amps_at_pos)

#### Plot the learnt feedback matrix and input gains for the grid, plot the output gains for a single position

In [None]:
from diff_gfdn.plot import animate_coupled_feedback_matrix

if config_dict.feedback_loop_config.coupling_matrix_type == CouplingMatrixType.SCALAR:
    animate_coupled_feedback_matrix(np.abs(coupled_feedback_matrix), np.abs(coupling_matrix), 
                                    save_path=Path(f'{fig_path}/animation/{config_name}_{pos_to_investigate}_scalar_coupling_matrix.gif').resolve())
else:
    animate_coupled_feedback_matrix(np.abs(coupled_feedback_matrix), 
                                    save_path=Path(f'{fig_path}/animation/{config_name}_{pos_to_investigate}_random_coupling_matrix.gif').resolve())

# also save the final optimised matrix
if config_dict.feedback_loop_config.coupling_matrix_type == CouplingMatrixType.SCALAR:
    assert is_unitary(torch.from_numpy(coupled_feedback_matrix[-1]))[0]    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(4, 6))

    # Coupling matrix subplot
    cax1 = ax1.matshow(np.abs(coupling_matrix[-1]))
    fig.colorbar(cax1, ax=ax1)
    ax1.set_title('Coupling matrix')
    
    # Coupled feedback matrix subplot
    cax2 = ax2.matshow(np.abs(coupled_feedback_matrix[-1]))
    fig.colorbar(cax2, ax=ax2)
    ax2.set_title('Coupled feedback matrix')
    
    plt.tight_layout()
    plt.savefig(f'{fig_path}/{config_name}_scalar_coupling_matrix.png', bbox_inches='tight')
else:
    unit_flag, max_val = is_unitary(torch.tensor(coupled_feedback_matrix[-1]), max_tol=1e-4)
    assert unit_flag
    plt.figure()
    plt.matshow(np.abs(coupled_feedback_matrix[-1]))
    plt.title('Optimised feedback matrix')
    plt.savefig(f'{fig_path}/{config_name}_random_coupling_matrix.png')
    

In [None]:
# Convert the list of vectors to a 2D numpy array (N x M matrix)
input_gain_matrix = np.stack([vec for vec in input_scalars]).T
output_gain_matrix = np.stack([vec for vec in output_scalars]).T

# Plot the matrix
fig, ax = plt.subplots(2, figsize=(6,6))
in_plot = ax[0].matshow(np.abs(input_gain_matrix), aspect='auto', cmap='viridis')
fig.colorbar(in_plot, label="dB", ax=ax[0])
ax[0].set_ylabel("Group number")
ax[0].set_xlabel("Epoch number")
ax[0].set_title("Input scalars vs epoch")

out_plot = ax[1].matshow(np.abs(output_gain_matrix), aspect='auto', cmap='viridis')
fig.colorbar(out_plot, label="dB", ax=ax[1])
ax[1].set_ylabel("Group number")
ax[1].set_xlabel("Epoch number")
ax[1].set_title("Output scalars vs epoch")
fig.subplots_adjust(hspace=0.5)
plt.show()
fig.savefig(Path(f'{fig_path}/{config_name}_{pos_to_investigate}_io_scalars.png').resolve())

### Plot the amplitude distribution for each RIR as a position of space

In [None]:
from importlib import reload
import diff_gfdn
reload(diff_gfdn.plot)
from diff_gfdn.plot import plot_amps_in_space

plot_amps_in_space(room_data, all_rirs, all_pos, freq_to_plot=None, scatter=True, save_path=f'{fig_path}/{config_name}')