In [None]:
import torch
import numpy as np
import pandas as pd
from pathlib import Path
import pickle
from copy import deepcopy
from loguru import logger
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
from numpy.typing import ArrayLike, NDArray
from typing import List
from scipy.fft import rfft, irfft
import spaudiopy as spa
from tqdm import tqdm


import os
os.chdir('../..')
from spatial_sampling.dataloader import parse_three_room_data, SpatialRoomDataset, load_dataset
from src.sofa_parser import HRIRSOFAReader, SRIRSOFAWriter, convert_srir_to_brir
from src.sound_examples import binaural_dynamic_rendering
from src.dataclass import NAFDatasetUnpickler, NAFDatasetTrain, NAFDatasetInfer

from diff_gfdn.config.config import DiffGFDNConfig
from diff_gfdn.inference import infer_all_octave_bands_directional_fdn
from diff_gfdn.plot import plot_edc_error_in_space, plot_edr_error_in_space, order_position_matrices
from diff_gfdn.utils import ms_to_samps
from diff_gfdn.config.config_loader import load_and_validate_config

In [None]:
out_path = 'output/directional_fdn/'
room_data_pkl_path = Path('resources/Georg_3room_FDTD/srirs_spatial.pkl').resolve()
config_path = Path('data/config/directional_fdn/').resolve()
fig_path = Path('figures/directional_fdn').resolve()
save_path = Path('resources/Georg_3room_FDTD').resolve()

In [None]:
freqs_list = [63, 125, 250, 500, 1000, 2000, 4000, 8000]
grid_resolution_m = 0.9
config_dicts = []

for k in range(len(freqs_list)):
    config_name = f'/treble_data_grid_training_{freqs_list[k]}Hz_directional_fdn_grid_res={grid_resolution_m:.1f}m.yml'
    cur_config_dict = load_and_validate_config(str(config_path) + config_name, DiffGFDNConfig)
    config_dicts.append(cur_config_dict)
    
hrtf_path = Path('resources/HRTF/48kHz/KEMAR_Knowl_EarSim_SmallEars_FreeFieldComp_48kHz.sofa')

# get the original dataset
room_data = parse_three_room_data(room_data_pkl_path)

# get the HRTF
hrtf_reader = HRIRSOFAReader(hrtf_path)

if hrtf_reader.fs != room_data.sample_rate:
    logger.info(
            f"Resampling HRTFs to {room_data.sample_rate:.0f} Hz")
    hrtf_reader.resample_hrirs(room_data.sample_rate)

### Get BRIRs from trained DFDNs / NAFs

In [None]:
# these are the only directions in NAF
head_orientations = np.zeros((4, 2))
head_orientations[:, 0] = np.array([0, 90, 180, 270])
num_ori = head_orientations.shape[0]
num_ears = 2
error_edc = {}
error_edr = {}
leave_out_samps = ms_to_samps(5,room_data.sample_rate)
mixing_time_samps = ms_to_samps(room_data.mixing_time_ms, room_data.sample_rate)
trunc_at = ms_to_samps(2000, room_data.sample_rate)

method = 'dfdn'
logger.info(f"Creating BRIRs for spacing = {grid_resolution_m:.1f}m")
if method == 'dfdn':
    brir_pkl_path = f'{save_path}/diff_dfdn_pred_brirs_test_pos_only_grid_spacing={grid_resolution_m:.1f}m.pkl'
elif method == 'naf':
    brir_pkl_path = f'{save_path}/naf_dataset_infer_grid_spacing={grid_resolution_m:.1f}m.pkl'
input_pkl_path = f'{save_path}/naf_dataset_grid_spacing={grid_resolution_m:.1f}m.pkl'
cur_key = f'diff_dfdn_grid_spacing={grid_resolution_m:.1f}'
error_edc[cur_key] = np.zeros((num_ori, num_ears))
error_edr[cur_key] = np.zeros((num_ori, num_ears))


with open(input_pkl_path, "rb") as f:
    ref_naf_dataset = NAFDatasetUnpickler(f).load()
infer_pos_list = ref_naf_dataset.infer_receiver_pos

In [None]:
if os.path.exists(brir_pkl_path):
    logger.info(f'BRIRs already saved for grid spacing = {grid_resolution_m:.1f}m')
    with open(brir_pkl_path, "rb") as f:
        brir_dataset = NAFDatasetUnpickler(f).load()
    pred_brirs = brir_dataset.infer_brirs
    pred_cs_room_data = deepcopy(room_data)
    pred_cs_room_data.update_receiver_pos(infer_pos_list)
else:
    save_pkl_path = f'{out_path}/pred_ambi_rirs_test_pos_only_grid_res={grid_resolution_m:.1f}m/'
    pred_cs_room_data = infer_all_octave_bands_directional_fdn(freqs_list,                                                        
                                                               config_dicts, 
                                                               save_pkl_path, 
                                                               room_data, 
                                                               rec_pos_list=ref_naf_dataset.infer_receiver_pos,
                                                              )
  

    pred_brirs = convert_srir_to_brir(pred_cs_room_data.rirs, hrtf_reader, head_orientations)
    # ensure the position ordering is correct
    logger.info("Ordering position data...")
    correct_pos_order = order_position_matrices(ref_naf_dataset.infer_receiver_pos, pred_cs_room_data.receiver_position)
    pred_brirs = pred_brirs[correct_pos_order, ...]
    # save dataset
    dfdn_brir_dataset = NAFDatasetInfer(head_orientations[:, 0],
                                       ref_naf_dataset.num_infer_receivers,
                                       ref_naf_dataset.infer_receiver_pos,
                                       gt_brirs = ref_naf_dataset.infer_brirs,
                                       infer_brirs = pred_brirs
                                       )
    with open(brir_pkl_path, "wb") as f:
        pickle.dump(dfdn_brir_dataset, f)

### Plot an example BRIR for testing

In [None]:
from slope2noise.utils import schroeder_backward_int
from diff_gfdn.utils import db
time_slice_idx = np.arange(mixing_time_samps, trunc_at-leave_out_samps, dtype=np.int32)

true_brir = ref_naf_dataset.infer_brirs[53, 0, time_slice_idx, :]
pred_brir = pred_brirs[53, 0, time_slice_idx, :]
true_brir_edc = db(schroeder_backward_int(true_brir, time_axis=-2, normalize=True), is_squared=True)
pred_brir_edc = db(schroeder_backward_int(pred_brir, time_axis=-2, normalize=True), is_squared=True)

plt.figure()
plt.plot(true_brir)
plt.figure()
plt.plot(pred_brir)
plt.figure()
plt.plot(true_brir_edc)
plt.plot(pred_brir_edc)

In [None]:
pos_to_investigate = [6.4, 3.2, 1.5]

# find amplitudes corresponding to the receiver position
rec_pos_idx = np.argwhere(
    np.all(np.round(room_data.receiver_position,2) == pos_to_investigate, axis=1))[0]
rec_pos_idx_pred = np.argwhere(
    np.all(np.round(pred_cs_room_data.receiver_position,2) == pos_to_investigate, axis=1))[0]

true_srir = room_data.rirs[rec_pos_idx, :, time_slice_idx]
pred_srir = pred_cs_room_data.rirs[rec_pos_idx_pred, :, time_slice_idx]
true_srir_edc = db(schroeder_backward_int(true_srir.copy(), time_axis=0, normalize=True), is_squared=True)
pred_srir_edc = db(schroeder_backward_int(pred_srir.copy(), time_axis=0, normalize=True), is_squared=True)

In [None]:
plt.figure()
plt.plot(true_srir_edc[:, 6])
plt.plot(pred_srir_edc[:, 6])

### Calculate overall EDC error

In [None]:
# now plot the EDC error per orientation and ear
time_slice_idx = np.arange(mixing_time_samps, mixing_time_samps + trunc_at-leave_out_samps, dtype=np.int32)

for ori in range(head_orientations.shape[0]):
    for ear in range(2):
        cur_room_data = deepcopy(room_data)
        cur_room_data.update_receiver_pos(infer_pos_list)
        cur_room_data.update_rirs(ref_naf_dataset.infer_brirs[:, ori, time_slice_idx, ear])
        cur_brirs = np.squeeze(pred_brirs[:, ori, time_slice_idx, ear])

        save_path_edc = f'{fig_path}/edc_error_{method}_brir_ori={int(head_orientations[ori, 0])}_ear={ear}_grid_spacing={grid_resolution_m:.1f}m.png'
        save_path_edr = f'{fig_path}/edr_error_{method}_brir_ori={int(head_orientations[ori, 0])}_ear={ear}_grid_spacing={grid_resolution_m:.1f}m.png'
        
        ## NOTE- IT IS VERY IMPORTANT TO SET NORM_EDC = TRUE.
        #This is because the MLPs were trained on EDCs created from the common slope parameters(this reduced the amount of
        #data to be loaded during training and sped it up considerably). Now, the common slope amps have been normalised
        #in the downloaded dataset. Therefore, the scale of the predicted and true EDC won't match unless we normalise the EDCs.
        
        err_edc = plot_edc_error_in_space(cur_room_data, cur_brirs, infer_pos_list, scatter=True, 
                                          pos_sorted=True, save_path=save_path_edc, norm_edc=True)
        error_edc[cur_key][ori, ear] = err_edc
        err_edr = plot_edr_error_in_space(cur_room_data, cur_brirs, infer_pos_list, 
                                          scatter=True, pos_sorted=True, save_path=save_path_edr)
        error_edr[cur_key][ori, ear] = err_edr