In [None]:
import torch
import numpy as np
from pathlib import Path
import pickle
from loguru import logger
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
from numpy.typing import ArrayLike, NDArray
from typing import List
from scipy.fftpack import rfft, irfft
import spaudiopy as spa
from tqdm import tqdm


import os
os.chdir('..')
from spatial_sampling.dataloader import parse_room_data, SpatialRoomDataset, load_dataset
from spatial_sampling.config import SpatialSamplingConfig
from src.sofa_parser import HRIRSOFAReader, SRIRSOFAWriter
from src.sound_examples import binaural_dynamic_rendering
from src.run_model import load_and_validate_config

### Notebook to convert coupled room dataset to NAF compatible

NAF takes in BRIRs at spatial locations in the room for head orientations [0, 90, 180, 270]. To train NAF with different subsets of receivers, like we do in the WASPAA paper, we create different training and inference dataset containing SRIRs at different locations, and then convert them to BRIRs for the 4 given orientations. Finally we save them in a pickle file.

In [None]:
out_path = 'output/spatial_sampling/sound_examples'
room_data_pkl_path = Path('resources/Georg_3room_FDTD/srirs_spatial.pkl').resolve()
config_path = Path('data/config/spatial_sampling/').resolve()
save_path = Path('resources/Georg_3room_FDTD').resolve()

In [None]:
@dataclass
class NAFDataset:
    num_train_receivers: int
    num_infer_receivers: int
    train_receiver_pos: NDArray #of shape num_training_receivers x 3
    infer_receiver_pos: NDArray #of shape num_infer_receivers x 3
    train_brirs: NDArray #of shape num_training_receivers x num_orientation x num_time_samples x num_ears
    infer_brirs: NDArray #of shape num_infer_receivers x num_orientation x num_time_samples  x num_ears 
    orientation: ArrayLike # of length 4

In [None]:
def find_rec_idx_in_room_dataset(room_data : SpatialRoomDataset, rec_pos_list: NDArray) -> List:
    """Indices of the receivers in the dataset associated with the list of receiver positions"""
    # Compute Euclidean distance between each array in array_list_np and every row in matrix

    distances = np.linalg.norm(room_data.receiver_position[:, None, :] -
                               rec_pos_list,
                               axis=2)
    indices = np.argmin(distances, axis=0)
    return indices

In [None]:
def convert_srir_to_brir(srirs:NDArray, sample_rate: float, hrtf_reader:HRIRSOFAReader, head_orientations: ArrayLike) -> NDArray:
    """
    Convert SRIRs to BRIRs for specific orientations
    Args:
        srirs (NDArray): SRIRs of shape num_pos x num_ambi_channels x num_time_samp
        sample_rate (float): sample rate of the SRIRs
        hrtf_reader (HRIRSOFAReader): for parsing SOFA file
        head_orientations (ArrayLike): head orientations of shape num_ori x  2
    Returns:
        BRIRs of shape num_pos x num_ori x num_time_samples x 2
    """
    ambi_order = int(np.sqrt(srirs.shape[1] - 1))
    num_receivers = srirs.shape[0]
    num_freq_bins = 2**int(np.ceil(np.log2(srirs.shape[-1])))
    
    # size is num_ambi_channels x num_receivers x num_time_samples
    hrir_sh = hrtf_reader.get_spherical_harmonic_representation(ambi_order)
    ambi_rtfs = rfft(srirs, num_freq_bins, axis=-1)

    # these are of shape num_ambi_channels x 2 x num_freq_samples
    ambi_hrtfs = rfft(hrir_sh, n=num_freq_bins, axis=-1)
    logger.info("Done calculating FFTs")

    num_orientations = head_orientations.shape[0]
    brirs = np.zeros((num_receivers, num_orientations, num_freq_bins, 2))

    for rec_pos_idx in tqdm(range(num_receivers)):
        # shape is num_ambi_channels x num_freqs
        cur_ambi_rtf = ambi_rtfs[rec_pos_idx, ...]

        for ori_idx in range(num_orientations):
            cur_head_orientation = head_orientations[ori_idx, :]
    
            #rotate the soundfield in the opposite direction - size num_freq_bins x num_ambi_channels
            cur_rotation_matrix = spa.sph.sh_rotation_matrix(
                ambi_order,
                -cur_head_orientation[0],
                -cur_head_orientation[1],
                0,
                sh_type='real')
    
            rotated_ambi_rtf = cur_ambi_rtf.T @ cur_rotation_matrix.T
    
            # get the binaural room transfer function
            cur_brtf = np.einsum('nrf, fn -> fr', np.conj(ambi_hrtfs),
                                 rotated_ambi_rtf)
            # get the BRIR
            cur_brir = irfft(cur_brtf, n=num_freq_bins, axis=0)
            brirs[rec_pos_idx, ori_idx, ...] = cur_brir
    
    return brirs
        

### Get the true room dataset for different grid spacings

In [None]:
config_file = f'{config_path}/treble_data_grid_training_1000Hz_directional_spatial_sampling_test.yml'
config_dict = load_and_validate_config(config_file,
                                       SpatialSamplingConfig)
hrtf_path = Path('resources/HRTF/48kHz/KEMAR_Knowl_EarSim_SmallEars_FreeFieldComp_48kHz.sofa')

# get the original dataset
room_data = parse_room_data(room_data_pkl_path)

# get the HRTF
hrtf_reader = HRIRSOFAReader(hrtf_path)

if hrtf_reader.fs != room_data.sample_rate:
    logger.info(
            f"Resampling HRTFs to {room_data.sample_rate:.0f} Hz")
    hrtf_reader.resample_hrirs(room_data.sample_rate)


In [None]:
### get train dataset for different grid spacings
grid_resolution_m = np.arange(config_dict.num_grid_spacing, 0,
                                  -1) * room_data.grid_spacing_m
head_orientations = np.zeros((4, 2))
# these are the only directions in NAF
head_orientations[:, 0] = np.array([0, 90, 180, 270])

for k in range(config_dict.num_grid_spacing):
    logger.info(f'Creatng NAF dataset for grid spacing = {np.round(grid_resolution_m[k], 1)}m')

    pkl_path = f'{save_path}/naf_dataset_grid_spacing={grid_resolution_m[k]:.1f}m.pkl'

    if not os.path.exists(pkl_path):
        all_train_rec_pos = []
        all_train_srir = []
        all_valid_rec_pos = []
        all_valid_srir = []
        
        # prepare the training and validation data for DiffGFDN
        train_dataset, valid_dataset, dataset_ref = load_dataset(
            room_data,
            config_dict.device,
            grid_resolution_m=np.round(grid_resolution_m[k], 1),
            network_type=config_dict.network_type,
            batch_size=config_dict.batch_size)
    
        logger.info("Creating training BRIRs")
        # training data
        for data in train_dataset:
            cur_list_pos = data['listener_position'].detach().cpu().numpy()
            all_train_rec_pos.append(cur_list_pos)
            indx = find_rec_idx_in_room_dataset(room_data, cur_list_pos)
            cur_srir = room_data.rirs[indx, ...]
            all_train_srir.append(cur_srir)
    
        train_srir = np.vstack(all_train_srir)
        train_pos = np.vstack(all_train_rec_pos)
        train_brirs = convert_srir_to_brir(train_srir, room_data.sample_rate, hrtf_reader, head_orientations)
    
        logger.info("Creating inference BRIRs")
        # inference data
        if grid_resolution_m[k] != room_data.grid_spacing_m
            for data in valid_dataset:
                cur_list_pos = data['listener_position'].detach().cpu().numpy()
                all_valid_rec_pos.append(cur_list_pos)
                indx = find_rec_idx_in_room_dataset(room_data, cur_list_pos)
                cur_srir = room_data.rirs[indx, ...]
                all_valid_srir.append(cur_srir)
        
            valid_srir = np.vstack(all_valid_srir)
            valid_pos = np.vstack(all_valid_rec_pos)
            valid_brirs = convert_srir_to_brir(valid_srir, room_data.sample_rate, hrtf_reader, head_orientations)
            num_valid_receivers = valid_pos.shape[0]
        else:
            valid_srir = None
            valid_pos = None
            valid_brirs = None
            num_valid_receivers = None
        
    
        logger.info("Creating NAF dataset")
        naf_dataset = NAFDataset(num_train_receivers = train_pos.shape[0],
                                 num_infer_receivers = num_valid_receivers,
                                 train_receiver_pos = train_pos,
                                 infer_receiver_pos = valid_pos,
                                 train_brirs = train_brirs,
                                 infer_brirs = valid_brirs,
                                 orientation = head_orientations[0, :],                      
                                )
        print(naf_dataset.train_brirs.shape, naf_dataset.train_receiver_pos.shape)
    
        # Step 3: Save the instance to a pickle file
        with open(pkl_path, "wb") as f:
            pickle.dump(naf_dataset, f)
    else:
        logger.info("File already exists!")
    