In [None]:
import numpy as np
import torch
import re
import os
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.io import loadmat
from torch import nn
from numpy.typing import ArrayLike
from typing import Optional, List
from IPython import display
import soundfile as sf
from loguru import logger
from copy import deepcopy

from diff_gfdn.dataloader import load_dataset, RIRData
from diff_gfdn.config.config import DiffGFDNConfig, CouplingMatrixType
from diff_gfdn.solver import convert_common_slopes_rir_to_room_dataset
from diff_gfdn.model import DiffGFDNVarReceiverPos
from diff_gfdn.utils import is_unitary, db2lin, db, ms_to_samps, get_response
from diff_gfdn.plot import plot_edr, animate_coupled_feedback_matrix, plot_subband_edc, plot_learned_svf_response, plot_amps_in_space
from diff_gfdn.analysis import get_decay_fit_net_params
from run_model import load_and_validate_config
os.chdir('..')  # This changes the working directory to DiffGFDN

In [None]:
config_path = 'data/config/'
fig_path = 'figures/'
config_name = 'synth_data_subband_two_coupled_rooms_grid_training.yml'
config_file = config_path + config_name
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)

In [None]:
room_data = convert_common_slopes_rir_to_room_dataset(config_dict.room_dataset_path, 
                                                      num_freq_bins=config_dict.trainer_config.num_freq_bins,
                                                      )

config_dict = config_dict.copy(update={"num_groups": room_data.num_rooms})

trainer_config = config_dict.trainer_config

# force the trainer config device to be CPU
if trainer_config.device != 'cpu':
    trainer_config = trainer_config.copy(update={"device": 'cpu'})

# prepare the training and validation data for DiffGFDN
train_dataset, valid_dataset = load_dataset(
    room_data, trainer_config.device, train_valid_split_ratio=1.0,
    batch_size=trainer_config.batch_size, shuffle=False)

# initialise the model
model = DiffGFDNVarReceiverPos(config_dict.sample_rate,
                 config_dict.num_groups,
                 config_dict.delay_length_samps,
                 trainer_config.device,
                 config_dict.feedback_loop_config,
                 config_dict.output_filter_config,
                 use_absorption_filters=True,
                 common_decay_times=room_data.common_decay_times,
                 band_centre_hz=room_data.band_centre_hz,
                 )

In [None]:
audio_directory  = Path("audio/")
fig_path = Path("figures").resolve()
checkpoint_dir = Path(trainer_config.train_dir + 'checkpoints/').resolve()
max_epochs = trainer_config.max_epochs
plot_ir = True 
pos_to_investigate = [2.41, 5.54, 1.10] #[0.56, 4.27, 0.65]
desired_filename = f'ir_({pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f}).wav'

# find amplitudes corresponding to the receiver position
rec_pos_idx = np.where(
    np.all(np.round(room_data.receiver_position, 2) == pos_to_investigate, axis=1))[0]
amplitudes = room_data.amplitudes[..., rec_pos_idx].T
h_true = np.squeeze(room_data.rirs[rec_pos_idx, :])
if room_data.num_rooms == 1:
    amplitudes = amplitudes[:, np.newaxis]

### Iterate through epochs

In [None]:
h_approx_list = []
output_gains = []
input_gains = []
coupled_feedback_matrix = []
coupling_matrix = []
output_biquad_coeffs = []
svf_params = []
all_pos = []
all_rirs = []

for epoch in range(max_epochs):
    # load the trained weights for the particular epoch
    checkpoint = torch.load(f'{checkpoint_dir}/model_e{epoch}.pt', weights_only=True, map_location=torch.device('cpu'))
    # Load the trained model state
    model.load_state_dict(checkpoint)
    # in eval mode, no gradients are calculated
    model.eval()
    break_outer_loop = False

    with torch.no_grad():
        param_dict = model.get_param_dict()
        input_gains.append(deepcopy(param_dict['input_gains']))

        if 'output_gains' in param_dict.keys():
            output_gains.append(deepcopy(param_dict['output_gains']))
        if 'coupled_feedback_matrix' in param_dict.keys():
            coupled_feedback_matrix.append(deepcopy(param_dict['coupled_feedback_matrix']))
        if 'coupling_matrix' in param_dict.keys():
            coupling_matrix.append(deepcopy(param_dict['coupling_matrix']))
        if 'random_feedback_matrix' in param_dict.keys():
            coupled_feedback_matrix.append(deepcopy(param_dict['random_feedback_matrix']))
       
                            
        for data in train_dataset:
            position = data['listener_position']
            H, h = get_response(data, model)
            
            for num_pos in range(position.shape[0]):
                if epoch == max_epochs - 1:
                    all_pos.append(position[num_pos])
                    all_rirs.append(h[num_pos, :])
                filename = f'ir_({position[num_pos,0]:.2f}, {position[num_pos, 1]:.2f}, {position[num_pos, 2]:.2f}).wav'
                
                if filename == desired_filename:
                    # get the ir at this position
                    h_approx_list.append(h[num_pos, :])
    
                    # get the gains for this position
                    try:
                        output_biquad_coeffs.append(deepcopy(param_dict['output_biquad_coeffs'][num_pos]))
                        svf_params.append(deepcopy(param_dict['output_svf_params'][num_pos]))
                    except Exception as e:
                        logger.warning(e)
                        continue
    
                    # plot the EDRs of the true and estimated
                    plot_edr(h_true, model.sample_rate, title=f'True RIR EDR, epoch={epoch}', 
                             save_path=f'{fig_path}/true_edr_{filename}_{config_name}_epoch={epoch}.png')
    
                    plot_edr(h[num_pos, :], model.sample_rate, title=f'Estimated RIR EDR, epoch={epoch}', 
                             save_path=f'{fig_path}/approx_edr_{filename}_{config_name}_epoch={epoch}.png')
            
                    plt.plot(torch.stack((h_true, h[num_pos, :len(h_true)]), dim=-1))
                    plt.xlim([0, int(1.5 * model.sample_rate)])
                    plt.savefig(f'{fig_path}/ir_compare_{filename}_{config_name}_epoch={epoch}.png')
                    plt.show()

### Plot subband EDC as function of epoch number

In [None]:
plot_subband_edc(h_true, h_approx_list, config_dict.sample_rate, room_data.band_centre_hz, pos_to_investigate, 
                 save_path=f'{fig_path}/compare_synth_edf_{pos_to_investigate}_{config_name}.png')

### Investigate output SVFs as a function of epoch number

In [None]:
from importlib import reload
import diff_gfdn
reload(diff_gfdn.plot)
from diff_gfdn.plot import plot_learned_svf_response
plot_learned_svf_response(config_dict.num_groups, config_dict.sample_rate, 
                          output_biquad_coeffs, pos_to_investigate, svf_params=svf_params, save_path=f'{fig_path}/{config_name}')

### Plot the amplitude distribution for each RIR as a position of space for 1kHz band

In [None]:
from diff_gfdn.plot import plot_amps_in_space

plot_amps_in_space(room_data, all_rirs, all_pos, scatter=True, save_path=f'{fig_path}/{config_name}')