### Note
This notebook contains the code for the binaural rendering demo. It uses Ambisonics A-format recordings, so the results here are not exactly the same as the paper result which used omni channel only. Also, this is just one way of working with ambisonics, multi-channel data and there could be better methods.

In [None]:
import os
import sys
import numpy as np 
from sklearn.cluster import KMeans
import scipy
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import IPython.display as ipd
import pickle
import mat73
from collections import defaultdict
from pathlib import Path 
import soundfile as sf 
import spaudiopy

sys.path.append("../")
from DecayFitNet.python.toolbox.DecayFitNetToolbox import DecayFitNetToolbox
from DecayFitNet.python.toolbox.BayesianDecayAnalysis import BayesianDecayAnalysis
from DecayFitNet.python.toolbox.core import  decay_kernel, schroeder_to_envelope, PreprocessRIR, decay_model, discard_last_n_percent, FilterByOctaves
import yaml


from fade_in_reverb.data import load_simulation_dataset, load_measurement_dataset
from fade_in_reverb.analysis import load_common_decay_times, load_envelope_fit_result
config = yaml.safe_load(open("../fade_in_reverb/config.yaml"))

In [None]:
# Load sofa data

In [None]:
ls_sofa_files = ["/Users/kyungyunlee/dev/datasets/2024_blind_multi_room/hallway-lecturehall_sofa/hallways-lecturehall_zoom_ls_1.sofa",
             "/Users/kyungyunlee/dev/datasets/2024_blind_multi_room/hallway-lecturehall_sofa/hallways-lecturehall_zoom_ls_2.sofa",
             "/Users/kyungyunlee/dev/datasets/2024_blind_multi_room/hallway-lecturehall_sofa/hallways-lecturehall_zoom_ls_3.sofa",
             "/Users/kyungyunlee/dev/datasets/2024_blind_multi_room/hallway-lecturehall_sofa/hallways-lecturehall_zoom_ls_4.sofa"]


# Read sofa file for each loudspeaker
a_format_rirs = [] 
omni_rirs = [] 
for i, ls_sofa_file in enumerate(ls_sofa_files):
    
    ls_sofa_data = spaudiopy.io.load_sofa_data(ls_sofa_file) 
    ls_rir_a_format_tmp = ls_sofa_data['Data.IR']
    ls_rir_omni_tmp = np.sum(ls_rir_a_format_tmp, axis=1)
    
    
    
    # Process the data a little bit 
    
    ls_rir_omni_tmp = ls_rir_omni_tmp + np.random.randn(ls_rir_omni_tmp.shape[0], ls_rir_omni_tmp.shape[1]) * 1e-8
    ls_rir_a_format_tmp = ls_rir_a_format_tmp + np.random.randn(ls_rir_a_format_tmp.shape[0], ls_rir_a_format_tmp.shape[1], ls_rir_a_format_tmp.shape[2]) * 1e-8
    # Preprocess 
    ls_rir_omni = [] 
    ls_rir_a_format = [] 
    mask=np.ones(10)/10

 
    for i in range(len(ls_rir_omni_tmp)) : 
        curr_omni_rir = ls_rir_omni_tmp[i] # shape = (num_samples)
        curr_aformat_rir = ls_rir_a_format_tmp[i] # shape = (4, num_samples)
        log_energy = 10 * np.log10(np.convolve(curr_omni_rir**2, mask))
        noise_floor_level = np.mean(log_energy[500 : 1500]) 
        direct_thresh = noise_floor_level + 20
        noise_thresh = noise_floor_level 
        
        front_index = np.where(log_energy[:8000] > direct_thresh)
        # print (front_index)
        if len(front_index[0]) == 0 :        

            # just remove noise 
            front_index = np.where(log_energy[:8000] > noise_thresh)[0][0]
            front_index = max(0, front_index - 200)
            cut_omni_rir = curr_omni_rir[front_index :]
            cut_aformat_rir = curr_aformat_rir[:, front_index : ]

        else : 
            front_index = front_index[0][0]
       
            front_index = max(0, front_index - 200)
            cut_omni_rir = curr_omni_rir[front_index :]
            cut_aformat_rir = curr_aformat_rir[:, front_index : ]

        
        cut_length = len(cut_omni_rir)
        if cut_length < 48000: 
            # Pad
            tmp = np.zeros((48000,))
            tmp[:cut_length] = cut_omni_rir
            cut_omni_rir =tmp 
            
            tmp_aformat = np.zeros((4, 48000))
            tmp_aformat[:, :cut_length] = cut_aformat_rir
            cut_aformat_rir = tmp_aformat 
            
        else :
            cut_omni_rir = cut_omni_rir[:48000]
            cut_aformat_rir = cut_aformat_rir[:, :48000]

        ls_rir_omni.append(cut_omni_rir) 
        ls_rir_a_format.append(cut_aformat_rir)
    
    
    
    a_format_rirs.append(ls_rir_a_format)
    omni_rirs.append(ls_rir_omni)

a_format_rirs = np.array(a_format_rirs)
omni_rirs = np.array(omni_rirs)
print (a_format_rirs.shape, omni_rirs.shape)

In [None]:
# Combine all omni RIRs and compute the common decay times

In [None]:
common_decay_times = load_common_decay_times("../data/hallway-lectureHall_combined_common_decay_times_sofa0722.npy", omni_rirs.reshape(-1, omni_rirs.shape[-1]), n_slopes=4)
print (common_decay_times)

In [None]:
# Perform estimate for each A-format channel 

In [None]:
a_format_pos_fit = []
a_format_neg_fit = []
a_format_orig_env = []
for loudspk_idx in range(4) : 
    curr_rirs = a_format_rirs[loudspk_idx]
    pos_fit_result, neg_fit_result, all_original_env = load_envelope_fit_result(f"../data/hallway-lectureHall_model_fit_result_ls_{loudspk_idx+1}.pkl", curr_rirs, common_decay_times, multichannel=True, plot=False)
    a_format_pos_fit.append(pos_fit_result)
    a_format_neg_fit.append(neg_fit_result)
    a_format_orig_env.append(all_original_env)

a_format_pos_fit = np.array(a_format_pos_fit)
a_format_neg_fit = np.array(a_format_neg_fit)
a_format_orig_env = np.array(a_format_orig_env)

In [None]:
a_format_neg_fit.shape

In [None]:
# Reconstruct the result 

In [None]:
L = 40000

neg_fit_recon = np.zeros_like(a_format_rirs)
neg_fit_recon = neg_fit_recon[:, :, :, :L]

pos_fit_recon = np.zeros_like(a_format_rirs)
pos_fit_recon = pos_fit_recon[:, :, :, :L]

loudspk_idx, n_traj, n_channels, n_bands, n_est = a_format_neg_fit.shape
for loudspk_idx in range(4) : 
    for traj_idx in range(n_traj):
        for c_idx in range(n_channels) : 
            rir = a_format_rirs[loudspk_idx][traj_idx][c_idx][:L]
            

            curr_rir_neg_fit_result = a_format_neg_fit[loudspk_idx][traj_idx][c_idx]
            curr_rir_pos_fit_result = a_format_pos_fit[loudspk_idx][traj_idx][c_idx]
            curr_rir_original_env = a_format_orig_env[loudspk_idx][traj_idx][c_idx]
            
            ds_start_index = 2
            downSampleLength = 200 
            # Make time axis (downsampled version and full version)
            timeAxis_ds = np.linspace(0, L / config['sample_rate'], downSampleLength- ds_start_index) 
            timeAxis_fullLength = np.linspace(0, L/config['sample_rate'], L)


            neg_recon_rir = np.zeros_like(rir)
            pos_recon_rir = np.zeros_like(rir)
            
            
            for bIdx in range(len(config['f_bands'])) : 

                # Filter signal by octave
                filterbank = FilterByOctaves(order=6, sample_rate=config['sample_rate'], backend='scipy',
                                                            center_frequencies=[config['f_bands'][bIdx]])



                # Perform octave filtering at the current octave 
                octave_filtered_rir = filterbank(torch.FloatTensor(rir))[0]
                octave_filtered_rir = octave_filtered_rir.numpy()

                octave_filtered_rir_norm_factor = np.max(np.abs(octave_filtered_rir)) 
                octave_filtered_rir = octave_filtered_rir / octave_filtered_rir_norm_factor

                # Compute EDF for plotting later 
                rir_preprocessing = PreprocessRIR(sample_rate=config['sample_rate'], filter_frequencies=[config['f_bands'][bIdx]])
                edf, __ = rir_preprocessing.schroeder(rir, analyse_full_rir=True) # normalized EDFs 
                edf = edf.squeeze(0).squeeze(0)
                edf = edf.numpy()


                curr_common_decay_times = common_decay_times[bIdx]

                # Get the exponentials 
                envelopeTimes = 2 * curr_common_decay_times
                envelopes = decay_kernel(envelopeTimes, timeAxis_ds)
                envelopes_fullLength = decay_kernel(envelopeTimes, timeAxis_fullLength)

                # RMS for later gain matching 
                original_rms = np.sqrt(np.mean(octave_filtered_rir**2))

                # The noise part is just ones 
                envelopes[:, -1] = np.ones_like(envelopes[:, -1]) 
                envelopes_fullLength[:, -1] = np.ones_like(envelopes_fullLength[:, -1]) 

                # Neg envelope full length
                neg_weighted_envelopes = envelopes_fullLength[:, :-1]  * curr_rir_neg_fit_result[bIdx][:-1]
                neg_envelopes = np.sum(neg_weighted_envelopes, 1) 

                # Neg envelope full length + noise 
                neg_weighted_envelopes_with_noise = envelopes_fullLength  * curr_rir_neg_fit_result[bIdx]
                neg_envelopes_with_noise = np.sum(neg_weighted_envelopes_with_noise, 1) 

                # Neg envelope downsampled 
                # neg_weighted_envelopes_ds = envelopes[:, :-1]  * curr_rir_neg_fit_result[bIdx][:-1] 
                neg_weighted_envelopes_ds = envelopes  * curr_rir_neg_fit_result[bIdx]
                neg_envelopes_ds = np.sum(neg_weighted_envelopes_ds, 1) 

                pos_weighted_envelopes = envelopes_fullLength[:, :-1]  * curr_rir_pos_fit_result[bIdx][:-1] 
                pos_envelopes = np.sum(pos_weighted_envelopes, 1) 


                pos_weighted_envelopes_with_noise = envelopes_fullLength  * curr_rir_pos_fit_result[bIdx]
                pos_envelopes_with_noise = np.sum(pos_weighted_envelopes_with_noise, 1) 

                # pos_weighted_envelopes_ds = envelopes[:, :-1]  * curr_rir_pos_fit_result[bIdx][:-1] 
                pos_weighted_envelopes_ds = envelopes  * curr_rir_pos_fit_result[bIdx]
                pos_envelopes_ds = np.sum(pos_weighted_envelopes_ds, 1) 


                noise = np.random.randn(L*2)
                noise_rms = np.sqrt(np.mean(noise**2))
                octave_filtered_noise = filterbank(torch.FloatTensor(noise))
                octave_filtered_noise = octave_filtered_noise.numpy()[0]
                octave_filtered_noise  = octave_filtered_noise[L//2 : -L//2]

                neg_shaped_noise = neg_envelopes * octave_filtered_noise
                neg_shaped_noise_rms = np.sqrt(np.mean(neg_shaped_noise**2)) 
                neg_shaped_noise *= original_rms / neg_shaped_noise_rms 

                neg_shaped_noise_with_noise = neg_envelopes_with_noise * octave_filtered_noise
                neg_shaped_noise_rms_with_noise = np.sqrt(np.mean(neg_shaped_noise_with_noise**2)) 
                neg_shaped_noise_with_noise *= original_rms / neg_shaped_noise_rms_with_noise 

                noise = np.random.randn(L*2)
                noise_rms = np.sqrt(np.mean(noise**2))
                octave_filtered_noise = filterbank(torch.FloatTensor(noise))
                octave_filtered_noise = octave_filtered_noise.numpy()[0]
                octave_filtered_noise  = octave_filtered_noise[L//2 : -L//2]

                pos_shaped_noise = pos_envelopes * octave_filtered_noise
                pos_shaped_noise_rms = np.sqrt(np.mean(pos_shaped_noise**2)) 
                pos_shaped_noise *= original_rms / pos_shaped_noise_rms 

                pos_shaped_noise_with_noise = pos_envelopes_with_noise * octave_filtered_noise
                pos_shaped_noise_rms_with_noise = np.sqrt(np.mean(pos_shaped_noise_with_noise**2)) 
                pos_shaped_noise_with_noise *= original_rms / pos_shaped_noise_rms_with_noise 

                fittedEDF_neg = np.flipud(np.cumsum(np.flipud(neg_shaped_noise_with_noise**2)))
                # Normalize to 1
                norm_val = np.max(fittedEDF_neg, axis=-1)
                fittedEDF_neg = fittedEDF_neg / norm_val

                fittedEDF_pos = np.flipud(np.cumsum(np.flipud(pos_shaped_noise_with_noise**2)))
                # Normalize to 1
                norm_val = np.max(fittedEDF_pos, axis=-1)
                fittedEDF_pos = fittedEDF_pos / norm_val
                
                octave_filtered_rir = octave_filtered_rir * octave_filtered_rir_norm_factor
                neg_shaped_noise  = neg_shaped_noise * octave_filtered_rir_norm_factor
                pos_shaped_noise = pos_shaped_noise * octave_filtered_rir_norm_factor
                
                neg_recon_rir += neg_shaped_noise 
                pos_recon_rir += pos_shaped_noise 
                
                neg_fit_recon[loudspk_idx][traj_idx][c_idx] = neg_recon_rir
                pos_fit_recon[loudspk_idx][traj_idx][c_idx] = pos_recon_rir

In [None]:
# Now convert this to spherical harmonics 

In [None]:
neg_fit_recon.shape

In [None]:
neg_fit_recon_sh = np.zeros_like(neg_fit_recon)
pos_fit_recon_sh = np.zeros_like(pos_fit_recon)
original_sh = np.zeros_like(a_format_rirs)

for loudspk_idx in range(4) : 
    for traj_idx in range(n_traj):
        b_format_neg = spaudiopy.sph.soundfield_to_b(neg_fit_recon[loudspk_idx][traj_idx], W_weight=None)
        sh_format_neg = spaudiopy.sph.b_to_sh(b_format_neg, W_weight=None)
        neg_fit_recon_sh[loudspk_idx][traj_idx] = sh_format_neg
        
        
        b_format_pos = spaudiopy.sph.soundfield_to_b(pos_fit_recon[loudspk_idx][traj_idx], W_weight=None)
        sh_format_pos = spaudiopy.sph.b_to_sh(b_format_pos, W_weight=None)
        pos_fit_recon_sh[loudspk_idx][traj_idx] = sh_format_pos
        
        
        b_format_orig = spaudiopy.sph.soundfield_to_b(a_format_rirs[loudspk_idx][traj_idx], W_weight=None)
        sh_format_orig = spaudiopy.sph.b_to_sh(b_format_orig, W_weight=None)
        original_sh[loudspk_idx][traj_idx] = sh_format_orig

In [None]:
# Save as sofa file 

In [None]:
# Load hrirs

In [None]:
hrirs = spaudiopy.io.load_hrirs(fs=48000)

hrirs_nm_left = spaudiopy.sph.src_to_sh(hrirs.left, hrirs.azi, hrirs.zen, 1, sh_type='real')
hrirs_nm_right = spaudiopy.sph.src_to_sh(hrirs.right, hrirs.azi, hrirs.zen, 1, sh_type='real')
hrirs_nm = np.vstack([hrirs_nm_left[np.newaxis,:], hrirs_nm_right[np.newaxis,:]])
hrirs_nm.shape

In [None]:
# Speech signal
speech_file = "/Users/kyungyunlee/Downloads/Speech - Core Take - shortened.wav"
source, fs = sf.read(speech_file, dtype='float64')
print (source.shape, fs)
# source_len = 60 * fs
# source = source[:source_len]
# print (source.shape, fs)
# speech = speech[:,0]
ipd.Audio(source, rate=fs)


In [None]:
# Choose loudspeaker number 
ls = 0

loudspk_neg_fit_recon_sh = neg_fit_recon_sh[ls]
loudspk_pos_fit_recon_sh = pos_fit_recon_sh[ls]
original_rir = a_format_rirs[ls]
print (loudspk_neg_fit_recon_sh.shape, original_rir.shape)

In [None]:
# Convolve the signal with each spherical harmonics 

In [None]:
block_size = 28800
source_len = source.shape[0] 
print(source_len)
num_blocks = source_len // block_size 
print (num_blocks)

# Make the trajectory longer to match the number of blocks 
traj_start, traj_end = 0, 50
trajectory_idx = np.arange(traj_start, traj_end)
print (trajectory_idx)
n_repeat = num_blocks // len(trajectory_idx) + 1 
print (n_repeat)
trajectory_list = np.repeat(trajectory_idx, n_repeat)
print (trajectory_list)


print (len(trajectory_list), num_blocks)
assert len(trajectory_list) >= num_blocks

In [None]:
fBands = [125, 250, 500, 1000, 2000, 4000, 8000]
L = 40000
target_loudness_level = -30 

output_meas_sh = np.zeros((4, len(source) + 48000))
output_neg_sh = np.zeros((4, len(source) + 48000))
output_pos_sh = np.zeros((4, len(source) + 48000))


output_meas_bin = np.zeros((2, len(source) + 48000))
output_neg_bin = np.zeros((2, len(source) + 48000))
output_pos_bin = np.zeros((2, len(source) + 48000))

for i in range(num_blocks) : 
    # Get current position
    
    curr_pos_idx = trajectory_list[i] 
    
    # Ground truth 
    curr_rir_loudspeaker = original_rir[curr_pos_idx]
    curr_rir_loudspeaker = curr_rir_loudspeaker[:,:L] 
    
    neg_recon_rir = loudspk_neg_fit_recon_sh[curr_pos_idx]
    pos_recon_rir = loudspk_pos_fit_recon_sh[curr_pos_idx]
    assert curr_rir_loudspeaker.shape ==  neg_recon_rir.shape
    assert curr_rir_loudspeaker.shape[0] == 4 and neg_recon_rir.shape[0] == 4 
    
#     plt.title(curr_pos_idx)
#     plt.plot(neg_recon_rir[0])
#     plt.show()
#     plt.plot(pos_recon_rir[0])
#     plt.show()
    
    
#     print ("----------")

    # current block of source
    curr_source = source[i * block_size : (i+1) * block_size] 
    if len(curr_source)== 0 : 
        break 
    # Copy source signal to have 4 channels (scipy convention)
    curr_source = curr_source[np.newaxis, :]
    curr_source = np.repeat(curr_source, 4, axis=0)
    
    
    convolved_source_meas_sh = scipy.signal.fftconvolve(curr_source, curr_rir_loudspeaker, axes=1)
    convolved_source_neg_sh = scipy.signal.fftconvolve(curr_source, neg_recon_rir, axes=1)
    convolved_source_pos_sh = scipy.signal.fftconvolve(curr_source, pos_recon_rir, axes=1)
    
    
    

    # Do any head rotation if needed 
    convolved_source_meas_bin = spaudiopy.decoder.sh2bin(convolved_source_meas_sh, hrirs_nm)
    convolved_source_neg_bin = spaudiopy.decoder.sh2bin(convolved_source_neg_sh, hrirs_nm)
    convolved_source_pos_bin = spaudiopy.decoder.sh2bin(convolved_source_pos_sh, hrirs_nm)
    
    output_meas_bin[:, i * block_size : i * block_size + convolved_source_meas_bin.shape[-1]] += convolved_source_meas_bin
    output_neg_bin[:, i * block_size : i * block_size + convolved_source_neg_bin.shape[-1]] += convolved_source_neg_bin
    output_pos_bin[:, i * block_size : i * block_size + convolved_source_pos_bin.shape[-1]] += convolved_source_pos_bin

    traj_end = curr_pos_idx

curr_level = np.sqrt(np.mean(output_meas_bin**2))
adj_level = np.sqrt(10**(target_loudness_level/10)) / curr_level 
output_meas_bin = output_meas_bin * adj_level 

curr_level = np.sqrt(np.mean(output_neg_bin**2))
adj_level = np.sqrt(10**(target_loudness_level/10)) / curr_level 
output_neg_bin = output_neg_bin * adj_level

curr_level = np.sqrt(np.mean(output_pos_bin**2))
adj_level = np.sqrt(10**(target_loudness_level/10)) / curr_level 
output_pos_bin = output_pos_bin * adj_level


sf.write(f"rendered_original_ls{ls}_pos{traj_start}_{traj_end}.wav", output_meas_bin.T, 48000)
sf.write(f"rendered_neg_ls{ls}_pos{traj_start}_{traj_end}.wav", output_neg_bin.T, 48000)
sf.write(f"rendered_pos_ls{ls}_pos{traj_start}_{traj_end}.wav", output_pos_bin.T, 48000)

In [None]:
ipd.Audio(output_meas_bin, rate=48000)

In [None]:
ipd.Audio(output_neg_bin, rate=48000)

In [None]:
ipd.Audio(output_pos_bin, rate=48000)