In [None]:
import os 
import sys 
import numpy as np 
from sklearn.cluster import KMeans
import scipy
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import pickle
import mat73
from collections import defaultdict
from pathlib import Path 
import soundfile as sf 

sys.path.append("../")
from DecayFitNet.python.toolbox.DecayFitNetToolbox import DecayFitNetToolbox
from DecayFitNet.python.toolbox.BayesianDecayAnalysis import BayesianDecayAnalysis
from DecayFitNet.python.toolbox.core import  decay_kernel, schroeder_to_envelope, PreprocessRIR, decay_model, discard_last_n_percent, FilterByOctaves
import yaml

from fade_in_reverb.data import load_simulation_dataset, load_measurement_dataset
from fade_in_reverb.analysis import load_common_decay_times, load_envelope_fit_result
# import seaborn as sns 
config = yaml.safe_load(open("../fade_in_reverb/config.yaml"))


text_width = 7.16


font = {'family' : 'Times New Roman',
        'size'   : 9}
params = {'text.usetex': False, 'mathtext.fontset': 'cm'}
plt.rcParams.update(params)
mpl.rc('font', **font)
plt.rcParams['text.latex.preamble'] = r"\usepackage{siunitx} \sisetup{detect-all} \usepackage{helvet} \usepackage{sansmath} \sansmath"   

sys.path

In [None]:
rirs, trajectory  = load_measurement_dataset()
print (rirs.shape, trajectory.shape)

In [None]:
combined_common_decay_times = load_common_decay_times("../data/hallway-lectureHall_combined_common_decay_times.npy", rirs, n_slopes=4)
print (combined_common_decay_times)

In [None]:
# Combined data fit with 4 slopes 
combined_pos_fit_result, combined_neg_fit_result, combined_all_original_env = load_envelope_fit_result("../data/hallway-lectureHall_combined_model_fit_result.pkl", rirs[:, :40000], combined_common_decay_times, multichannel=False, plot=False)


In [None]:
# RIR_number 41, 250Hz, 1000 Hz 
# RIR number 121 (in the meeting room)
# RIR 



# Second half 
# RIR number 88, 2000Hz 
# 108, 2000Hz 
# 124, 1000Hz 
# 20 

rir_number = 160 + 20
rir = rirs[rir_number]
# rir = first_half_rirs[rir_number] # 104, 292
# rir = second_half_rirs[rir_number] # 104, 292

# # fig = plt.figure(figsize=(column_width, column_width))
# plt.plot(rir[:10000])
# plt.title("test")
# plt.show()
L = 40000
rir = rir[:40000]

# curr_rir_neg_fit_result = firstHalf_neg_fit_result[rir_number]
# curr_rir_pos_fit_result = firstHalf_pos_fit_result[rir_number]
# curr_rir_original_env = firstHalf_all_original_env[rir_number]


# curr_rir_neg_fit_result = secondHalf_neg_fit_result[rir_number]
# curr_rir_pos_fit_result = secondHalf_pos_fit_result[rir_number]
# curr_rir_original_env = secondHalf_all_original_env[rir_number]

curr_rir_neg_fit_result = combined_neg_fit_result[rir_number]
curr_rir_pos_fit_result = combined_pos_fit_result[rir_number]
curr_rir_original_env = combined_all_original_env[rir_number]


ds_start_index = 2
downSampleLength = 200 
# Make time axis (downsampled version and full version)
timeAxis_ds = np.linspace(0, L / config['sample_rate'], downSampleLength- ds_start_index) 
timeAxis_fullLength = np.linspace(0, L/config['sample_rate'], L)


neg_recon_rir = np.zeros_like(rir)
pos_recon_rir = np.zeros_like(rir)


for bIdx in range(len(config['f_bands'])) : 

    # Filter signal by octave
    filterbank = FilterByOctaves(order=6, sample_rate=config['sample_rate'], backend='scipy',
                                                center_frequencies=[config['f_bands'][bIdx]])



    # Perform octave filtering at the current octave 
    octave_filtered_rir = filterbank(torch.FloatTensor(rir))[0]
    octave_filtered_rir = octave_filtered_rir.numpy()

    octave_filtered_rir_norm_factor = np.max(np.abs(octave_filtered_rir)) 
    octave_filtered_rir = octave_filtered_rir / octave_filtered_rir_norm_factor

    # Compute EDF for plotting later 
    rir_preprocessing = PreprocessRIR(sample_rate=config['sample_rate'], filter_frequencies=[config['f_bands'][bIdx]])
    edf, __ = rir_preprocessing.schroeder(rir, analyse_full_rir=True) # normalized EDFs 
    edf = edf.squeeze(0).squeeze(0)
    edf = edf.numpy()


    
#     curr_common_decay_times = firstHalf_common_decay_times[bIdx]
#     curr_common_decay_times = secondHalf_common_decay_times[bIdx]
    curr_common_decay_times = combined_common_decay_times[bIdx]
    
    # Get the exponentials 
    envelopeTimes = 2 * curr_common_decay_times
    envelopes = decay_kernel(envelopeTimes, timeAxis_ds)
    envelopes_fullLength = decay_kernel(envelopeTimes, timeAxis_fullLength)

    # RMS for later gain matching 
    original_rms = np.sqrt(np.mean(octave_filtered_rir**2))
    
    # The noise part is just ones 
    envelopes[:, -1] = np.ones_like(envelopes[:, -1]) 
    envelopes_fullLength[:, -1] = np.ones_like(envelopes_fullLength[:, -1]) 

    # Neg envelope full length
    neg_weighted_envelopes = envelopes_fullLength[:, :-1]  * curr_rir_neg_fit_result[bIdx][:-1]
    neg_envelopes = np.sum(neg_weighted_envelopes, 1) 

    # Neg envelope full length + noise 
    neg_weighted_envelopes_with_noise = envelopes_fullLength  * curr_rir_neg_fit_result[bIdx]
    neg_envelopes_with_noise = np.sum(neg_weighted_envelopes_with_noise, 1) 

    # Neg envelope downsampled 
    # neg_weighted_envelopes_ds = envelopes[:, :-1]  * curr_rir_neg_fit_result[bIdx][:-1] 
    neg_weighted_envelopes_ds = envelopes  * curr_rir_neg_fit_result[bIdx]
    neg_envelopes_ds = np.sum(neg_weighted_envelopes_ds, 1) 

    pos_weighted_envelopes = envelopes_fullLength[:, :-1]  * curr_rir_pos_fit_result[bIdx][:-1] 
    pos_envelopes = np.sum(pos_weighted_envelopes, 1) 

    
    pos_weighted_envelopes_with_noise = envelopes_fullLength  * curr_rir_pos_fit_result[bIdx]
    pos_envelopes_with_noise = np.sum(pos_weighted_envelopes_with_noise, 1) 
    
    # pos_weighted_envelopes_ds = envelopes[:, :-1]  * curr_rir_pos_fit_result[bIdx][:-1] 
    pos_weighted_envelopes_ds = envelopes  * curr_rir_pos_fit_result[bIdx]
    pos_envelopes_ds = np.sum(pos_weighted_envelopes_ds, 1) 
  
    print (curr_rir_pos_fit_result[bIdx])
    print (curr_rir_neg_fit_result[bIdx])

    noise = np.random.randn(L*2)
    noise_rms = np.sqrt(np.mean(noise**2))
    octave_filtered_noise = filterbank(torch.FloatTensor(noise))
    octave_filtered_noise = octave_filtered_noise.numpy()[0]
    octave_filtered_noise  = octave_filtered_noise[L//2 : -L//2]
    
    neg_shaped_noise = neg_envelopes * octave_filtered_noise
    neg_shaped_noise_rms = np.sqrt(np.mean(neg_shaped_noise**2)) 
    neg_shaped_noise *= original_rms / neg_shaped_noise_rms 

    neg_shaped_noise_with_noise = neg_envelopes_with_noise * octave_filtered_noise
    neg_shaped_noise_rms_with_noise = np.sqrt(np.mean(neg_shaped_noise_with_noise**2)) 
    neg_shaped_noise_with_noise *= original_rms / neg_shaped_noise_rms_with_noise 

    noise = np.random.randn(L*2)
    noise_rms = np.sqrt(np.mean(noise**2))
    octave_filtered_noise = filterbank(torch.FloatTensor(noise))
    octave_filtered_noise = octave_filtered_noise.numpy()[0]
    octave_filtered_noise  = octave_filtered_noise[L//2 : -L//2]
    
    pos_shaped_noise = pos_envelopes * octave_filtered_noise
    pos_shaped_noise_rms = np.sqrt(np.mean(pos_shaped_noise**2)) 
    pos_shaped_noise *= original_rms / pos_shaped_noise_rms 

    pos_shaped_noise_with_noise = pos_envelopes_with_noise * octave_filtered_noise
    pos_shaped_noise_rms_with_noise = np.sqrt(np.mean(pos_shaped_noise_with_noise**2)) 
    pos_shaped_noise_with_noise *= original_rms / pos_shaped_noise_rms_with_noise 


    neg_recon_rir += neg_shaped_noise 
    pos_recon_rir += pos_shaped_noise 
    

    fittedEDF_neg = np.flipud(np.cumsum(np.flipud(neg_shaped_noise_with_noise**2)))
    # Normalize to 1
    norm_val = np.max(fittedEDF_neg, axis=-1)
    fittedEDF_neg = fittedEDF_neg / norm_val

    fittedEDF_pos = np.flipud(np.cumsum(np.flipud(pos_shaped_noise_with_noise**2)))
    # Normalize to 1
    norm_val = np.max(fittedEDF_pos, axis=-1)
    fittedEDF_pos = fittedEDF_pos / norm_val
    


    fig = plt.figure(figsize=(text_width, text_width * 0.6))

    # fig = plt.figure(figsize=(text_width, text_width * 0.7), dpi=50) 
    gs = plt.GridSpec(nrows=2, ncols=1, height_ratios=[1.3, 1], hspace=0.5)

    gp0 = gs[0].subgridspec(1, 2, wspace=0.25)
    gp1 = gs[1].subgridspec(1, 3, wspace=0.05)

    ax1 = fig.add_subplot(gp0[0])
    ax2 = fig.add_subplot(gp0[1])

    ax3 = fig.add_subplot(gp1[0])
    ax4 = fig.add_subplot(gp1[1])
    ax5 = fig.add_subplot(gp1[2])
    
    xaxis = np.arange(0, 38000/48000, 200/48000) [ds_start_index:]
    print (xaxis.shape)
    ax1.plot(xaxis,np.sqrt(curr_rir_original_env[bIdx, :188]), label='measured', color='C7') 
    ax1.plot(xaxis,np.sqrt(neg_envelopes_ds[:188]), label='fade-in', color='C3')
    ax1.plot(xaxis,np.sqrt(pos_envelopes_ds[:188]), label='pos-only', color='C10')
    ax1.set_title("Power-law Scaled Envelopes")
    ax1.set_xlabel("Time in second")
    ax1.set_ylabel("Amplitude")
    ax1.set_xticks([0, 0.2, 0.4, 0.6, 0.8], [0, 0.2, 0.4, 0.6, 0.8])

    xaxis = np.arange(0, 38000/48000 , 1/48000) 
    ax2.plot(xaxis, (10 *np.log10(edf))[:38000], label='measured', color='C7')
    ax2.plot(xaxis, (10 *np.log10(fittedEDF_neg))[:38000], label='fade-in', color='C3')
    ax2.plot(xaxis, (10 *np.log10(fittedEDF_pos))[:38000], label='pos-only', color='C10') 
    ax2.set_title("Energy decay curves") 
    ax2.set_xlabel("Time in second")
    ax2.set_ylabel("Decay level in dB")
    ax2.set_xticks([0, 0.2, 0.4, 0.6, 0.8], [0, 0.2, 0.4, 0.6, 0.8])

    xaxis = np.arange(0, 38000/48000 , 1/48000) 
    
    
    octave_filtered_rir = octave_filtered_rir * octave_filtered_rir_norm_factor
    neg_shaped_noise  = neg_shaped_noise * octave_filtered_rir_norm_factor
    pos_shaped_noise = pos_shaped_noise * octave_filtered_rir_norm_factor
    
    vmax =np.max( [np.max(np.abs(octave_filtered_rir)), np.max(np.abs(neg_shaped_noise)), np.max(np.abs(pos_shaped_noise))]) * 1.2
    vmin = -vmax
    
    ax3.plot(xaxis, octave_filtered_rir[:38000], label='measured', color='C7')
    ax3.set_ylim([vmin, vmax])
    ax3.set_title("Measured RIR")
    ax3.set_ylabel("Signal Value")
    ax3.set_xticks([0, 0.2, 0.4, 0.6, 0.8], [0, 0.2, 0.4, 0.6, 0.8])

    ax4.plot(xaxis, neg_shaped_noise[:38000], label='fade-in', color='C3')
    ax4.set_ylim([vmin, vmax])
    ax4.tick_params(labelleft=False)    
    ax4.set_title("Synthesized w/ fade-in")
    ax4.set_xlabel("Time in second")
    ax4.set_xticks([0, 0.2, 0.4, 0.6, 0.8], [0, 0.2, 0.4, 0.6, 0.8])
    
    
    ax5.plot(xaxis,pos_shaped_noise[:38000], label='pos-only', color='C10')
    ax5.set_ylim([vmin, vmax])
    ax5.set_title("Synthesized w/ pos-only")
    ax5.tick_params(labelleft=False) 
    ax5.set_xticks([0, 0.2, 0.4, 0.6, 0.8], [0, 0.2, 0.4, 0.6, 0.8])
    

    # fig.suptitle(f"RIR {rir_number}, {config['f_bands'][bIdx]} Hz")
    print(f"RIR {rir_number}, {config['f_bands'][bIdx]} Hz")

    lines, labels = fig.axes[0].get_legend_handles_labels() 
    fig.legend(lines, labels, loc = 'upper center', ncol=3, labelspacing=0)


    # gs.tight_layout(fig, rect=[0, 0.03, 1, 1])  

    
    ax1.grid()
    ax2.grid()
    ax3.grid()
    ax4.grid()
    ax5.grid()
    
    output_folder = "../figures/measurement_fitting_examples/"
    if not os.path.exists(output_folder) : 
        os.makedirs(output_folder, exist_ok=True)
    
    
    plt.savefig(fname=os.path.join(output_folder, f"RIR{rir_number}_{config['f_bands'][bIdx]}Hz.pdf"), bbox_inches="tight")


In [None]:
plt.plot(rir)
plt.show()
plt.plot(neg_recon_rir)
plt.show()
plt.plot(pos_recon_rir)