In [None]:
import numpy as np
import torch
import re
import os
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.io import loadmat
import soundfile as sf
import pickle
from torch import nn
from scipy.fft import rfft, rfftfreq
from numpy.typing import ArrayLike
from typing import Optional, List
from copy import deepcopy
from IPython import display
from loguru import logger

from diff_gfdn.dataloader import load_dataset, RIRData
from diff_gfdn.config.config import DiffGFDNConfig, CouplingMatrixType
from diff_gfdn.solver import convert_common_slopes_rir_to_room_dataset
from diff_gfdn.model import DiffGFDNSinglePos
from diff_gfdn.utils import is_unitary, db2lin, db, ms_to_samps, get_response
from diff_gfdn.plot import plot_edr, animate_coupled_feedback_matrix
from diff_gfdn.analysis import get_decay_fit_net_params
from diff_gfdn.colorless_fdn.utils import get_colorless_fdn_params
from run_model import load_and_validate_config

os.chdir('..')  # This changes the working directory to DiffGFDN

In [None]:
config_path = 'data/config/'
fig_path = 'figures/'
config_name = 'single_rir_fit_broadband_two_stage_decay_colorless_prototype_pos2.yml'
config_file = config_path + config_name
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)


In [None]:
room_data = convert_common_slopes_rir_to_room_dataset(config_dict.room_dataset_path, 
                                                      num_freq_bins=config_dict.trainer_config.num_freq_bins,
                                                      )

config_dict = config_dict.copy(update={"num_groups": room_data.num_rooms})

trainer_config = config_dict.trainer_config
# prepare the training and validation data for DiffGFDN
if trainer_config.batch_size != room_data.num_freq_bins:
    trainer_config = trainer_config.copy(
        update={"batch_size": room_data.num_freq_bins})

# get the colorless FDN params
if config_dict.colorless_fdn_config.use_colorless_prototype:
    colorless_fdn_params = get_colorless_fdn_params(config_dict)
else:
    colorless_fdn_params = None

# initialise the model
model = DiffGFDNSinglePos(config_dict.sample_rate,
                 config_dict.num_groups,
                 config_dict.delay_length_samps,
                 trainer_config.device,
                 config_dict.feedback_loop_config,
                 config_dict.output_filter_config,
                 use_absorption_filters=False,
                 common_decay_times=room_data.common_decay_times,
                 colorless_fdn_params=colorless_fdn_params,
                 )

### Plot output data and compare RIR

In [None]:
true_ir, fs = sf.read(Path(config_dict.ir_path).resolve())
match = re.search(r'ir_\([^)]+\)', config_dict.ir_path)
dir_name = Path(trainer_config.ir_dir).parts[-1]
ir_name = match.group()
approx_ir_path = f'{trainer_config.ir_dir}/approx_{ir_name}.wav'
final_approx_ir, fs = sf.read(Path(approx_ir_path).resolve())

plt.plot(np.stack((true_ir, final_approx_ir[:len(true_ir), 0]), axis=-1))
plt.xlim([0, int(1.5 * config_dict.sample_rate)])
plt.savefig(f'{fig_path}/compare_{dir_name}_{ir_name}.png')
plt.show()

# find receiver position from string
match = re.search(r'ir_\(([^,]+), ([^,]+), ([^,]+)\)', ir_name)
# Convert the extracted values to floats
x, y, z = map(float, match.groups())
rec_pos = np.array([x, y, z])

# find amplitudes corresponding to the receiver position
rec_pos_idx = np.argwhere(
    np.all(np.round(room_data.receiver_position,2) == rec_pos, axis=1))[0]
amplitudes = np.squeeze(room_data.amplitudes[rec_pos_idx, :])

# plot EDR
plot_edr(torch.tensor(true_ir), config_dict.sample_rate, 
         title=f'True RIR EDR', save_path=f'{fig_path}/{dir_name}_true_edr_{ir_name}.png')
plot_edr(torch.tensor(final_approx_ir[:, 0]), config_dict.sample_rate, 
         title=f'Estimated RIR EDR', save_path=f'{fig_path}/{dir_name}_approx_edr_{ir_name}.png')


### Plot EDF and estimated amplitudes for each epoch

In [None]:
from slope2noise.slope2noise.utils import decay_kernel

def get_edc_params(rir: ArrayLike, n_slopes: int, fs:float):
    est_params_decay_net, norm_vals, fitted_edc_subband = get_decay_fit_net_params(rir, None, n_slopes, fs)
    est_t60 = np.mean(est_params_decay_net[0], axis=0)
    est_amp = np.mean(est_params_decay_net[1], axis=0)
    est_noise = np.mean(est_params_decay_net[2], axis=0)
    fitted_edc = torch.squeeze(torch.mean(fitted_edc_subband, dim=0))
    return est_t60, est_amp, est_noise, fitted_edc

def get_custom_edc(t60_vals:List, amp_vals:List, noise_val: float, num_groups:int, ir_len_samps: int, fs: float):
    time = np.linspace(0, (ir_len_samps - 1) / fs, ir_len_samps)
    t60_expanded = t60_vals[:, np.newaxis, np.newaxis]
    decay_envelope = np.zeros_like(time)
    for k in range(num_groups):
        kernel =  np.squeeze(decay_kernel(t60_expanded[k, ...], time, fs, normalise_envelope=False, add_noise=False))
        # account for limited upper limit of integration
        exp_offset = kernel[-1]
        decay_envelope += amp_vals[k] * (kernel - exp_offset)
    
    noise = noise_val * np.linspace(ir_len_samps, 1, ir_len_samps)
    decay_envelope += noise

    edc = np.flipud(np.cumsum(np.flipud(decay_envelope**2), axis=-1))
    return decay_envelope

In [None]:
# read the model parameters per epoch
checkpoint_dir = Path(trainer_config.train_dir + 'checkpoints/').resolve()

# create RIRDataset
rir_data = RIRData(config_dict.ir_path,
                   common_decay_times=room_data.common_decay_times,
                   band_centre_hz=room_data.band_centre_hz,
                   amplitudes=amplitudes,
                   nfft=config_dict.trainer_config.num_freq_bins,
                  )

# prepare the training and validation data for DiffGFDN
train_dataset = load_dataset(
    rir_data, trainer_config.device, train_valid_split_ratio=1.0,
    batch_size=trainer_config.batch_size, shuffle=False)


mixing_time_samp = ms_to_samps(20.0, fs)
crop_end_samp = ms_to_samps(5.0, fs)
trunc_true_ir = true_ir[mixing_time_samp:-crop_end_samp]
true_edf = np.flipud(np.cumsum(np.flipud(trunc_true_ir**2), axis=-1))
time = np.linspace(0, (len(trunc_true_ir) - 1) / fs,
                       len(trunc_true_ir))
og_est_t60, og_est_amp, og_noise_floor, og_fitted_edc = get_edc_params(trunc_true_ir, config_dict.num_groups, config_dict.sample_rate)
custom_edc = get_custom_edc(og_est_t60, og_est_amp, og_noise_floor, config_dict.num_groups, len(trunc_true_ir), config_dict.sample_rate)

fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(time, db(true_edf, is_squared=True), label='True EDF')
ax.plot(time[:len(og_fitted_edc)], db(og_fitted_edc, is_squared=True), label='DecayFitNet EDF')
ax.plot(time, db(custom_edc, is_squared=True), linestyle='--', label='Custom EDF')
ax.plot(np.zeros(room_data.num_rooms),  db(amplitudes, is_squared=True), 'kx', label='Original amplitudes')
ax.plot(np.zeros(room_data.num_rooms),  db(og_est_amp, is_squared=True), 'ko', label='Fitted amplitudes')
ax.legend()

In [None]:
max_epochs = trainer_config.max_epochs
est_t60 = np.zeros((config_dict.num_groups, max_epochs))
est_amp = np.zeros_like(est_t60)
coupled_feedback_matrix = []
coupling_matrix = []
input_gains = []
output_gains = []
input_scalars = []
output_scalars = []
h_approx_list = []

for epoch in range(0, max_epochs):
    # load the trained weights for the particular epoch
    try:
        checkpoint = torch.load(f'{checkpoint_dir}/model_e{epoch}.pt', weights_only=True, map_location=torch.device('cpu'))
    except:
        max_epochs = epoch - 1
        break
    
    # Load the trained model state
    model.load_state_dict(checkpoint)
    # in eval mode, no gradients are calculated
    model.eval()

    with torch.no_grad():
        param_dict = model.get_param_dict()
        input_gains.append(deepcopy(param_dict['input_gains']))
        output_gains.append(deepcopy(param_dict['output_gains']))
        
        if 'coupling_matrix' in param_dict.keys():
            coupling_matrix.append(deepcopy(param_dict['coupling_matrix']))
            
        if 'coupled_feedback_matrix' in param_dict.keys():
            coupled_feedback_matrix.append(deepcopy(param_dict['coupled_feedback_matrix']))

        if 'input_scalars' in param_dict.keys():
            input_scalars.append(deepcopy(param_dict['input_scalars']))
            output_scalars.append(deepcopy(param_dict['output_scalars']))
    
        for data in train_dataset:
            H, approx_ir = get_response(data, model)

    trunc_approx_ir = approx_ir[mixing_time_samp: mixing_time_samp + len(trunc_true_ir)]
    h_approx_list.append(approx_ir)
    est_t60[:, epoch], est_amp[:, epoch], _, fitted_edc = get_edc_params(trunc_approx_ir, config_dict.num_groups, config_dict.sample_rate)
    

In [None]:
def plot_edc(h_true: ArrayLike, h_approx: List[ArrayLike], fs: float, pos_to_investigate: List, mixing_time_ms:float=20.0):
    """Plot true and synthesised EDC curves"""
    
    mixing_time_samp = ms_to_samps(mixing_time_ms, fs)
    crop_end_samp = ms_to_samps(5.0, fs)
    
    trunc_true_ir = h_true[mixing_time_samp:-crop_end_samp]
    true_edf = np.flipud(np.cumsum(np.flipud(trunc_true_ir**2), axis=-1))
    time = np.linspace(0, (len(trunc_true_ir) - 1) / fs,
                       len(trunc_true_ir))
  
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.plot(time, db(true_edf, is_squared=True), label='True EDF')

    num_epochs = len(h_approx)
    ax.set_title(
        f'Truncated EDF at position {pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f} m'
    )
    for epoch in range(0, num_epochs, 4):
        approx_ir = h_approx[epoch]
        trunc_approx_ir = approx_ir[mixing_time_samp: mixing_time_samp + len(trunc_true_ir)]
        synth_edf = np.flipud(np.cumsum(np.flipud(trunc_approx_ir**2), axis=-1))
        ax.plot(time, db(synth_edf, is_squared=True), label=f'Epoch={epoch}')
        ax.legend()
        
        display.display(fig)  # Display the updated figure
        display.clear_output(wait=True)  # Clear the previous output to keep updates in place
        plt.pause(0.1)


    fig.savefig(Path(f'{fig_path}/compare_synth_edf_{rec_pos}_{config_dict.feedback_loop_config.coupling_matrix_type.value}_random_init_lr={trainer_config.io_lr}.png').resolve())
    plt.show()
    
plot_edc(true_ir, h_approx_list, config_dict.sample_rate, rec_pos)

In [None]:
def plot_final_decay_fit_net_edc(num_groups:int, og_amps: List, est_amps: List, og_edc: ArrayLike, synth_edc:ArrayLike):
    
    time = np.linspace(0, (len(og_edc) - 1) / fs, len(og_edc))
    fig, ax = plt.subplots(figsize=(6, 4))
    print(len(time), len(synth_edc))

    ax.plot(time, db(og_edc, is_squared=True), linestyle='-', label=f'Original estimated EDC (norm)')
    ax.plot(time[:len(synth_edc)], db(synth_edc[:-1], is_squared=True), linestyle='--', label=f'Final estimated EDC (norm)')
    ax.plot(np.zeros(num_groups), db(og_amps, is_squared=True), 'kx', label='Original amps')
    ax.plot(np.zeros(num_groups), db(est_amps, is_squared=True), 'gd', label='Synth amps')

    ax.legend()
    plt.show()

plot_final_decay_fit_net_edc(config_dict.num_groups, amplitudes, est_amp[:, len(h_approx_list)-1], 
                             og_fitted_edc, fitted_edc)

#### Investigate the T60s as a function of epoch

In [None]:
logger.info(f'Original params: T60 = {np.round(og_est_t60, 3)} s, amp = {np.round(db(og_est_amp, is_squared=True), 3)} dB')
for epoch in range(0, max_epochs):
    logger.info(f'Est params for epoch {epoch}: T60 = {np.round(est_t60[:, epoch], 3)}s, amp = {np.round(db(est_amp[:, epoch], is_squared=True),3)} dB')

#### Investigate feedback matrices and input-output gains as a function of epoch number

In [None]:
from importlib import reload
import diff_gfdn
reload(diff_gfdn.plot)
from diff_gfdn.plot import animate_coupled_feedback_matrix

if config_dict.feedback_loop_config.coupling_matrix_type == CouplingMatrixType.SCALAR:
    animate_coupled_feedback_matrix(np.abs(coupled_feedback_matrix), np.abs(coupling_matrix), 
                                    save_path=Path(f'{fig_path}/animation/{config_name}_{rec_pos}_scalar_coupling_matrix_lr={trainer_config.coupling_angle_lr}.gif').resolve())
else:
    animate_coupled_feedback_matrix(np.abs(coupled_feedback_matrix), 
                                    save_path=Path(f'{fig_path}/animation/{config_name}_{rec_pos}_random_coupling_matrix.gif').resolve())
       

In [None]:
# Convert the list of vectors to a 2D numpy array (N x M matrix)
input_gain_matrix = np.stack([vec for vec in input_scalars]).T
output_gain_matrix = np.stack([vec for vec in output_scalars]).T

# Plot the matrix
fig, ax = plt.subplots(2, figsize=(6,6))
in_plot = ax[0].matshow(np.abs(input_gain_matrix), aspect='auto', cmap='viridis', vmin=0)
fig.colorbar(in_plot, label="dB", ax=ax[0])
ax[0].set_ylabel("Group number")
ax[0].set_xlabel("Epoch number")
ax[0].set_title("Input scalars vs epoch")

out_plot = ax[1].matshow(np.abs(output_gain_matrix), aspect='auto', cmap='viridis', vmin=0)
fig.colorbar(out_plot, label="dB", ax=ax[1])
ax[1].set_ylabel("Group number")
ax[1].set_xlabel("Epoch number")
ax[1].set_title("Output scalars vs epoch")
fig.subplots_adjust(hspace=0.5)
plt.show()
fig.savefig(Path(f'{fig_path}/{config_name}_io_scalars_{rec_pos}_random_init_lr={trainer_config.io_lr}.png').resolve())

In [None]:
print(db(output_gain_matrix[:,0] * input_gain_matrix[:,0]))
print(db(output_gain_matrix[:,-1] * input_gain_matrix[:,-1]))

### Compare the magnitude response of the true and synthesized RIR

In [None]:
def plot_mag_response(true_ir: ArrayLike, synth_ir: ArrayLike, fs: float, nfft: Optional[int]=None):
    if nfft is None:
        nfft = int(np.pow(2, np.ceil(np.log2(len(true_ir)))))
    true_resp = rfft(true_ir, nfft)
    synth_resp = rfft(synth_ir, nfft)
    freq_bins = rfftfreq(nfft, d = 1.0/fs)

    fig = plt.figure(figsize=(6, 4))
    plt.semilogx(freq_bins, db(true_resp), label='True')
    plt.semilogx(freq_bins, db(synth_resp), linestyle= '--', label='DiffGFDN')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Magnitude (dB)')
    plt.xlim([20, fs / 2])
    plt.ylim([-60, 20])
    plt.legend()
    

plot_mag_response(true_ir, final_approx_ir[:, 0], fs)