In [None]:
import numpy as np
import torch
import re
import pyfar as pf
import os
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.io import loadmat
from torch import nn
from numpy.typing import ArrayLike, NDArray
from typing import Optional, List
from IPython import display
import soundfile as sf
from loguru import logger
from copy import deepcopy
scale = 2
plt.rcParams.update({
    'font.size': scale * 8,  # base font size
    'axes.labelsize': scale * 9,  # x/y label
    'xtick.labelsize': scale * 8,
    'ytick.labelsize': scale * 8,
    'legend.fontsize': scale * 8,
    'axes.titlesize': scale * 10,  # usually unused in journal figures
})

os.chdir('..')  # This changes the working directory to DiffGFDN

from diff_gfdn.dataloader import load_dataset, RIRData, ThreeRoomDataset, RoomDataset
from diff_gfdn.config.config import DiffGFDNConfig, CouplingMatrixType, SubbandProcessingConfig
from diff_gfdn.solver import convert_common_slopes_rir_to_room_dataset
from diff_gfdn.model import DiffGFDNVarReceiverPos
from diff_gfdn.utils import is_unitary, db2lin, db, ms_to_samps, normalised_echo_density
from diff_gfdn.plot import animate_coupled_feedback_matrix, plot_edc_error_in_space
from diff_gfdn.analysis import get_decay_fit_net_params
from diff_gfdn.colorless_fdn.utils import get_colorless_fdn_params
from diff_gfdn.colorless_fdn.model import ColorlessFDN
from diff_gfdn.inference import InferDiffGFDN
from diff_gfdn.config.config_loader import load_and_validate_config

from slope2noise.utils import schroeder_backward_int

In [None]:
config_path = 'data/config/'
fig_path = 'figures/'
# freq_to_plot = [63, 125, 250, 500, 1000, 2000, 4000, 8000]
freq_to_plot = [2000]
split_ratio = 0.8
audio_directory  = Path("audio/")
fig_path = Path("figures").resolve()
use_fixed_pos = True
coupling_matrix = []

def diagonal_measure(matrix: NDArray):
    """Ratio of diagonal to all terms, if 1 matrix is perfectly diagonal"""
    return np.sum(np.diag(matrix)**2) / np.sum(matrix**2)

In [None]:
for k in range(len(freq_to_plot)):
    
    config_name = f'treble_data_grid_training_{freq_to_plot[k]}Hz_colorless_loss_diff_delays.yml'
    config_file = config_path + config_name
    config_dict = load_and_validate_config(config_file,
                                           DiffGFDNConfig)
    if "3room_FDTD" in config_dict.room_dataset_path:
        room_data = ThreeRoomDataset(Path(config_dict.room_dataset_path).resolve(), config_dict)
        dataset_has_cs_params = True
    else:
        room_data = convert_common_slopes_rir_to_room_dataset(config_dict.room_dataset_path, 
                                                              num_freq_bins=config_dict.trainer_config.num_freq_bins,
                                                              )
        dataset_has_cs_params = True
    
    
    config_dict = config_dict.model_copy(update={"num_groups": room_data.num_rooms})    
    trainer_config = config_dict.trainer_config
    # force the trainer config device to be CPU
    if trainer_config.device != 'cpu':
        trainer_config = trainer_config.model_copy(update={"device": 'cpu'})
    
    folder_name = trainer_config.train_dir.rstrip("/").rsplit("/", 1)[-1]
    trainer_config = trainer_config.model_copy(update={"train_dir": f"output/train_split_test/{split_ratio:.1f}/{folder_name}/"})

    # only for testing GPU output
    # trainer_config = trainer_config.model_copy(update={"train_dir": f"output/train_split_test/fixed_test_set/new_results_latest/{split_ratio:.1f}/{folder_name}/"})
    # trainer_config = trainer_config.model_copy(update={"max_epochs": 20})

    # write it back into the parent model
    config_dict = config_dict.model_copy(update={"trainer_config": trainer_config})

    # get the colorless FDN params
    if config_dict.colorless_fdn_config.use_colorless_prototype:
        colorless_fdn_params = get_colorless_fdn_params(config_dict, colorless_dir=config_dict.colorless_fdn_config.saved_param_path)
    else:
        colorless_fdn_params = None
    
    # initialise the model
    model = DiffGFDNVarReceiverPos(config_dict.sample_rate,
                     config_dict.num_groups,
                     config_dict.delay_length_samps,
                     trainer_config.device,
                     config_dict.feedback_loop_config,
                     config_dict.output_filter_config,
                     config_dict.decay_filter_config.use_absorption_filters,
                     common_decay_times=room_data.common_decay_times if config_dict.decay_filter_config.initialise_with_opt_values else None,
                     learn_common_decay_times=config_dict.decay_filter_config.learn_common_decay_times,
                     colorless_fdn_params=colorless_fdn_params,
                     use_colorless_loss=trainer_config.use_colorless_loss
                     )

    # get checkpoint and investigate single position
    checkpoint_dir = Path(trainer_config.train_dir + 'checkpoints/').resolve()
    # for GPU training
    max_epochs = trainer_config.max_epochs
    if use_fixed_pos:
        pos_to_investigate =  [8.5, 3.5, 1.5] #[6.9, 8.7, 1.5] #[0.8, 4.4, 1.5] # [7.5, 10.2, 1.5]
    else:
        rec_idx = np.random.randint(0, high=room_data.num_rec, size=1, dtype=int)
        pos_to_investigate = np.round(np.squeeze(room_data.receiver_position[rec_idx,:]), 2)
    desired_filename = f'ir_({pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f}).wav'
    
    # find amplitudes corresponding to the receiver position
    rec_pos_idx = np.argwhere(
        np.all(np.round(room_data.receiver_position,2) == pos_to_investigate, axis=1))[0]
    amps_at_pos = np.squeeze(room_data.amplitudes[rec_pos_idx, :])
          
    h_true = np.squeeze(room_data.rirs[rec_pos_idx, :])

    ### Get inferred output
    infer_gfdn = InferDiffGFDN(room_data, config_dict, model, use_direct_cs_params=False)
    epoch_list = [-1, max_epochs-1]
    h_approx_list, all_pos, all_rirs, all_output_scalars, all_learned_params = infer_gfdn.get_model_output(epoch_list, 
                                                                                                           desired_filename, 
                                                                                                           False,
                                                                                                           h_true,
                                                                                                           )
    if not config_dict.feedback_loop_config.use_zero_coupling:
        logger.info(f'Diagonality measure of initial coupling matrix = {diagonal_measure(all_learned_params.coupling_matrix[0]):.3f}')
        logger.info(f'Diagonality measure of final coupling matrix = {diagonal_measure(all_learned_params.coupling_matrix[-1]):.3f}')
        save_path = f'{fig_path}/animation/{config_name}_learned_feedback_matrix.gif'
        animate_coupled_feedback_matrix(all_learned_params.coupled_feedback_matrix, all_learned_params.coupling_matrix,save_path)
    
    logger.debug(f"Plotting EDC error for frequency: {freq_to_plot[k]}Hz for split ratio = {split_ratio:.1f}")
    plot_edc_error_in_space(room_data, all_rirs[-1], all_pos[-1], freq_to_plot=None, scatter=True, 
                        save_path=f'{fig_path}/{config_name}_split={split_ratio:.1f}_test_set_only',
                        norm_edc=False, 
                        use_amp_preserving_filterbank=trainer_config.subband_process_config.use_amp_preserving_filterbank)
    plt.show()

In [None]:
fs = room_data.sample_rate
mixing_time_samp = ms_to_samps(20.0, fs)
crop_end_samp = ms_to_samps(5.0, fs)

fig, ax = plt.subplots(figsize=(6, 4))
trunc_true_ir = h_true[mixing_time_samp:-crop_end_samp] 
time = np.linspace(0, (len(trunc_true_ir) - 1) / fs,
                   len(trunc_true_ir))

true_edf = schroeder_backward_int(trunc_true_ir, normalize=False, discard_last_zeros=False)
ax.plot(time, db(true_edf, is_squared=True), label='True EDC', linestyle='-')
ax.set_title(
    f'Truncated EDF at position {pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f} m'
)

trunc_approx_ir_subband_gfdn = h_approx_list[-1].cpu().detach().numpy()
synth_edf_subband = schroeder_backward_int(trunc_approx_ir_subband_gfdn[mixing_time_samp:-crop_end_samp], normalize=False, discard_last_zeros=False)
ax.plot(time, db(synth_edf_subband[:len(true_edf)], is_squared= True), label='Subband DiffGFDN',linestyle='--')
ax.legend()