In [None]:
import torch
import numpy as np
from pathlib import Path
import pickle
from copy import deepcopy
from loguru import logger
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
from numpy.typing import ArrayLike, NDArray
from typing import List
from scipy.fft import rfft, irfft
import spaudiopy as spa
from tqdm import tqdm


import os
os.chdir('../..')
from spatial_sampling.dataloader import parse_three_room_data, SpatialRoomDataset, load_dataset
from spatial_sampling.config import SpatialSamplingConfig
from src.sofa_parser import HRIRSOFAReader, SRIRSOFAWriter, convert_srir_to_brir
from src.sound_examples import binaural_dynamic_rendering
from spatial_sampling.inference import get_ambisonic_rirs

from diff_gfdn.plot import plot_edc_error_in_space, plot_edr_error_in_space
from diff_gfdn.utils import ms_to_samps

from diff_gfdn.config.config_loader import load_and_validate_config
from src.dataclass import NAFDatasetUnpickler, NAFDatasetTrain, NAFDatasetInfer

In [None]:
out_path = 'output/spatial_sampling/sound_examples'
room_data_pkl_path = Path('resources/Georg_3room_FDTD/srirs_spatial.pkl').resolve()
config_path = Path('data/config/spatial_sampling/').resolve()
fig_path = Path('figures/spatial_sampling').resolve()
save_path = Path('resources/Georg_3room_FDTD').resolve()

In [None]:
config_file = f'{config_path}/treble_data_grid_training_1000Hz_directional_spatial_sampling_test.yml'
config_dict = load_and_validate_config(config_file,
                                       SpatialSamplingConfig)
hrtf_path = Path('resources/HRTF/48kHz/KEMAR_Knowl_EarSim_SmallEars_FreeFieldComp_48kHz.sofa')

# get the original dataset
room_data = parse_three_room_data(room_data_pkl_path)

# get the HRTF
hrtf_reader = HRIRSOFAReader(hrtf_path)

if hrtf_reader.fs != room_data.sample_rate:
    logger.info(
            f"Resampling HRTFs to {room_data.sample_rate:.0f} Hz")
    hrtf_reader.resample_hrirs(room_data.sample_rate)

### Get BRIRs from trained MLPs / NAFs

In [None]:
### get train dataset for different grid spacings
grid_resolution_m = np.arange(config_dict.num_grid_spacing, 0,
                                  -1) * room_data.grid_spacing_m
head_orientations = np.zeros((4, 2))
# these are the only directions in NAF
head_orientations[:, 0] = np.array([0, 90, 180, 270])
num_ori = head_orientations.shape[0]
num_ears = 2
error_edc = {}
error_edr = {}
leave_out_samps = ms_to_samps(5,room_data.sample_rate)
mixing_time_samps = ms_to_samps(room_data.mixing_time_ms, room_data.sample_rate)
trunc_at = ms_to_samps(2000, room_data.sample_rate)
k = 0 #range(config_dict.num_grid_spacing-1)

method = 'proposed'
logger.info(f"Creating BRIRs for spacing = {grid_resolution_m[k]:.1f}m")
if method == 'proposed':
    brir_pkl_path = f'{save_path}/mlp_pred_brirs_test_pos_only_grid_spacing={grid_resolution_m[k]:.1f}m.pkl'
elif method == 'naf':
    brir_pkl_path = f'{save_path}/naf_dataset_infer_grid_spacing={grid_resolution_m[k]:.1f}m.pkl'
input_pkl_path = f'{save_path}/naf_dataset_grid_spacing={grid_resolution_m[k]:.1f}m.pkl'
cur_key = f'mlp_grid_spacing={grid_resolution_m[k]:.1f}'
error_edc[cur_key] = np.zeros((num_ori, num_ears))
error_edr[cur_key] = np.zeros((num_ori, num_ears))


with open(input_pkl_path, "rb") as f:
    ref_naf_dataset = NAFDatasetUnpickler(f).load()
infer_pos_list = ref_naf_dataset.infer_receiver_pos

In [None]:
if os.path.exists(brir_pkl_path):
    logger.info(f'BRIRs already saved for grid spacing = {grid_resolution_m[k]:.1f}m')
    with open(brir_pkl_path, "rb") as f:
        brir_dataset = NAFDatasetUnpickler(f).load()
    pred_brirs = brir_dataset.infer_brirs
    pred_cs_room_data = deepcopy(room_data)
    pred_cs_room_data.update_receiver_pos(infer_pos_list)
else:

    pred_cs_room_data = get_ambisonic_rirs(infer_pos_list, room_data, 
                                       use_trained_model=True, config_path=config_path, grid_resolution_m=grid_resolution_m[k])
  
    pred_brirs = convert_srir_to_brir(pred_cs_room_data.rirs, hrtf_reader, head_orientations)
    mlp_brir_dataset = NAFDatasetInfer(head_orientations[:, 0],
                                     ref_naf_dataset.num_infer_receivers,
                                     ref_naf_dataset.infer_receiver_pos,
                                     gt_brirs = ref_naf_dataset.infer_brirs,
                                     infer_brirs = pred_brirs,
                                     )
    with open(brir_pkl_path, "wb") as f:
        pickle.dump(mlp_brir_dataset, f)

### Calculate overall EDC error

In [None]:
# now plot the EDC error per orientation and ear
time_slice_idx = np.arange(mixing_time_samps, mixing_time_samps + trunc_at-leave_out_samps, dtype=np.int32)

for ori in range(head_orientations.shape[0]):
    for ear in range(2):
        cur_room_data = deepcopy(room_data)
        cur_room_data.update_receiver_pos(infer_pos_list)
        cur_room_data.update_rirs(ref_naf_dataset.infer_brirs[:, ori, time_slice_idx, ear])
        cur_brirs = np.squeeze(pred_brirs[:, ori, :trunc_at-leave_out_samps, ear])

        save_path_edc = f'{fig_path}/edc_error_{method}_brir_ori={int(head_orientations[ori, 0])}_ear={ear}_grid_spacing={grid_resolution_m[k]:.1f}m.png'
        save_path_edr = f'{fig_path}/edr_error_{method}_brir_ori={int(head_orientations[ori, 0])}_ear={ear}_grid_spacing={grid_resolution_m[k]:.1f}m.png'
        
        ## NOTE- IT IS VERY IMPORTANT TO SET NORM_EDC = TRUE.
        #This is because the MLPs were trained on EDCs created from the common slope parameters(this reduced the amount of
        #data to be loaded during training and sped it up considerably). Now, the common slope amps have been normalised
        #in the downloaded dataset. Therefore, the scale of the predicted and true EDC won't match unless we normalise the EDCs.
        
        err_edc = plot_edc_error_in_space(cur_room_data, cur_brirs, infer_pos_list, scatter=True, 
                                          pos_sorted=True, save_path=save_path_edc, norm_edc=True)
        error_edc[cur_key][ori, ear] = err_edc
        err_edr = plot_edr_error_in_space(cur_room_data, cur_brirs, infer_pos_list, 
                                          scatter=True, pos_sorted=True, save_path=save_path_edr)
        error_edr[cur_key][ori, ear] = err_edr

### Calculate octave band EDC errors

In [None]:
from slope2noise.utils import octave_filtering

# now plot the EDC error per orientation and ear
f_bands = [63, 125, 250, 500, 1000, 2000, 4000, 8000]
band_error_edc = {}
band_error_edc[cur_key] = np.zeros((len(f_bands),num_ori, num_ears))
time_slice_idx = np.arange(mixing_time_samps, mixing_time_samps + trunc_at-leave_out_samps, dtype=np.int32)


for ori in range(num_ori):
    for ear in range(num_ears):
        cur_room_data = deepcopy(room_data)
        cur_room_data.update_receiver_pos(infer_pos_list)
        cur_ref_brirs = ref_naf_dataset.infer_brirs[:, ori, time_slice_idx, ear]
        cur_pred_brirs = np.squeeze(pred_brirs[:, ori, :trunc_at-leave_out_samps, ear])

        filtered_ref_brirs = octave_filtering(cur_ref_brirs, room_data.sample_rate, f_bands)
        filtered_pred_brirs = octave_filtering(cur_pred_brirs, room_data.sample_rate, f_bands)

        for b_idx in range(len(f_bands)):
            logger.info(f'Plotting EDC error for freq = {f_bands[b_idx]}Hz, azimuth = {head_orientations[ori, 0]:.1f}, ear = {ear}')

            cur_room_data.update_rirs(filtered_ref_brirs[..., b_idx])
            save_path_edc = f'{fig_path}/edc_error_{method}_brir_ori={int(head_orientations[ori, 0])}_ear={ear}_freq={f_bands[b_idx]}Hz_grid_spacing={grid_resolution_m[k]:.1f}m.png'
            save_path_edr = f'{fig_path}/edr_error_{method}_brir_ori={int(head_orientations[ori, 0])}_ear={ear}_freq={f_bands[b_idx]}Hz_grid_spacing={grid_resolution_m[k]:.1f}m.png'
            
            ## NOTE- IT IS VERY IMPORTANT TO SET NORM_EDC = TRUE.
            #This is because the MLPs were trained on EDCs created from the common slope parameters(this reduced the amount of
            #data to be loaded during training and sped it up considerably). Now, the common slope amps have been normalised
            #in the downloaded dataset. Therefore, the scale of the predicted and true EDC won't match unless we normalise the EDCs.
            
            err_edc = plot_edc_error_in_space(cur_room_data, filtered_pred_brirs[..., b_idx], infer_pos_list, scatter=True, 
                                              pos_sorted=True, save_path=save_path_edc, norm_edc=True)
            band_error_edc[cur_key][b_idx, ori, ear] = err_edc

### Print EDC errors

In [None]:
logger.info(f'Overall EDC errors for grid spacing = {grid_resolution_m[k]:.1f}m are {np.round(error_edc[cur_key], 3)}')

In [None]:
logger.info(f'Overall mean EDC error for grid spacing = {grid_resolution_m[k]:.1f}m are {np.round(np.mean(error_edc[cur_key]), 3)}')

In [None]:
logger.info(f'Bandwise BRIR EDC errors for grid spacing = {grid_resolution_m[k]:.1f}m are \n {np.round(np.mean(band_error_edc[cur_key], axis=1), 3)}')

### Plot for sanity check

In [None]:
from diff_gfdn.utils import ms_to_samps, db
from slope2noise.utils import schroeder_backward_int

edc_true = schroeder_backward_int(ref_naf_dataset.infer_brirs[52, ori, time_slice_idx, :], 
                                  time_axis=-2, normalize=True)
edc_pred = schroeder_backward_int(pred_brirs[52, ori, :trunc_at-leave_out_samps, :], time_axis=-2, normalize=True)

plt.figure()
plt.plot(db(edc_true, is_squared=True))
plt.plot(db(edc_pred, is_squared=True))
plt.xlabel('Time (samples)')
plt.ylabel('Magnitude (dB')
plt.title('EDC')
plt.legend(['Ref L', 'Ref R', 'Pred L', 'Pred R'])
