In the new architecture of the DiffGFDN, we have decoupled FDNs, each having the common decay times associated with each room in the coupled space, The block diagonal mixing matrix, $\mathbf{A}$ and the input and output gains, $\mathbf{b,c}$ are determined using colourless FDN optimisation. The only thing left to be determined is the source and receiver filters, $\mathbf{g_i}(z), \mathbf{g_o}(z)$. Using the common-slopes model, we see that these gains model the position-dependent amplitudes associated with each common slope. Can we find these filters by using a rank-1 decomposition of the common slope amplitudes in each frequency band?

Let's first start by investigating the broadband case, where the source-receiver filters are replaced with broadband gains. We test our hypothesis on a synthetically generated dataset of shaped white noise, which has two slopes, each with decay time $0.5$s and $1.7$s. The amplitudes are spatial functions of the room geometry and source-receiver positions. Let the amplitude of the $k$th slope at source position $\mathbf{x_s}$ and receiver position $\mathbf{x_r}$ be denoted as $A_k(\mathbf{x_s, x_r}).$ On doing a common-slope analysis of RIRs at source positions $\mathbf{x_{s_1}} \cdots, \mathbf{x_{s_M}}$ and receiver multiple receiver positions $\mathbf{x_{r_1}}, \cdots, \mathbf{x_{r_N}}$, we can get a matrix of amplitudes for the $k$th slope, 

\begin{aligned}
\mathbf{A}_k = \begin{bmatrix}A_k(\mathbf{x_{s_1}, x_{r_1}}) & A_k(\mathbf{x_{s_1}, x_{r_2}}) & \cdots  & A_k(\mathbf{x_{s_1}, x_{r_N}}) \\
\vdots & \vdots & \ddots & \vdots \\
A_k(\mathbf{x_{s_M}, x_{r_1}}) & A_k(\mathbf{x_{s_M}, x_{r_2}}) & \cdots  & A_k(\mathbf{x_{s_M}, x_{r_N}})
\end{bmatrix}
\end{aligned}

We can do a rank-1 decomposition of this matrix using singular value decomposition
\begin{aligned}
\mathbf{A}_k  &= \mathbf{U_k} \Sigma_k \mathbf{V_k}^{-1} \\
& \approx \sigma_{1k} \mathbf{u}_{1k} \mathbf{v}_{1k}^{H}
\end{aligned}
where $\mathbf{u}_{1k} \in \mathbb{C}^{M \times 1}, \mathbf{v}_{1k} \in \mathbb{C}^{N \times 1}$ are the left and right singular vectors associated with the largest singular value, $\sigma_{1k}$. Using this decomposition, we can set the source and receiver gains for the $k$th FDN as, 

\begin{aligned}
\mathbf{g_{i_k}} &= \begin{bmatrix} g_{i_k}(\mathbf{x_{s_1}}) & \cdots & g_{i_k}(\mathbf{x_{s_M}})\end{bmatrix}^T = \sqrt{\sigma_{1k}}\mathbf{u}_{1k} \\
\mathbf{g_{o_k}} &= \begin{bmatrix} g_{o_k}(\mathbf{x_{r_1}}) & \cdots & g_{o_k}(\mathbf{x_{r_N}})\end{bmatrix}^T = \sqrt{\sigma_{1k}}\mathbf{v}_{1k}
\end{aligned}

We can pre-compute these gains from common-slope analysis and interpolate between them for dynamic source and listener positions. Now, if we do not have a regular grid of source-receiver positions, for example, for source location $\mathbf{x_{s_m}}$, we only have measured RIRs at $P < N$ receiver positions, then some entries in the matrix $\mathbf{A}_k$ will be 0. We can still recover the source and receiver gains by minimising the loss:

\begin{aligned}
\mathcal{L}(\mathbf{g}_{ik}, \mathbf{g}_{ok}) := ||\mathbf{g}_{ik} \mathbf{g}_{ok}^T - \mathbf{A}_k||_F
\end{aligned}
We can initialise $\mathbf{g}_{ik}, \mathbf{g}_{ok}$ with the rank-1 decomposition suggested above, and then use gradient descent to converge to an optimal solution of $\mathbf{g}_{ik}^*, \mathbf{g}_{ok}^*$
To fix scaling ambiguities, i.e., the optimum value of the source gains is found as, $\mathbf{g}^*_{ik}\alpha$ and the optimum values of receiver gains is found as, $\frac{\mathbf{g}^*_{ok}}{\alpha}$ for some non-zero $\alpha$, a modified loss is considered,

\begin{aligned}
\mathcal{L}^* (\mathbf{g}_{ik}, \mathbf{g}_{ok}) := ||\mathbf{g}_{ik} \mathbf{g}_{ok}^T - \mathbf{A}_k||_F + \lambda (\mathbf{g}_{ik}^T \mathbf{g}_{ik} - \mathbf{g}_{ok}^T\mathbf{g}_{ok})
\end{aligned}

In [None]:
import numpy as np
import torch
import re
import os
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.io import loadmat
from torch import nn
from numpy.typing import ArrayLike
from typing import Optional, List
from IPython import display
from tqdm import tqdm
import soundfile as sf
from loguru import logger
from copy import deepcopy
os.chdir('..')  # This changes the working directory to DiffGFDN

from diff_gfdn.dataloader import load_dataset, RIRData, RoomDataset
from diff_gfdn.config.config import DiffGFDNConfig, CouplingMatrixType
from diff_gfdn.solver import convert_common_slopes_rir_to_room_dataset, run_training_colorless_fdn
from diff_gfdn.model import DiffGFDNSinglePos
from diff_gfdn.utils import is_unitary, db2lin, db, ms_to_samps, get_response
from diff_gfdn.plot import (plot_edr, animate_coupled_feedback_matrix, plot_subband_edc, 
                            plot_learned_svf_response, plot_amps_in_space, plot_edc_error_in_space)
from diff_gfdn.analysis import get_decay_fit_net_params
from src.run_model import load_and_validate_config

### Create config file and dataset

In [None]:
config_path = 'data/config/'
fig_path = 'figures/'

config_params = {
    'room_dataset_path': 'resources/synthetic_dataset/two_coupled_rooms_multi_source/bb_wgn_0000.pkl',
    'sample_rate': 48000.0,
    'num_delay_lines': 8,
    'use_absorption_filters': False,
    'feedback_loop_config':{
      'coupling_matrix_type': 'scalar_matrix',
    },
    'output_filter_config':
    {
      'use_svfs': False,
    },
    'colorless_fdn_config':{
      'use_colorless_prototype': True,
      'batch_size': 4000,
      'max_epochs': 15,
    },
    'trainer_config':{
          'batch_size': 1,
          'num_freq_bins': 96000,
          'train_dir': 'output/single_rir_out_gains_synth_data_colorless_prototype_low_rank_decomp/',
          'ir_dir': 'audio/single_rir_out_gains_synth_data_colorless_prototype_low_rank_decomp/',
    },
}

    
config_dict = DiffGFDNConfig(**config_params)
config_name = 'synth_data_broadband_two_coupled_rooms_colorless_prototype_low_rank_decomp'
room_data = convert_common_slopes_rir_to_room_dataset(config_dict.room_dataset_path, 
                                                      num_freq_bins=config_dict.trainer_config.num_freq_bins,
                                                     )
ir_dir = Path(config_dict.trainer_config.ir_dir).resolve()
if not os.path.exists(ir_dir):
    os.makedirs(ir_dir)                                                      
config_dict = config_dict.copy(update={"num_groups": room_data.num_rooms})

trainer_config = config_dict.trainer_config
# prepare the training and validation data for DiffGFDN
if trainer_config.batch_size != room_data.num_freq_bins:
    trainer_config = trainer_config.copy(
        update={"batch_size": room_data.num_freq_bins})

mixing_time_samp = ms_to_samps(20.0, config_dict.sample_rate)
crop_end_samp = ms_to_samps(5.0, config_dict.sample_rate)

### Get the source-receiver gains with rank-1 decomposition

In [None]:
# do a rank-1 decomposition of the matrix
g_in = np.zeros((len(source_locs), room_data.num_rooms))
g_out = np.zeros_like(g_in)
A_recons = np.zeros_like(A_matrix)
A_matrix = room_data.amplitudes

for k in range(room_data.num_rooms):
    cur_amp_matrix = A_matrix[...,k]
    [U, S, Vh] = np.linalg.svd(cur_amp_matrix)
    max_svd_idx = np.argmax(np.abs(S), axis=0)
    g_in[:, k] = np.sqrt(S[max_svd_idx]) * U[max_svd_idx, :]
    g_out[:, k] = np.sqrt(S[max_svd_idx]) * Vh[:, max_svd_idx]
    cur_recons_matrix = g_in[...,k] @ g_out[...,k].T
    A_recons[...,k] = cur_recons_matrix
    print(f'Variance explained by first principal component: {S[max_svd_idx] / np.sum(S)}') 
    print(f'Reconstruction error: {np.linalg.norm(cur_amp_matrix - cur_recons_matrix)}')

### Run the DiffGFDN with the source-receiver gains

In [None]:
# get the colorless FDN params
if config_dict.colorless_fdn_config.use_colorless_prototype:
    colorless_fdn_params = run_training_colorless_fdn(config_dict, num_freq_bins = trainer_config.num_freq_bins)
else:
    colorless_fdn_params = None

# initialise the model
model = DiffGFDNSinglePos(config_dict.sample_rate,
                 config_dict.num_groups,
                 config_dict.delay_length_samps,
                 trainer_config.device,
                 config_dict.feedback_loop_config,
                 config_dict.output_filter_config,
                 use_absorption_filters=False,
                 common_decay_times=room_data.common_decay_times,
                 colorless_fdn_params=colorless_fdn_params,
                )

In [None]:
# loop over all source and receiver positions
all_rir_recons = np.zeros((room_data.num_src, room_data.num_rec, room_data.rir_length), dtype=np.float32)

for src_idx in tqdm(range(room_data.num_src)):
    src_pos_to_investigate = np.squeeze(np.round(source_locs[src_idx,:], 2))
    logger.info(f'Running GFDN for source pos : {src_pos_to_investigate}')
    for rec_idx in range(room_data.num_rec):
        rec_pos_to_investigate = np.squeeze(np.round(receiver_locs[rec_idx, :],2))
        true_ir = np.squeeze(room_data.rirs[src_idx, rec_idx, :])
        amplitudes = np.squeeze(room_data.amplitudes[src_idx, rec_idx, :])
        filename = f'ir_src={src_pos_to_investigate}_rec={rec_pos_to_investigate}.wav'
        sf.write(Path(trainer_config.ir_dir + 'true_' + filename).resolve(), true_ir, int(config_dict.sample_rate)) 
        
        # create RIRDataset
        rir_data = RIRData(rir=true_ir,
                           sample_rate=config_dict.sample_rate,
                           common_decay_times=room_data.common_decay_times,
                           band_centre_hz=room_data.band_centre_hz,
                           amplitudes=amplitudes,
                           nfft=config_dict.trainer_config.num_freq_bins,
                           )
        
        # prepare the training and validation data for DiffGFDN
        train_dataset = load_dataset(
            rir_data, trainer_config.device, train_valid_split_ratio=1.0,
            batch_size=trainer_config.batch_size, shuffle=False)

        model.eval()
        cur_gin = g_in[src_idx, :]
        cur_gout = g_out[rec_idx, :]
        model.input_scalars = nn.Parameter(torch.tensor(cur_gin[:, np.newaxis], dtype=torch.float32))
        model.output_scalars = nn.Parameter(torch.tensor(cur_gout[:, np.newaxis], dtype=torch.float32))
        
        with torch.no_grad():
            for data in train_dataset:
                H, approx_ir = get_response(data, model)
                        
                # the input and output scalars should be scaled to match the energy of the desired RIR
                H_late = torch.fft.rfft(approx_ir[mixing_time_samp:])
                energyH = torch.mean(torch.pow(torch.abs(H_late), 2))
                energyH_target = torch.mean(
                    torch.pow(torch.abs(data['target_late_response']), 2))
                energy_diff = torch.div(energyH, energyH_target)
                for name, prm in model.named_parameters():
                    if name in ('input_scalars', 'output_scalars'):
                        prm.data.copy_(
                            torch.div(prm.data, torch.pow(energy_diff, 1 / 4)))
                H, approx_ir = get_response(data, model)
                sf.write(Path(trainer_config.ir_dir + 'approx_' + filename).resolve(), approx_ir, int(config_dict.sample_rate))
                all_rir_recons[src_idx, rec_idx, :] = approx_ir[:room_data.rir_length].detach().numpy()       

### Evaluate the output at one position

In [None]:
def get_edc_params(rir: ArrayLike, n_slopes: int, fs:float):
    est_params_decay_net, norm_vals, fitted_edc_subband = get_decay_fit_net_params(rir, None, n_slopes, fs)
    est_t60 = np.mean(est_params_decay_net[0], axis=0)
    est_amp = np.mean(est_params_decay_net[1], axis=0)
    est_noise = np.mean(est_params_decay_net[2], axis=0)
    fitted_edc = torch.squeeze(torch.mean(fitted_edc_subband, dim=0))
    return est_t60, est_amp, est_noise, fitted_edc

def plot_final_decay_fit_net_edc(num_groups:int, fs: float, og_amps: List, est_amps: List, og_edc: ArrayLike, synth_edc:ArrayLike):
    
    time = np.linspace(0, (len(og_edc) - 1) / fs, len(og_edc))
    fig, ax = plt.subplots(figsize=(6, 4))

    ax.plot(time, db(og_edc, is_squared=True), linestyle='-', label=f'Original estimated EDC (norm)')
    ax.plot(time[:len(synth_edc)], db(synth_edc, is_squared=True), linestyle='--', label=f'Final estimated EDC (norm)')
    ax.plot(np.zeros(num_groups), db(og_amps, is_squared=True), 'kx', label='Original amps')
    ax.plot(np.zeros(num_groups), db(est_amps, is_squared=True), 'gd', label='Synth amps')

    ax.legend()
    plt.show()

def plot_edc(h_true: ArrayLike, h_approx: ArrayLike, fs: float, mixing_time_ms:float=20.0, og_amps: Optional[List]=None):
    """Plot true and synthesised EDC curves"""
    
    mixing_time_samp = ms_to_samps(mixing_time_ms, fs)
    crop_end_samp = ms_to_samps(5.0, fs)
    
    trunc_true_ir = h_true[mixing_time_samp:-crop_end_samp]
    true_edf = np.flipud(np.cumsum(np.flipud(trunc_true_ir**2), axis=-1))
    time = np.linspace(0, (len(trunc_true_ir) - 1) / fs,
                       len(trunc_true_ir))
  
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.plot(time, db(true_edf, is_squared=True), label='True EDF')
    if og_amps is not None:
        ax.plot(np.zeros(len(og_amps)), db(og_amps, is_squared=True), 'kd')
    trunc_approx_ir = h_approx[mixing_time_samp: mixing_time_samp + len(trunc_true_ir)]
    synth_edf = np.flipud(np.cumsum(np.flipud(trunc_approx_ir**2), axis=-1))
    ax.plot(time, db(synth_edf, is_squared=True), label='Synth EDF')
    ax.legend()
    fig.savefig(Path(f'{fig_path}/compare_synth_edf_{config_name}.png').resolve())
    plt.show()

In [None]:
src_idx = np.random.randint(0, high=room_data.num_src, size=1, dtype=int)
rec_idx = np.random.randint(0, high=room_data.num_rec, size=1, dtype=int)

true_ir = np.squeeze(room_data.rirs[src_idx, rec_idx, :])
approx_ir = np.squeeze(all_rir_recons[src_idx, rec_idx, :])
amplitudes = np.squeeze(room_data.amplitudes[src_idx, rec_idx, :])

plt.plot(np.stack((true_ir, approx_ir), axis=-1))
plt.xlim([0, int(1.5 * config_dict.sample_rate)])
plt.show()

trunc_true_ir = true_ir[mixing_time_samp:-crop_end_samp]
trunc_approx_ir = approx_ir[mixing_time_samp:-crop_end_samp]
og_est_t60, og_est_amp, og_noise_floor, og_fitted_edc = get_edc_params(trunc_true_ir, config_dict.num_groups, config_dict.sample_rate)
est_t60, est_amp, _, fitted_edc = get_edc_params(trunc_approx_ir, config_dict.num_groups, config_dict.sample_rate)

plot_final_decay_fit_net_edc(config_dict.num_groups, config_dict.sample_rate, amplitudes, est_amp, og_fitted_edc, fitted_edc)
plot_edc(true_ir, approx_ir, config_dict.sample_rate, og_amps=amplitudes)

### Get the overall amplitude and EDC mismatch as a function of spatial location

In [None]:
from importlib import reload
import diff_gfdn
reload(diff_gfdn.plot)
from diff_gfdn.plot import plot_amps_in_space, plot_edc_error_in_space

plot_edc_error_in_space(room_data, all_rir_recons, room_data.receiver_position, freq_to_plot=None, 
                         scatter=True, save_path=f'{fig_path}/{config_name}', pos_sorted=True)

estimated_amps = plot_amps_in_space(room_data, all_rir_recons, room_data.receiver_position, 
                                    freq_to_plot=None, scatter=True, save_path=f'{fig_path}/{config_name}', pos_sorted=True)

In [None]:
from scipy.io import savemat
mdic = {'src_pos' : room_data.source_position, 'rec_pos': room_data.receiver_position, 
        'approx_rir': all_rir_recons, 'est_amps': estimated_amps}
save_path = Path(config_dict.trainer_config.train_dir + 'low_rank_decomp_GFDN_approx_rirs.mat').resolve()
savemat(save_path, mdic)