In [None]:
import torch
import numpy as np
from pathlib import Path
import pickle
import librosa
from loguru import logger
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
from numpy.typing import ArrayLike, NDArray
from typing import List
from scipy.signal import stft
from scipy.spatial import cKDTree
from scipy.fft import rfft, irfft
import spaudiopy as spa
from tqdm import tqdm

import os
os.chdir('..')
from spatial_sampling.dataloader import parse_room_data, SpatialRoomDataset, load_dataset
from spatial_sampling.config import SpatialSamplingConfig
from diff_gfdn.utils import db, db2lin, ms_to_samps

from src.sound_examples import binaural_dynamic_rendering
from src.convert_mat_to_pkl_ambi import process_ambi_srirs
from src.run_model import load_and_validate_config
from src.dataclass import BarycentricInterpolatedDataset

from slope2noise.utils import octave_filtering, schroeder_backward_int
from slope2noise.rooms import RoomGeometry

In [None]:
out_path = 'output/spatial_sampling/sound_examples'
room_data_pkl_path = Path('resources/Georg_3room_FDTD/srirs_spatial.pkl').resolve()
config_path = Path('data/config/spatial_sampling/').resolve()
save_path = Path('resources/Georg_3room_FDTD').resolve()
fig_path = Path('figures/spatial_sampling').resolve()

# get the original dataset
room_data = parse_room_data(room_data_pkl_path)

### Helper functions

In [None]:
def find_closest_neighbours_in_2D_grid(test_pos: NDArray, train_pos: NDArray, num_neighbours: int = 4):
    """Find closest neighbours in train_pos for barycentric interpolation for each position in test_pos"""
    """
    Finds the k closest 2D (x, y) neighbors in train_pos for each point in test_pos.

    Args:
        train_pos (np.ndarray): Array of shape (M, 3), containing (x, y, z) positions.
        test_pos (np.ndarray): Array of shape (N, 3), containing (x, y, z) positions.
        num_neighbours (int): Number of nearest neighbors to find (default = 4).

    Returns:
        indices (np.ndarray): Array of shape (N, num_neighbours) containing indices of the nearest neighbors in train_pos.
        distances (np.ndarray): Array of shape (N, num_neighbours) with corresponding 2D Euclidean distances.
    """
    # Use only x and y for 2D distance
    train_xy = train_pos[:, :2]
    test_xy = test_pos[:, :2]

    # Build KDTree for fast neighbor search
    tree = cKDTree(train_xy)

    # Query k nearest neighbors
    distances, indices = tree.query(test_xy, k=num_neighbours)

    return indices, distances


def find_mixing_time(test_srirs : NDArray, fs: float, er_target: float = 0.1) -> ArrayLike:
    ## Calculate the cutoff between early refs and late reverb
    ## based on when the energy decay curve amplitude is lower than a target
    # Returns : sample index of the mixing time for each test_srir
    
    # get the 1kHz band of the W channel
    rirs_filtered = octave_filtering(np.squeeze(test_srirs[:, 0, :]), fs, f_bands, use_amp_preserving_filterbank=True)[..., 4]
    # get the EDC    
    rirs_edc = schroeder_backward_int(rirs_filtered, time_axis=-1, normalize=True)
    ind_at_mixing_time = (rirs_edc < er_target).argmax(axis=-1)
    mixing_time_samp = (np.ceil(ind_at_mixing_time/1000) * 1000).astype(np.int32)
    return mixing_time_samp


def interpolate_late_reverb(closest_late_rev : NDArray, fs: float, gainsSH : NDArray, 
                            n_bandsLR: int= 48, erb_low_freq: float = 10) -> NDArray:
    """
    Interpolate late reverberation based on closest neighbours, according to McKenzie et al
    Args: 
        closest_late_rev: late reverb tail of the closest neighbours, of size num_time_samps x num_ambi_channels x num_neighbours
        fs (float): sampling frequency
        gainsSH: gains to be applied based on distance, of size (1, 1, num_neighbours)
    Returns:
        interpolated late tail of size num_time_samps x num_ambi_channels
    """
    # FFT of isolated late reverb (T, C, K)
    XLRnearestMeas = rfft(closest_late_rev, axis=0)  # shape: (F, C, K)
    nfftLR = XLRnearestMeas.shape[0]
    
    # Weighted interpolation
    XLRnearestMeas_gain = XLRnearestMeas * gainsSH  # broadcasting (1, 1, K)
    XLRnearestMeas_interp = np.sum(XLRnearestMeas_gain, axis=2)  # (F, C)

    # Generate ERB-spaced center frequencies
    ERBfreqs =librosa.mel_frequencies(n_mels=n_bandsLR, fmin=erb_low_freq, fmax=fs/2)
    
    # FFT bin frequencies
    freqs = np.linspace(0, fs/2, nfftLR, endpoint=False)
    
    # Find the closest FFT bin indices for the ERB frequencies
    freqsInd = np.searchsorted(freqs, ERBfreqs)
    
    # Fix first and last index if needed
    freqsInd[0] = 0
    freqsInd[-1] = np.searchsorted(freqs, fs/2)

    # Calculate RMS magnitudes over bands
    rmsLR_target = np.array([
        np.sqrt(np.mean(np.abs(XLRnearestMeas_gain[freqsInd[j]:freqsInd[j+1], 0, :])**2))
        for j in range(n_bandsLR - 1)
    ])
    rmsLR_current = np.array([
        np.sqrt(np.mean(np.abs(XLRnearestMeas_interp[freqsInd[j]:freqsInd[j+1], 0])**2))
        for j in range(n_bandsLR - 1)
    ])
    rmsDiffLR = rmsLR_target / (rmsLR_current + 1e-8)  # avoid division by zero
    
    # Interpolated RMS gain array over all bins
    rmsGainsLR = np.zeros(XLRnearestMeas_interp.shape[0], dtype=np.float32)
    for j in range(n_bandsLR - 1):
        band_range = slice(freqsInd[j], freqsInd[j+1])
        start_gain = rmsDiffLR[j]
        end_gain = rmsDiffLR[j+1] if j < n_bandsLR - 2 else rmsDiffLR[j]
        rmsGainsLR[band_range] = np.linspace(start_gain, end_gain, freqsInd[j+1] - freqsInd[j])
    
    # Cosine fade from 16kHz to Nyquist
    fade_start = freqsInd[-1]
    fade_end = nfftLR // 2
    if fade_end > fade_start:
        fade_len = fade_end - fade_start
        cosine_win = np.cos(np.linspace(0, np.pi/2, fade_len))**2
        rmsGainsLR[fade_start:fade_end] = cosine_win * rmsDiffLR[-1]
    
    # Apply gain and IFFT
    rmsGainsLR_2d = rmsGainsLR[:, None]  # shape (F, 1)
    hLR_interp = irfft(XLRnearestMeas_interp * rmsGainsLR_2d, axis=0)  # shape (T, C)
    
    # Normalize energy per channel
    rms_orig = np.sqrt(np.mean(closest_late_rev**2, axis=0))  # shape (C, K)
    rms_interp = np.sqrt(np.mean(hLR_interp**2, axis=0))     # shape (C,)
    scaling = rms_orig @ gainsSH[0, 0, :]  # combine gains per channel
    scaling /= (rms_interp + 1e-8)
    hLR_interp *= scaling
    
    return hLR_interp


### Get the true dataset and interpolated dataset for different grid spacings

In [None]:
config_file = f'{config_path}/treble_data_grid_training_1000Hz_directional_spatial_sampling_test.yml'
config_dict = load_and_validate_config(config_file,
                                       SpatialSamplingConfig)


### get train dataset for different grid spacings
grid_resolution_m = np.arange(config_dict.num_grid_spacing, 0,
                                  -1) * room_data.grid_spacing_m
f_bands = [63, 125, 250, 500, 1000, 2000, 4000, 8000]
num_neighbours = 4
ambi_order = room_data.ambi_order
num_ambi_channels = (ambi_order+1)**2
win_len_ms = 5
win_len_samps = ms_to_samps(win_len_ms, room_data.sample_rate)
window = np.broadcast_to(np.hanning(win_len_samps), (num_neighbours, num_ambi_channels) + (win_len_samps, ))

In [None]:
for k in [0]: #range(config_dict.num_grid_spacing - 1):
    logger.info(f'Creatng barycentric interpolated dataset for grid spacing = {np.round(grid_resolution_m[k], 1)}m')

    pkl_path = f'{save_path}/bary_interp_dataset_grid_spacing={grid_resolution_m[k]:.1f}m.pkl'

    if not os.path.exists(pkl_path):
        all_train_rec_pos = []
        all_train_srir = []
        all_valid_rec_pos = []
        all_valid_srir = []
        
        # prepare the training and validation data for DiffGFDN
        train_dataset, valid_dataset, dataset_ref = load_dataset(
            room_data,
            config_dict.device,
            grid_resolution_m=np.round(grid_resolution_m[k], 1),
            network_type=config_dict.network_type,
            batch_size=config_dict.batch_size)
    
        logger.info("Creating training SRIRs")
        # training data
        for data in train_dataset:
            cur_list_pos = data['listener_position'].detach().cpu().numpy()
            all_train_rec_pos.append(cur_list_pos)
            indx = room_data.find_rec_idx_in_room_dataset(cur_list_pos)
            cur_srir = room_data.rirs[indx, ...]
            all_train_srir.append(cur_srir)
    
        train_srir = np.vstack(all_train_srir)
        train_pos = np.vstack(all_train_rec_pos)
    
        logger.info("Creating test SRIRs")
        # inference data
        for data in valid_dataset:
            cur_list_pos = data['listener_position'].detach().cpu().numpy()
            all_valid_rec_pos.append(cur_list_pos)
            indx = room_data.find_rec_idx_in_room_dataset(cur_list_pos)
            cur_srir = room_data.rirs[indx, ...]
            all_valid_srir.append(cur_srir)
    
        valid_srir = np.vstack(all_valid_srir)
        valid_pos = np.vstack(all_valid_rec_pos)
        num_valid_receivers = valid_pos.shape[0]
    
        logger.info("Calculating KDTree of neighbours")
        closest_neighbor_idxs, closest_neighbor_dist = find_closest_neighbours_in_2D_grid(valid_pos, train_pos)
        logger.info("Calculating mixing time")
        mixing_time_samp = find_mixing_time(valid_srir, room_data.sample_rate)
        interp_srir = np.zeros_like(valid_srir)
    
        logger.info("Interpolating test SRIRs using train SRIRs")
        for rec_idx in tqdm(range(num_valid_receivers)):
            cur_valid_pos = valid_pos[rec_idx, :]
            # kdtree output is already sorted, so take the closest neighbour
            idx_nearest = closest_neighbor_idxs[rec_idx, 0]
            gains = 1.0 / (closest_neighbor_dist[rec_idx, :] + np.finfo(float).eps)
            gains = gains / np.sum(np.abs(gains))
        
            # Store in gainsSH with shape (1, 1, num_neighbors)
            gainsSH = np.zeros((1, 1, num_neighbours))
            gainsSH[0, 0, :] = gains
            
            # Slice and multiply
            late_rev_start_samp = mixing_time_samp[idx_nearest]
            cur_late_rev_isolated = train_srir[closest_neighbor_idxs[rec_idx, :], :, late_rev_start_samp:]
            cur_late_rev_isolated[..., late_rev_start_samp:late_rev_start_samp + win_len_samps] *= window
        
            # interpolate
            cur_late_rev_isolated = np.transpose(cur_late_rev_isolated, (2, 1, 0))
            cur_late_reverb_interp = interpolate_late_reverb(cur_late_rev_isolated, room_data.sample_rate, gainsSH)
            interp_srir[rec_idx, :, late_rev_start_samp+1:] = cur_late_reverb_interp.T
        
        logger.info("Creating interpolated dataset")
        interp_dataset = BarycentricInterpolatedDataset(
                                 num_infer_receivers = num_valid_receivers,
                                 infer_receiver_pos = valid_pos,
                                 ref_srir = valid_srir,
                                 pred_srir = interp_srir,
                                 mixing_time_samp = mixing_time_samp,
                                )
        
        # Step 3: Save the instance to a pickle file
        with open(pkl_path, "wb") as f:
            pickle.dump(interp_dataset, f)
    else:
        logger.info("File already exists!")
        with open(pkl_path, "rb") as f:
            interp_dataset = pickle.load(f)

### Convert true and interpolated SRIRs to DRIRs and calculate EDC error for each direction and frequency band

In [None]:
room = RoomGeometry(room_data.sample_rate,
                     room_data.num_rooms,
                     np.array(room_data.room_dims),
                     np.array(room_data.room_start_coord),
                     aperture_coords=room_data.aperture_coords)

logger.info("Converting SRIRs to DRIRs")
ref_drirs = process_ambi_srirs(interp_dataset.ref_srir.transpose(1, -1, 0), room_data.ambi_order, 
                               room_data.sph_directions).transpose(2,0,1)
pred_drirs =  process_ambi_srirs(interp_dataset.pred_srir.transpose(1,-1, 0), room_data.ambi_order, 
                                 room_data.sph_directions).transpose(2,0,1)


error_db = np.zeros((room_data.num_directions, len(f_bands)))
mixing_time_samp = interp_dataset.mixing_time_samp
trunc_at = ms_to_samps(2000, room_data.sample_rate)

In [None]:
for j in range(room_data.num_directions):
    logger.info(f"Calculating EDC error for direction {j}")
    
    ref_drirs_filtered = octave_filtering(ref_drirs[:, j, :trunc_at], room_data.sample_rate, f_bands, use_amp_preserving_filterbank=True)
    pred_drirs_filtered = octave_filtering(pred_drirs[:, j, :trunc_at], room_data.sample_rate, f_bands, use_amp_preserving_filterbank=True)

    for band_idx in range(len(f_bands)):
        cur_ref_drirs = ref_drirs_filtered[..., band_idx]
        cur_pred_drirs = pred_drirs_filtered[..., band_idx]
    
        ref_edc = db(schroeder_backward_int(cur_ref_drirs, time_axis=-1), is_squared=True)
        pred_edc = db(schroeder_backward_int(cur_pred_drirs, time_axis=-1), is_squared=True)

        cur_error_db = np.array([np.mean(np.abs(ref_edc[i, mixing_time_samp[i]:] - pred_edc[i, mixing_time_samp[i]:]), 
                                         axis=-1)
                        for i in range(interp_dataset.num_infer_receivers)])
        
        error_db[j, band_idx] = cur_error_db.mean()
    
        logger.info(f'Mean EDC error for frequency = {f_bands[band_idx]}Hz and direction {j} is {cur_error_db.mean():.3f} dB')

        room.plot_edc_error_at_receiver_points(
                interp_dataset.infer_receiver_pos,
                np.array(room_data.source_position).squeeze(),
                db2lin(cur_error_db),
                scatter_plot=True,
                cur_freq_hz=None,
                save_path=f'{fig_path}/bary_interp_edc_error_in_space_direction={j+1}_' 
                          +
                          f'grid_resolution_m={np.round(grid_resolution_m[k], 3)}.png',
                title=f'az = {np.degrees(room_data.sph_directions[0, j]):.2f} deg,'
                +
                f' pol = {np.degrees(room_data.sph_directions[1, j]):.2f} deg'
        )

In [None]:
logger.info(f'Avg EDC over all directions for each band = {np.round(np.mean(error_db, axis=0), 3)} dB')

In [None]:
rec_idx = 252
plt.figure()
plt.plot(ref_edc[rec_idx, mixing_time_samp[rec_idx]:trunc_at])
plt.plot(pred_edc[rec_idx, mixing_time_samp[rec_idx]:trunc_at])
err = np.abs(ref_edc[rec_idx, mixing_time_samp[rec_idx]:trunc_at] - pred_edc[rec_idx, mixing_time_samp[rec_idx]:trunc_at])
plt.plot(err)
plt.show()

print(np.mean(err))