In [None]:
import numpy as np
import torch
import re
import os
import matplotlib.pyplot as plt

from pathlib import Path
from numpy.typing import ArrayLike, NDArray
from typing import Optional, List
from IPython import display
from loguru import logger
from copy import deepcopy
scale = 2
plt.rcParams.update({
    'font.size': scale * 8,  # base font size
    'axes.labelsize': scale * 9,  # x/y label
    'xtick.labelsize': scale * 8,
    'ytick.labelsize': scale * 8,
    'legend.fontsize': scale * 8,
    'axes.titlesize': scale * 10,  # usually unused in journal figures
})

os.chdir('../..')  # This changes the working directory to DiffGFDN

from diff_gfdn.config.config import DiffGFDNConfig
from diff_gfdn.config.config_loader import load_and_validate_config
from diff_gfdn.model import DiffDirectionalFDNVarReceiverPos
from diff_gfdn.inference import InferDiffDirectionalFDN
from diff_gfdn.colorless_fdn.utils import get_colorless_fdn_params

from spatial_sampling.dataloader import parse_room_data

In [None]:
config_path = 'data/config/directional_fdn/'
fig_path = 'figures/directional_fdn/'
freq_to_plot = 4000
config_name = f'treble_data_grid_training_{freq_to_plot}Hz_directional_fdn_grid_res=0.6m.yml'
config_file = config_path + config_name
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)
trainer_config = config_dict.trainer_config

#### Parse config file and instantiate model

In [None]:
if "3room_FDTD" in config_dict.room_dataset_path:
    room_data = parse_room_data(Path(config_dict.room_dataset_path).resolve())
else:
    logger.error("Other room data not supported currently")

config_dict = config_dict.model_copy(
        update={"num_groups": room_data.num_rooms})
assert config_dict.num_delay_lines % config_dict.num_groups == 0, "Delay lines must be \
    divisible by number of groups in network"

# update ambisonics order
config_dict = config_dict.model_copy(
    update={"ambi_order": room_data.ambi_order})

if config_dict.sample_rate != room_data.sample_rate:
    logger.warning("Config sample rate does not match data, alterning it")
    config_dict.sample_rate = room_data.sample_rate

# get the training config
trainer_config = config_dict.trainer_config
# update num_freq_bins in pydantic class
trainer_config = trainer_config.model_copy(
    update={"num_freq_bins": room_data.num_freq_bins})

if config_dict.colorless_fdn_config.use_colorless_prototype and trainer_config.use_colorless_loss:
    raise ValueError(
        "Cannot use optimised colorless FDN parameters and colorless FDN loss together"
    )

# are we using a colorless FDN to get the feedback matrix?
if config_dict.colorless_fdn_config.use_colorless_prototype:
    colorless_fdn_params = get_colorless_fdn_params(config_dict)
else:
    colorless_fdn_params = None

# initialise the model
model = DiffDirectionalFDNVarReceiverPos(
    room_data.sample_rate,
    room_data.num_rooms,
    config_dict.delay_length_samps,
    trainer_config.device,
    config_dict.feedback_loop_config,
    config_dict.output_filter_config,
    ambi_order=config_dict.ambi_order,
    desired_directions=room_data.sph_directions,
    common_decay_times=room_data.common_decay_times if
    config_dict.decay_filter_config.initialise_with_opt_values else None,
    band_centre_hz=room_data.band_centre_hz,
    colorless_fdn_params=colorless_fdn_params,
    use_colorless_loss=trainer_config.use_colorless_loss,
)


#### Instantiate inference and plotting object. Loop through epochs to plot EDC error in each epoch

In [None]:
num_epochs = trainer_config.max_epochs
plot_obj = InferDiffDirectionalFDN(room_data, config_dict, model, apply_filter_norm=True)

for epoch in [num_epochs]:
    logger.info(f'Plotting EDC error for epoch number {epoch}')
    # get the model output
    est_points, est_dir_rirs = plot_obj.get_model_output(epoch)
    # plot the EDC error
    original_edc, est_edc = plot_obj.plot_edc_error_in_space(est_dir_rirs.detach().cpu().numpy(), est_points.detach().cpu().numpy(), epoch)

#### Investigate EDC fit at a single position

In [None]:
pos = 43
fig, ax = plt.subplots(room_data.num_directions, 1, figsize=(8, 20), sharey=True)  # rows, cols

for j in range(room_data.num_directions):
    ax[j].plot(original_edc[pos, j, :])
    ax[j].plot(est_edc[pos, j, :])
    ax[j].set_title(f'Direction = {j+1}')

ax[-1].set_xlabel('Time (samples)')
fig.text(0.04, 0.5, 'EDC (dB)', va='center', rotation='vertical')
# increase space between subplots
fig.subplots_adjust(hspace=1.5)  # increase vertical spacing

### Investigate the learned output gains in the SH domain

In [None]:
all_learned_params = plot_obj.all_learned_params
learned_sh_gains = plot_obj.all_output_sh_gains

init_sh_gains = learned_sh_gains[0][pos]
final_sh_gains = learned_sh_gains[-1][pos]

In [None]:
print(init_sh_gains)
print(final_sh_gains)
print(room_data.amplitudes[pos, ...])