In [None]:
import torch
import numpy as np
import pandas as pd
import spaudiopy as sp
from pathlib import Path
import pickle
import soundfile as sf
from copy import deepcopy
from loguru import logger
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
from numpy.typing import ArrayLike, NDArray
from typing import List
from scipy.fft import rfft, irfft
import spaudiopy as spa
from tqdm import tqdm


import os
os.chdir('../..')
from spatial_sampling.dataloader import SpatialThreeRoomDataset, SpatialRoomDataset, load_dataset
from src.sofa_parser import HRIRSOFAReader, SRIRSOFAWriter, convert_srir_to_brir
from src.sound_examples import binaural_dynamic_rendering
from src.convert_mat_to_pkl_ambi import process_srirs
from src.dataclass import NAFDatasetUnpickler, NAFDatasetTrain, NAFDatasetInfer

from diff_gfdn.config.config import DiffGFDNConfig
from diff_gfdn.inference import infer_all_octave_bands_directional_fdn
from diff_gfdn.plot import plot_edc_error_in_space, plot_edr_error_in_space, order_position_matrices
from diff_gfdn.utils import ms_to_samps, db
from diff_gfdn.config.config_loader import load_and_validate_config

from slope2noise.utils import schroeder_backward_int
from slope2noise.generate import shaped_wgn

In [None]:
out_path = 'output/directional_fdn/'
room_data_pkl_path = Path('resources/Georg_3room_FDTD/srirs_spatial.pkl').resolve()
config_path = Path('data/config/directional_fdn/').resolve()
fig_path = Path('figures/directional_fdn').resolve()
save_path = Path('resources/Georg_3room_FDTD').resolve()

In [None]:
freqs_list = [63, 125, 250, 500, 1000, 2000, 4000, 8000]
grid_resolution_m = 0.6
config_dicts = []

for k in range(len(freqs_list)):
    config_name = f'/treble_data_grid_training_{freqs_list[k]}Hz_directional_fdn_grid_res={grid_resolution_m:.1f}m.yml'
    cur_config_dict = load_and_validate_config(str(config_path) + config_name, DiffGFDNConfig)
    config_dicts.append(cur_config_dict)
    
hrtf_path = Path('resources/HRTF/48kHz/KEMAR_Knowl_EarSim_SmallEars_FreeFieldComp_48kHz.sofa')

# get the original dataset
room_data = SpatialThreeRoomDataset(room_data_pkl_path)

# get the HRTF
hrtf_reader = HRIRSOFAReader(hrtf_path)

if hrtf_reader.fs != room_data.sample_rate:
    logger.info(
            f"Resampling HRTFs to {room_data.sample_rate:.0f} Hz")
    hrtf_reader.resample_hrirs(room_data.sample_rate)

### Get BRIRs from trained DFDNs / NAFs

In [None]:
# these are the only directions in NAF
head_orientations = np.zeros((4, 2))
head_orientations[:, 0] = np.array([0, 90, 180, 270])
num_ori = head_orientations.shape[0]
num_ears = 2
error_edc = {}
error_edr = {}
leave_out_samps = ms_to_samps(5,room_data.sample_rate)
mixing_time_samps = ms_to_samps(room_data.mixing_time_ms, room_data.sample_rate)
trunc_at = ms_to_samps(2000, room_data.sample_rate)

method = 'dfdn'
logger.info(f"Creating BRIRs for spacing = {grid_resolution_m:.1f}m")
if method == 'dfdn':
    brir_pkl_path = f'{save_path}/diff_dfdn_pred_brirs_test_pos_only_dir_gains_no_gain_comp_grid_spacing={grid_resolution_m:.1f}m.pkl'
elif method == 'naf':
    brir_pkl_path = f'{save_path}/naf_dataset_infer_grid_spacing={grid_resolution_m:.1f}m.pkl'

# ref_pkl_path = f'{save_path}/cs_pred_brirs_test_pos_only_grid_spacing={grid_resolution_m:.1f}m.pkl'
ref_pkl_path = f'{save_path}/naf_dataset_grid_spacing={grid_resolution_m:.1f}m.pkl'
cur_key = f'diff_dfdn_grid_spacing={grid_resolution_m:.1f}'
error_edc[cur_key] = np.zeros((num_ori, num_ears))

with open(ref_pkl_path, "rb") as f:
    ref_naf_dataset = NAFDatasetUnpickler(f).load()
infer_pos_list = ref_naf_dataset.infer_receiver_pos

In [None]:
save_pkl_path = f'{out_path}/pred_ambi_rirs_test_pos_only_no_gain_comp_grid_res={grid_resolution_m:.1f}m/'
freqs_list = room_data.band_centre_hz
pred_cs_room_data = infer_all_octave_bands_directional_fdn(freqs_list,                                                        
                                                           config_dicts, 
                                                           save_pkl_path, 
                                                           room_data, 
                                                           rec_pos_list=ref_naf_dataset.infer_receiver_pos,
                                                          )

In [None]:
if os.path.exists(brir_pkl_path):
    logger.info(f'BRIRs already saved for grid spacing = {grid_resolution_m:.1f}m')
    with open(brir_pkl_path, "rb") as f:
        brir_dataset = NAFDatasetUnpickler(f).load()
    pred_brirs = brir_dataset.infer_brirs
else:    
    pred_brirs = convert_srir_to_brir(pred_cs_room_data.rirs, hrtf_reader, head_orientations)
    # ensure the position ordering is correct
    logger.info("Ordering position data...")
    correct_pos_order = order_position_matrices(ref_naf_dataset.infer_receiver_pos, pred_cs_room_data.receiver_position)
    pred_brirs = pred_brirs[correct_pos_order, ...]
    # save dataset
    dfdn_brir_dataset = NAFDatasetInfer(head_orientations[:, 0],
                                       ref_naf_dataset.num_infer_receivers,
                                       ref_naf_dataset.infer_receiver_pos,
                                       gt_brirs = ref_naf_dataset.infer_brirs,
                                       infer_brirs = pred_brirs
                                       )
    with open(brir_pkl_path, "wb") as f:
        pickle.dump(dfdn_brir_dataset, f)

### Plot an example BRIR for testing

In [None]:
time_slice_idx = np.arange(mixing_time_samps, trunc_at-leave_out_samps, dtype=np.int32)

for ori in range(4):
    true_brir = ref_naf_dataset.infer_brirs[22, ori, time_slice_idx, :]
    pred_brir = pred_brirs[22, ori, time_slice_idx, :]
    true_brir_edc = db(schroeder_backward_int(true_brir, time_axis=-2, normalize=True), is_squared=True)
    pred_brir_edc = db(schroeder_backward_int(pred_brir, time_axis=-2, normalize=True), is_squared=True)
    
    plt.figure()
    plt.plot(true_brir_edc)
    plt.plot(pred_brir_edc)
    plt.xlabel('Time (samples)')
    plt.ylabel('Magnitude (dB')
    plt.title(f'EDC for orientation = {head_orientations[ori, 0]:.0f}')
    plt.legend(['Ref L', 'Ref R', 'Pred L', 'Pred R'])

# sf.write(Path(f'{out_path}sound_examples/true_brir_grid_res={grid_resolution_m}m.wav').resolve(), true_brir, room_data.sample_rate)
# sf.write(Path(f'{out_path}sound_examples/pred_brir_directional_fdn_grid_res={grid_resolution_m}m.wav').resolve(), pred_brir, room_data.sample_rate)

### Plot the ground truth, CS model predicted and DiffGFDN predicted SRIR EDCs at a particular position

In [None]:
def get_srir_envelopes_from_cs_model(room_data: SpatialRoomDataset, amps: NDArray):
    """
    For a particular position, generate SRIR EDCs predicted by the common slopes model
    """
    tvals_exp = np.repeat(room_data.common_decay_times,
                           room_data.num_directions,
                           axis=1).transpose(1, -1, 0)
    amps_reshape = amps.transpose(1, 0, -1)
    _, drir = shaped_wgn(tvals_exp, 
                      amps_reshape, 
                      room_data.sample_rate, 
                      room_data.rir_length,
                      f_bands=room_data.band_centre_hz)

    modal_weights = sp.sph.cardioid_modal_weights(room_data.ambi_order)
    [sph_an_matrix, sph_syn_matrix] = sp.sph.design_sph_filterbank(room_data.ambi_order,
                                      room_data.sph_directions[0, :],
                                      np.pi / 2 - room_data.sph_directions[1, :],
                                      modal_weights,
                                      mode='energy')
    srir =  np.einsum('kn, nt -> kt', sph_syn_matrix, drir)
    srir_envelopes = schroeder_backward_int(srir, normalize=False)
    return srir_envelopes

In [None]:
pos_to_investigate = [6.4, 3.8, 1.5]

# find amplitudes corresponding to the receiver position
rec_pos_idx = np.argwhere(
    np.all(np.round(room_data.receiver_position,2) == pos_to_investigate, axis=1))[0]
rec_pos_idx_pred = np.argwhere(
    np.all(np.round(pred_cs_room_data.receiver_position,2) == pos_to_investigate, axis=1))[0]
cur_amps = room_data.amplitudes[rec_pos_idx_pred].squeeze()

true_srir = room_data.rirs[rec_pos_idx, :, time_slice_idx].squeeze()
pred_drir = pred_cs_room_data.rirs[rec_pos_idx_pred, :, mixing_time_samps:]
pred_srir = process_srirs(pred_drir.transpose(1, -1, 0), room_data.ambi_order, np.rad2deg(room_data.sph_direction), mode='synthesis')
pred_srir = pred_srir.transpose(-1, 0, 1).squeeze()

true_srir_edc = db(schroeder_backward_int(true_srir.copy(), time_axis=0, normalize=False), is_squared=True)
pred_srir_edc = db(schroeder_backward_int(pred_srir.copy(), time_axis=-1, normalize=False), is_squared=True)
ref_srir_cs_edc = db(get_srir_envelopes_from_cs_model(room_data, cur_amps.copy())[..., :trunc_at-leave_out_samps], is_squared=True)

fig, ax = plt.subplots((room_data.ambi_order + 1)**2, 1, figsize=(8, 20), sharey=True)  # rows, cols
for j in range((room_data.ambi_order + 1)**2):
    ax[j].plot(true_srir_edc[:, j], label='GT')
    ax[j].plot(ref_srir_cs_edc[j, :], label='CS model')
    ax[j].plot(pred_srir_edc[j, :], label='DiffGFDN')
    ax[j].set_title(f'Channel = {j+1}')

ax[-1].set_xlabel('Time (samples)')
ax[-1].legend()
fig.text(0.04, 0.5, 'EDC (dB)', va='center', rotation='vertical')
# increase space between subplots
fig.subplots_adjust(hspace=1.5)  # increase vertical spacing

### Calculate overall EDC error

In [None]:
# now plot the EDC error per orientation and ear
time_slice_idx = np.arange(mixing_time_samps, mixing_time_samps + trunc_at-leave_out_samps, dtype=np.int32)

for ori in range(head_orientations.shape[0]):
    for ear in range(2):
        cur_room_data = deepcopy(room_data)
        cur_room_data.update_receiver_pos(infer_pos_list)
        cur_room_data.update_rirs(ref_naf_dataset.infer_brirs[:, ori, time_slice_idx, ear])
        cur_brirs = np.squeeze(pred_brirs[:, ori, time_slice_idx, ear])

        save_path_edc = f'{fig_path}/edc_error_{method}_brir_ori={int(head_orientations[ori, 0])}_ear={ear}_grid_spacing={grid_resolution_m:.1f}m.png'
        save_path_edr = f'{fig_path}/edr_error_{method}_brir_ori={int(head_orientations[ori, 0])}_ear={ear}_grid_spacing={grid_resolution_m:.1f}m.png'
        
        ## NOTE- IT IS VERY IMPORTANT TO SET NORM_EDC = TRUE.
        #This is because the MLPs were trained on EDCs created from the common slope parameters(this reduced the amount of
        #data to be loaded during training and sped it up considerably). Now, the common slope amps have been normalised
        #in the downloaded dataset. Therefore, the scale of the predicted and true EDC won't match unless we normalise the EDCs.
        
        err_edc = plot_edc_error_in_space(cur_room_data, cur_brirs, infer_pos_list, scatter=True, 
                                          pos_sorted=True, save_path=save_path_edc, norm_edc=True)
        error_edc[cur_key][ori, ear] = err_edc

In [None]:
logger.info(f'Overall EDC errors for grid spacing = {grid_resolution_m:.1f}m are {np.round(error_edc[cur_key], 3)}')