In [None]:
import torch
import numpy as np
from pathlib import Path
import pickle
from loguru import logger
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
from numpy.typing import ArrayLike, NDArray
from typing import List
from scipy.signal import stft
from scipy.fft import rfft, irfft
import spaudiopy as spa
from tqdm import tqdm


import os
os.chdir('../..')
from spatial_sampling.dataloader import parse_room_data, SpatialRoomDataset, load_dataset
from spatial_sampling.config import SpatialSamplingConfig
from diff_gfdn.utils import db
from src.sofa_parser import HRIRSOFAReader, SRIRSOFAWriter, convert_srir_to_brir
from src.run_model import load_and_validate_config
from src.dataclass import NAFDataset

### Notebook to convert coupled room dataset to NAF compatible

NAF takes in BRIRs at spatial locations in the room for head orientations [0, 90, 180, 270]. To train NAF with different subsets of receivers, like we do in the WASPAA paper, we create different training and inference dataset containing SRIRs at different locations, and then convert them to BRIRs for the 4 given orientations. Finally we save them in a pickle file.

In [None]:
out_path = 'output/spatial_sampling/sound_examples'
room_data_pkl_path = Path('resources/Georg_3room_FDTD/srirs_spatial.pkl').resolve()
config_path = Path('data/config/spatial_sampling/').resolve()
save_path = Path('resources/Georg_3room_FDTD').resolve()

### Get the true room dataset for different grid spacings

In [None]:
config_file = f'{config_path}/treble_data_grid_training_1000Hz_directional_spatial_sampling_test.yml'
config_dict = load_and_validate_config(config_file,
                                       SpatialSamplingConfig)
hrtf_path = Path('resources/HRTF/48kHz/KEMAR_Knowl_EarSim_SmallEars_FreeFieldComp_48kHz.sofa')

# get the original dataset
room_data = parse_room_data(room_data_pkl_path)

# get the HRTF
hrtf_reader = HRIRSOFAReader(hrtf_path)

if hrtf_reader.fs != room_data.sample_rate:
    logger.info(
            f"Resampling HRTFs to {room_data.sample_rate:.0f} Hz")
    hrtf_reader.resample_hrirs(room_data.sample_rate)

plt.figure()
hrir_sh = hrtf_reader.get_spherical_harmonic_representation(2)
print(hrir_sh.shape)
plt.plot(hrir_sh[0,...].T)  # left ear
plt.title("Raw SH HRIRs")

plt.figure()
hrirs = hrtf_reader.ir_data
print(hrirs.shape)
plt.plot(hrirs[0, ...].T)  # left ear
plt.title("Raw HRIRs")

In [None]:
### get train dataset for different grid spacings
grid_resolution_m = np.arange(config_dict.num_grid_spacing, 0,
                                  -1) * room_data.grid_spacing_m
head_orientations = np.zeros((4, 2))
# these are the only directions in NAF
head_orientations[:, 0] = np.array([0, 90, 180, 270])

for k in [1]:
    logger.info(f'Creatng NAF dataset for grid spacing = {np.round(grid_resolution_m[k], 1)}m')

    pkl_path = f'{save_path}/naf_dataset_grid_spacing={grid_resolution_m[k]:.1f}m.pkl'

    if not os.path.exists(pkl_path):
        all_train_rec_pos = []
        all_train_srir = []
        all_valid_rec_pos = []
        all_valid_srir = []
        
        # prepare the training and validation data for DiffGFDN
        train_dataset, valid_dataset, dataset_ref = load_dataset(
            room_data,
            config_dict.device,
            grid_resolution_m=np.round(grid_resolution_m[k], 1),
            network_type=config_dict.network_type,
            batch_size=config_dict.batch_size)
    
        logger.info("Creating training BRIRs")
        # training data
        for data in train_dataset:
            cur_list_pos = data['listener_position'].detach().cpu().numpy()
            all_train_rec_pos.append(cur_list_pos)
            indx = room_data.find_rec_idx_in_room_dataset(cur_list_pos)
            cur_srir = room_data.rirs[indx, ...]
            all_train_srir.append(cur_srir)
    
        train_srir = np.vstack(all_train_srir)
        train_pos = np.vstack(all_train_rec_pos)
        train_brirs = convert_srir_to_brir(train_srir, hrtf_reader, head_orientations)
    
        logger.info("Creating inference BRIRs")
        # inference data
        if grid_resolution_m[k] != room_data.grid_spacing_m:
            for data in valid_dataset:
                cur_list_pos = data['listener_position'].detach().cpu().numpy()
                all_valid_rec_pos.append(cur_list_pos)
                indx = room_data.find_rec_idx_in_room_dataset(cur_list_pos)
                cur_srir = room_data.rirs[indx, ...]
                all_valid_srir.append(cur_srir)
        
            valid_srir = np.vstack(all_valid_srir)
            valid_pos = np.vstack(all_valid_rec_pos)
            valid_brirs = convert_srir_to_brir(valid_srir, hrtf_reader, head_orientations)
            num_valid_receivers = valid_pos.shape[0]
        else:
            valid_srir = None
            valid_pos = None
            valid_brirs = None
            num_valid_receivers = None
        
    
        logger.info("Creating NAF dataset")
        naf_dataset = NAFDataset(num_train_receivers = train_pos.shape[0],
                                 num_infer_receivers = num_valid_receivers,
                                 train_receiver_pos = train_pos,
                                 infer_receiver_pos = valid_pos,
                                 train_brirs = train_brirs,
                                 infer_brirs = valid_brirs,
                                 orientation = head_orientations[0, :],                      
                                )
    
        # Step 3: Save the instance to a pickle file
        with open(pkl_path, "wb") as f:
            pickle.dump(naf_dataset, f)
    else:
        logger.info("File already exists!")
        with open(pkl_path, "rb") as f:
            naf_dataset = pickle.load(f)

    

### Plot a BRIR for sanity check

In [None]:
# srir = train_srir[52, 0, ...]
# plt.plot(srir)
# plt.title("Example SRIR")
# plt.show()

brir = naf_dataset.train_brirs[52, 0, ...]
plt.plot(brir)
plt.title("Example BRIR")
plt.show()

### Plot the STFT and IF that is used by NAF

In [None]:
f, t, S = stft(brir, fs=room_data.sample_rate, window='hann', 
               nperseg=256, noverlap=None, nfft=2**10, return_onesided=True, axis = 0)
plt.figure(figsize=(8, 6))
plt.subplot(211)
plt.pcolormesh(t, f, db(np.squeeze(S[:, 0, :])), shading='gouraud', cmap='viridis')
plt.yscale('log')
plt.xlabel('Time (s)')
plt.ylabel('Freq (Hz)')
plt.ylim([20, 16000])
plt.title('Spectrogram for left ear')
plt.colorbar(label='Magnitude(dB)')

plt.subplot(212)
plt.pcolormesh(t, f, db(np.squeeze(S[:, 1, :])), shading='gouraud', cmap='viridis')
plt.yscale('log')
plt.xlabel('Time (s)')
plt.ylabel('Freq(Hz)')
plt.ylim([20, 16000])
plt.title('Spectrogram for right ear')
plt.colorbar(label='Magnitude(dB)')
plt.tight_layout()

In [None]:
print(t.shape, S.shape)
phase = np.unwrap(np.angle(S))
instant_freq = np.diff(phase, n=1, axis=-1)

plt.figure(figsize=(8, 6))
plt.subplot(211)
plt.pcolormesh(t[:-1], f, np.squeeze(instant_freq[:, 0, :]), shading='gouraud', cmap='viridis')
plt.yscale('log')
plt.xlabel('Time (s)')
plt.ylabel('Freq (Hz)')
plt.ylim([20, 16000])
plt.title('Instantaneous frequency left ear')
plt.colorbar(label='rad/s')

plt.subplot(212)
plt.pcolormesh(t[:-1], f, np.squeeze(instant_freq[:, 1, :]), shading='gouraud', cmap='viridis')
plt.yscale('log')
plt.xlabel('Time (s)')
plt.ylabel('Freq(Hz)')
plt.ylim([20, 16000])
plt.title('Instantaneous frequency right ear')
plt.colorbar(label='rad/s')
plt.tight_layout()