In [None]:
import numpy as np
import torch
import re
import pyfar as pf
import os
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.io import loadmat
from torch import nn
from numpy.typing import ArrayLike, NDArray
from typing import Optional, List
from IPython import display
import soundfile as sf
from loguru import logger
from copy import deepcopy

os.chdir('..')  # This changes the working directory to DiffGFDN

from diff_gfdn.dataloader import load_dataset, RIRData, ThreeRoomDataset, RoomDataset
from diff_gfdn.config.config import DiffGFDNConfig, CouplingMatrixType, SubbandProcessingConfig
from diff_gfdn.solver import convert_common_slopes_rir_to_room_dataset
from diff_gfdn.model import DiffGFDNVarReceiverPos
from diff_gfdn.utils import is_unitary, db2lin, db, ms_to_samps, get_response
from diff_gfdn.plot import plot_edr, animate_coupled_feedback_matrix, plot_amps_in_space, plot_edc_error_in_space, plot_magnitude_response
from diff_gfdn.analysis import get_decay_fit_net_params
from diff_gfdn.colorless_fdn.utils import get_colorless_fdn_params
from diff_gfdn.colorless_fdn.model import ColorlessFDN
from slope2noise.utils import schroeder_backward_int
from src.run_model import load_and_validate_config

In [None]:
config_path = 'data/config/'
fig_path = 'figures/'
config_name = 'treble_data_grid_training_8000Hz_colorless_loss.yml'
config_file = config_path + config_name
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)

In [None]:
if "3room_FDTD" in config_dict.room_dataset_path:
    room_data = ThreeRoomDataset(Path(config_dict.room_dataset_path).resolve(), config_dict)
    dataset_has_cs_params = True
else:
    room_data = convert_common_slopes_rir_to_room_dataset(config_dict.room_dataset_path, 
                                                          num_freq_bins=config_dict.trainer_config.num_freq_bins,
                                                          )
    dataset_has_cs_params = True


config_dict = config_dict.model_copy(update={"num_groups": room_data.num_rooms})

trainer_config = config_dict.trainer_config
# force the trainer config device to be CPU
if trainer_config.device != 'cpu':
    trainer_config = trainer_config.model_copy(update={"device": 'cpu'})

# prepare the training and validation data for DiffGFDN
train_dataset, valid_dataset = load_dataset(
    room_data, trainer_config.device, train_valid_split_ratio=1.0,
    batch_size=trainer_config.batch_size, shuffle=False)

# get the colorless FDN params
if config_dict.colorless_fdn_config.use_colorless_prototype:
    colorless_fdn_params = get_colorless_fdn_params(config_dict)
else:
    colorless_fdn_params = None

# initialise the model
model = DiffGFDNVarReceiverPos(config_dict.sample_rate,
                 config_dict.num_groups,
                 config_dict.delay_length_samps,
                 trainer_config.device,
                 config_dict.feedback_loop_config,
                 config_dict.output_filter_config,
                 config_dict.decay_filter_config.use_absorption_filters,
                 common_decay_times=room_data.common_decay_times if config_dict.decay_filter_config.initialise_with_opt_values else None,
                 learn_common_decay_times=config_dict.decay_filter_config.learn_common_decay_times,
                 colorless_fdn_params=colorless_fdn_params,
                 use_colorless_loss=trainer_config.use_colorless_loss
                 )

In [None]:
audio_directory  = Path("audio/")
fig_path = Path("figures").resolve()
checkpoint_dir = Path(trainer_config.train_dir + 'checkpoints/').resolve()
max_epochs = trainer_config.max_epochs
plot_ir = False
use_fixed_pos = True
if use_fixed_pos:
    pos_to_investigate = [9.3, 6.6, 1.5] #[2.00, 6.8, 1.5] #[6.4, 3.8, 1.5] #[2.00, 6.8, 1.5] #[0.20, 2.90, 1.50] # 
else:
    rec_idx = np.random.randint(0, high=room_data.num_rec, size=1, dtype=int)
    pos_to_investigate = np.round(np.squeeze(room_data.receiver_position[rec_idx,:]), 2)
desired_filename = f'ir_({pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f}).wav'

# find amplitudes corresponding to the receiver position
rec_pos_idx = np.argwhere(
    np.all(np.round(room_data.receiver_position,2) == pos_to_investigate, axis=1))[0]
amps_at_pos = np.squeeze(room_data.amplitudes[rec_pos_idx, :])
      
h_true = np.squeeze(room_data.rirs[rec_pos_idx, :])

In [None]:
### function to normalise the RIR by the filter's energy while plotting
def get_norm_factor(subband_process_config: SubbandProcessingConfig, fs):
    subband_filters, subband_freqs = pf.dsp.filter.reconstructing_fractional_octave_bands(
                None,
                num_fractions=subband_process_config.num_fraction_octaves,
                frequency_range=subband_process_config.frequency_range,
                sampling_rate=fs,
            )
    subband_filter_idx = np.argmin(
        np.abs(subband_freqs - subband_process_config.centre_frequency))
    norm_factor = np.sqrt(np.sum(np.power(subband_filters.coefficients[subband_filter_idx, :], 2)))
    return norm_factor

### Iterate through epochs

In [None]:
h_approx_list = []
output_gains = []
input_gains = []
input_scalars = []
output_scalars = []
coupled_feedback_matrix = []
coupling_matrix = []
all_pos = [np.empty((room_data.num_rec, 3)) for i in range(-1, max_epochs)]
all_rirs = [np.empty((room_data.num_rec, room_data.num_freq_bins)) for i in range(-1, max_epochs)]
all_output_scalars = [np.empty((room_data.num_rec, room_data.num_rooms)) for i in range (-1, max_epochs)]
use_direct_cs_params = False

def find_listener_pos_in_room_data(list_pos: NDArray, room_data: RoomDataset) -> ArrayLike:
    """Return the indices of list_pos found in room_data.receiver_position"""
    index = np.full(len(list_pos), -1, dtype=np.int32)  # Default to -1 for non-matches

    for i, pos in enumerate(list_pos):
        match = np.where((room_data.receiver_position == pos).all(axis=-1))[0]
        if match.size > 0:
            index[i] = match[0]  # Take the first match if multiple exist
    return index
    
for epoch in [max_epochs-1]:
    # load the trained weights for the particular epoch
    checkpoint = torch.load(f'{checkpoint_dir}/model_e{epoch}.pt', weights_only=True, map_location=torch.device('cpu'))
    # Load the trained model state
    model.load_state_dict(checkpoint, strict=False)
    # in eval mode, no gradients are calculated
    model.eval()
    break_outer_loop = False
    npos = 0

    with torch.no_grad():
        param_dict = model.get_param_dict()
        input_gains.append(deepcopy(param_dict['input_gains']))
        input_scalars.append(deepcopy(param_dict['input_scalars']))
        output_gains.append(deepcopy(param_dict['output_gains']))
        coupled_feedback_matrix.append(deepcopy(param_dict['coupled_feedback_matrix']))
        coupling_matrix.append(deepcopy(param_dict['coupling_matrix']))
    
        for data in train_dataset:
            position = data['listener_position']

            if use_direct_cs_params:
                # try using CS amps as receiver gains
                pos_idxs = find_listener_pos_in_room_data(position, room_data)
                cs_output_scalars = np.sqrt(room_data.amplitudes[pos_idxs, :])
                
                if trainer_config.use_colorless_loss:
                    H, H_sub_fdn, h = get_response(data, model, torch.tensor(cs_output_scalars))
                else:
                    H, h = get_response(data, model, torch.tensor(cs_output_scalars))
            else:
                if trainer_config.use_colorless_loss:
                    H, H_sub_fdn, h = get_response(data, model)
                else:
                    H, h = get_response(data, model)
                
            # this needs to be added to compensate for subband filter energy
            h *= get_norm_factor(trainer_config.subband_process_config, model.sample_rate)
                
            # get parameter dictionary used in inferencing
            inf_param_dict = model.get_param_dict_inference(data)
            
            for num_pos in range(position.shape[0]):
                all_pos[epoch+1][npos, :] = position[num_pos]
                all_rirs[epoch+1][npos, :] = h[num_pos, :]
                all_output_scalars[epoch+1][npos, :] = deepcopy(inf_param_dict['output_scalars'][num_pos])
                npos += 1
                filename = f'ir_({position[num_pos,0]:.2f}, {position[num_pos, 1]:.2f}, {position[num_pos, 2]:.2f}).wav'
                
                if filename == desired_filename:
                   
                    # get the ir at this position
                    h_approx_list.append(h[num_pos, :])
    
                    # get the gains for this position
                    if 'output_scalars' in inf_param_dict.keys():
                        output_scalars.append(deepcopy(inf_param_dict['output_scalars'][num_pos]))

                    if plot_ir:
                        # plot the EDRs of the true and estimated
                        plot_edr(torch.tensor(h_true), model.sample_rate, title=f'True RIR EDR, epoch={epoch}', 
                                 save_path=f'{fig_path}/true_edr_{filename}_{config_name}_epoch={epoch}.png')
        
                        plot_edr(h[num_pos, :], model.sample_rate, title=f'Estimated RIR EDR, epoch={epoch}', 
                                 save_path=f'{fig_path}/approx_edr_{filename}_{config_name}_epoch={epoch}.png')

                    # not breaking the loop to collect all the RIRs
                    # break_outer_loop = True
                    # break
     

#### Plot the EDCs as a function of epoch number

In [None]:
def plot_edc(h_true: ArrayLike, h_approx: List[ArrayLike], fs: float, pos_to_investigate: List, amps_at_pos: List, mixing_time_ms:float=20.0):
    """Plot true and synthesised EDC curves"""
    
    mixing_time_samp = ms_to_samps(mixing_time_ms, fs)
    crop_end_samp = ms_to_samps(5.0, fs)
    
    trunc_true_ir = h_true[mixing_time_samp:-crop_end_samp]
    true_edf = schroeder_backward_int(trunc_true_ir)
    time = np.linspace(0, (len(trunc_true_ir) - 1) / fs,
                       len(trunc_true_ir))
  
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(time, db(true_edf, is_squared=True), label='True EDF')
    ax.plot(np.zeros(len(amps_at_pos)), db(amps_at_pos, is_squared=True), 'kx')       
    ax.set_title(
        f'Truncated EDF at position {pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f} m'
    )
    
    num_epochs = len(h_approx)
    for epoch in range(0, num_epochs, 3):
        approx_ir = h_approx[epoch]
        trunc_approx_ir = approx_ir[mixing_time_samp: mixing_time_samp + len(trunc_true_ir)]
        synth_edf = schroeder_backward_int(trunc_approx_ir.detach().numpy())
        ax.plot(time, db(synth_edf, is_squared=True), label=f'Epoch={epoch}')
        ax.legend()
        
        display.display(fig)  # Display the updated figure
        display.clear_output(wait=True)  # Clear the previous output to keep updates in place
        plt.pause(0.1)

    fig.savefig(Path(f'{fig_path}/compare_synth_edf_{pos_to_investigate}_{config_name}.png').resolve())
    plt.show()

# plot_edc(h_true[:len(h_approx_list[-1])], h_approx_list, config_dict.sample_rate, pos_to_investigate, amps_at_pos)

#### Plot the desired and final EDC

In [None]:
def get_edc_params(rir: ArrayLike, n_slopes: int, fs:float):
    est_params_decay_net, norm_vals, fitted_edc_subband = get_decay_fit_net_params(rir, None, n_slopes, fs)
    est_t60 = np.mean(est_params_decay_net[0], axis=0)
    est_amp = np.mean(est_params_decay_net[1], axis=0)
    est_noise = np.mean(est_params_decay_net[2], axis=0)
    fitted_edc = torch.squeeze(torch.mean(fitted_edc_subband, dim=0))
    return est_t60, est_amp, est_noise, fitted_edc

def plot_final_decay_fit_net_edc(num_groups:int, fs: float, og_amps: List, est_amps: List, 
                                 og_edc: ArrayLike, synth_edc:ArrayLike, 
                                 src_pos:Optional[List]=None, rec_pos: Optional[List]=None):
    
    time = np.linspace(0, (len(og_edc) - 1) / fs, len(og_edc))
    fig, ax = plt.subplots(figsize=(6, 4))

    ax.plot(time, db(og_edc, is_squared=True), linestyle='-', label=f'Original estimated EDC (norm)')
    ax.plot(time[:len(synth_edc)], db(synth_edc, is_squared=True), linestyle='--', label=f'Final estimated EDC (norm)')
    ax.plot(np.zeros(num_groups), db(og_amps, is_squared=True), 'kx', label='Original amps')
    ax.plot(np.zeros(num_groups), db(est_amps, is_squared=True), 'gd', label='Synth amps')

    if src_pos is not None and rec_pos is not None:
        ax.set_title(f"Source pos = {np.round(src_pos, 2)}, receiver pos = {np.round(rec_pos, 2)}")

    ax.legend()
    plt.show()

mixing_time_samp = ms_to_samps(20.0, config_dict.sample_rate)
crop_end_samp = ms_to_samps(5.0, config_dict.sample_rate)
trunc_true_ir = h_true[mixing_time_samp:-crop_end_samp]
trunc_approx_ir = h_approx_list[-1][mixing_time_samp:len(trunc_true_ir)]
og_est_t60, og_est_amp, og_noise_floor, og_fitted_edc = get_edc_params(trunc_true_ir, config_dict.num_groups, config_dict.sample_rate)
est_t60, est_amp, _, fitted_edc = get_edc_params(trunc_approx_ir, config_dict.num_groups, config_dict.sample_rate)

plot_final_decay_fit_net_edc(config_dict.num_groups, config_dict.sample_rate, og_est_amp, est_amp, og_fitted_edc, fitted_edc)

print(db(og_est_amp, is_squared=True))
print(db(est_amp, is_squared=True))

#### Plot the learnt feedback matrix and input gains for the grid, plot the output gains for a single position

In [None]:
from diff_gfdn.plot import animate_coupled_feedback_matrix

if config_dict.feedback_loop_config.coupling_matrix_type == CouplingMatrixType.SCALAR:
    animate_coupled_feedback_matrix(np.abs(coupled_feedback_matrix), np.abs(coupling_matrix), 
                                    save_path=Path(f'{fig_path}/animation/{config_name}_{pos_to_investigate}_scalar_coupling_matrix.gif').resolve())
else:
    animate_coupled_feedback_matrix(np.abs(coupled_feedback_matrix), 
                                    save_path=Path(f'{fig_path}/animation/{config_name}_{pos_to_investigate}_random_coupling_matrix.gif').resolve())

# also save the final optimised matrix
if config_dict.feedback_loop_config.coupling_matrix_type == CouplingMatrixType.SCALAR:
    # assert is_unitary(torch.from_numpy(coupled_feedback_matrix[-1]))[0]    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(4, 6))

    # Coupling matrix subplot
    cax1 = ax1.matshow(np.abs(coupling_matrix[-1]))
    fig.colorbar(cax1, ax=ax1)
    ax1.set_title('Coupling matrix')
    
    # Coupled feedback matrix subplot
    cax2 = ax2.matshow(np.abs(coupled_feedback_matrix[-1]))
    fig.colorbar(cax2, ax=ax2)
    ax2.set_title('Coupled feedback matrix')
    
    plt.tight_layout()
    plt.savefig(f'{fig_path}/{config_name}_scalar_coupling_matrix.png', bbox_inches='tight')
else:
    unit_flag, max_val = is_unitary(torch.tensor(coupled_feedback_matrix[-1]), max_tol=1e-4)
    assert unit_flag
    plt.figure()
    plt.matshow(np.abs(coupled_feedback_matrix[-1]))
    plt.title('Optimised feedback matrix')
    plt.savefig(f'{fig_path}/{config_name}_random_coupling_matrix.png')
    

In [None]:
# Convert the list of vectors to a 2D numpy array (N x M matrix)
input_gain_matrix = np.stack([vec for vec in input_scalars]).T
output_gain_matrix = np.stack([vec for vec in output_scalars]).T

# Plot the matrix
fig, ax = plt.subplots(2, figsize=(6,6))
in_plot = ax[0].matshow(np.abs(input_gain_matrix), aspect='auto', cmap='viridis')
fig.colorbar(in_plot, label="dB", ax=ax[0])
ax[0].set_ylabel("Group number")
ax[0].set_xlabel("Epoch number")
ax[0].set_title("Source gains vs epoch")

out_plot = ax[1].matshow(np.abs(output_gain_matrix), aspect='auto', cmap='viridis')
fig.colorbar(out_plot, label="dB", ax=ax[1])
ax[1].set_ylabel("Group number")
ax[1].set_xlabel("Epoch number")
ax[1].set_title("Receiver gains vs epoch")
fig.subplots_adjust(hspace=0.5)
plt.show()
fig.savefig(Path(f'{fig_path}/{config_name}_{pos_to_investigate}_io_scalars.png').resolve())

b_matrix = np.stack([vec for vec in input_gains]).T
c_matrix = np.stack([vec for vec in output_gains]).T

# Plot the matrix
fig, ax = plt.subplots(2, figsize=(6,6))
in_plot = ax[0].matshow(np.abs(b_matrix), aspect='auto', cmap='viridis')
fig.colorbar(in_plot, label="dB", ax=ax[0])
ax[0].set_ylabel("Delay line number")
ax[0].set_xlabel("Epoch number")
ax[0].set_title("Input gains vs epoch")
logger.info(f"B energy across epochs {np.linalg.norm(b_matrix, axis=0)}")

out_plot = ax[1].matshow(np.abs(c_matrix), aspect='auto', cmap='viridis')
fig.colorbar(out_plot, label="dB", ax=ax[1])
ax[1].set_ylabel("Delay line number")
ax[1].set_xlabel("Epoch number")
ax[1].set_title("Output gains vs epoch")
fig.subplots_adjust(hspace=0.5)
logger.info(f"C energy across epochs {np.linalg.norm(c_matrix, axis=0)}")
plt.show()
fig.savefig(Path(f'{fig_path}/{config_name}_{pos_to_investigate}_bc_gains.png').resolve())

#### Plot magnitude response of each sub-FDN to inspect colouration

In [None]:
import diff_gfdn
from importlib import reload
reload(diff_gfdn.utils)
reload(diff_gfdn.plot)
from diff_gfdn.plot import plot_magnitude_response

save_path = f'{fig_path}/{config_name}_mag_spectrum.png'
plot_magnitude_response(room_data, config_dict, model, save_path)

if config_dict.decay_filter_config.learn_common_decay_times:
    logger.info(f'CS common decay times : {room_data.common_decay_times * 1e3} ms')
    logger.info(f'Learned common decay times: {model.feedback_loop.common_decay_times.detach() * 1e3} ms')

### Plot EDC error for each RIR as a position of space

In [None]:
from importlib import reload
import diff_gfdn
import slope2noise
reload(diff_gfdn.plot)
reload(slope2noise.rooms)
from diff_gfdn.plot import plot_amps_in_space, plot_edc_error_in_space

plot_edc_error_in_space(room_data, all_rirs[-1], all_pos[-1], freq_to_plot=None, scatter=False, 
                        save_path=f'{fig_path}/{config_name}', 
                        norm_edc=False)

#### Plot EDC error with common slopes model and LS estimation

In [None]:
import slope2noise
reload(slope2noise.utils)
from slope2noise.generate import shaped_wgn
from slope2noise.utils import calculate_amplitudes_least_squares, schroeder_backward_int, octave_filtering

# whether to plot the normalised common slope parameters from Georg's dataset
plot_norm_cs = False

t_vals_expanded = np.tile(np.squeeze(room_data.common_decay_times.T), (room_data.num_rec, 1))

common_slope_amps = room_data.amplitudes if dataset_has_cs_params else np.squeeze(
                                        calculate_amplitudes_least_squares(t_vals_expanded[..., np.newaxis], 
                                        room_data.sample_rate, 
                                        room_data.rirs[..., np.newaxis]))
common_slope_noise = room_data.noise_floor if dataset_has_cs_params else None



_, ls_est_rirs = shaped_wgn(t_vals_expanded, common_slope_amps, 
                            room_data.sample_rate, room_data.rir_length, 
                            # n_vals=common_slope_noise
                           )

if plot_norm_cs:
    cs_amps_norm = room_data.amplitudes_norm
    cs_noise_norm = room_data.noise_floor_norm
    _, ls_est_rirs_norm = shaped_wgn(t_vals_expanded, 
                                     cs_amps_norm, 
                                     room_data.sample_rate, room_data.rir_length, 
                                     # n_vals=cs_noise_norm
                                    )

In [None]:
plot_edc_error_in_space(room_data, ls_est_rirs, room_data.receiver_position, freq_to_plot=None, scatter=False, 
                        save_path=f'{fig_path}/{config_name}_common_slopes_model', pos_sorted=True, norm_edc=False)

#### Compare EDC from both models at a single position

In [None]:
fs = room_data.sample_rate
mixing_time_samp = ms_to_samps(20.0, fs)
crop_end_samp = ms_to_samps(5.0, fs)
len_ir = min(len(h_true), len(h_approx_list[-1]))
h_true = h_true[:len_ir]
h_approx_final = h_approx_list[-1][:len_ir]
h_approx_cs = np.squeeze(ls_est_rirs[rec_pos_idx, :len_ir])
norm_flag = False

trunc_true_ir = h_true[mixing_time_samp:-crop_end_samp]
true_edf = schroeder_backward_int(trunc_true_ir, normalize=norm_flag)
time = np.linspace(0, (len(trunc_true_ir) - 1) / fs,
                   len(trunc_true_ir))

fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(time, db(true_edf, is_squared=True), label='True EDF')
ax.set_title(
    f'Truncated EDF at position {pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f} m'
)

trunc_approx_ir = h_approx_final[mixing_time_samp: mixing_time_samp + len(trunc_true_ir)]
synth_edf = schroeder_backward_int(trunc_approx_ir.detach().numpy(), normalize=norm_flag)
ax.plot(time, db(synth_edf, is_squared=True), label='DiffGFDN')

trunc_approx_ir_common_slopes = h_approx_cs[mixing_time_samp: mixing_time_samp + len(trunc_true_ir)]
synth_edf_common_slopes = schroeder_backward_int(trunc_approx_ir_common_slopes, normalize=norm_flag)
ax.plot(time, db(synth_edf_common_slopes, is_squared=True), label='Common slopes model')
# ax.plot(np.zeros(room_data.num_rooms), db(amps_at_pos), '*', label='CS amps')

if plot_norm_cs:
    h_approx_cs_norm = np.squeeze(ls_est_rirs_norm[rec_pos_idx, :len_ir])
    trunc_approx_ir_common_slopes_norm = h_approx_cs_norm[mixing_time_samp: mixing_time_samp + len(trunc_true_ir)]
    synth_edf_common_slopes_norm = schroeder_backward_int(trunc_approx_ir_common_slopes_norm, normalize=norm_flag)
    ax.plot(time, db(synth_edf_common_slopes_norm, is_squared=True), label='Common slopes model Georg params')
    # ax.plot(np.zeros(room_data.num_rooms), db(np.squeeze(room_data.amplitudes_norm[rec_pos_idx, :])), 'x', label='CS amps Georg')

ax.legend()
# ax.set_ylim([-70, -10])
fig.savefig(Path(f'{fig_path}/compare_edf_{pos_to_investigate}_{config_name}_common_slopes.png').resolve())

#### Plot amplitude distribution from common slopes and DiffGFDN

In [None]:
import slope2noise
reload(slope2noise.rooms)
from slope2noise.rooms import RoomGeometry
room = RoomGeometry(room_data.sample_rate,
                    room_data.num_rooms,
                    np.array(room_data.room_dims),
                    np.array(room_data.room_start_coord),
                    aperture_coords=room_data.aperture_coords)

for epoch in [max_epochs-1]:
    plot_og_amps = True if epoch == -1 or epoch == max_epochs-1 else False
    try:
        # plot the estimated amplitudes as a function of spatial position
        est_amps = plot_amps_in_space(room_data, all_rirs[epoch+1], all_pos[epoch+1], freq_to_plot=None, scatter=False, 
                                      save_path=f'{fig_path}/{config_name}_epoch={epoch}', 
                                      plot_original_amps=plot_og_amps, plot_amp_error=False)
        
       
    except Exception:
        logger.warning(f'Error in LS estimation for epoch={epoch}')
        continue

In [None]:
for epoch in [max_epochs-1]:
    
    # plot the receiver gains as a function of spatial position
    room.plot_amps_at_receiver_points(
        all_pos[epoch+1],
        np.squeeze(np.array(room_data.source_position)),
        all_output_scalars[epoch+1].T,
        scatter_plot=False,
        save_path=Path(f'{fig_path}/{config_name}_epoch={epoch}_learnt_receiver_gains_in_space.png'
                       ).resolve())
