In [None]:
import os 
import sys
import numpy as np 
from sklearn.cluster import KMeans
import scipy
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
import pickle
import mat73
from collections import defaultdict

sys.path.append("../")
from DecayFitNet.python.toolbox.DecayFitNetToolbox import DecayFitNetToolbox
from DecayFitNet.python.toolbox.BayesianDecayAnalysis import BayesianDecayAnalysis
from DecayFitNet.python.toolbox.core import  decay_kernel, schroeder_to_envelope, PreprocessRIR, decay_model, discard_last_n_percent, FilterByOctaves
import yaml

from fade_in_reverb.data import load_simulation_dataset
from fade_in_reverb.analysis import load_common_decay_times, load_envelope_fit_result, get_envelope

config = yaml.safe_load(open("../fade_in_reverb/config.yaml"))

text_width = 7.16
column_width = 3.15

font = {'family' : 'Times New Roman',
        'size'   : 9}
params = {'text.usetex': False, 'mathtext.fontset': 'cm'}
plt.rcParams.update(params)
mpl.rc('font', **font)
plt.rcParams['text.latex.preamble'] = r"\usepackage{siunitx} \sisetup{detect-all} \usepackage{helvet} \usepackage{sansmath} \sansmath"   


# Load data

In [None]:
total_rirs, omni_rirs, rcvPos, srcPos = load_simulation_dataset()
print (total_rirs.shape)
print (omni_rirs.shape)

In [None]:
# Preprocess omni RIRs only 
omni_rirs_processed = [] 
mask=np.ones(10)/10


direct_thresh = -30

noise_thresh = -70

for i in range(0, len(omni_rirs)) : 
    curr_rir = omni_rirs[i]
    log_energy = 10 * np.log10(np.convolve(curr_rir**2, mask))


    front_index = np.where(log_energy[:700] > direct_thresh)
    if len(front_index[0]) > 0 : 
        direct_loc = max(0, front_index[0][0] - 200)
        cut_rir = curr_rir[direct_loc:]
        
    else : 
        # remove noise 
        front_index = np.where(log_energy[:2000] > noise_thresh)
        if len(front_index[0]) > 0 : 
            start_loc = front_index[0][0]
        else :
            start_loc = 2000 
        
        start_loc = max(0, start_loc - 200)
        cut_rir = curr_rir[start_loc:]


    if len(cut_rir) < 96000: 
        # Pad
        tmp = np.zeros((96000,))
        tmp[:len(cut_rir)] = cut_rir
        cut_rir =tmp 
    else :
        cut_rir = cut_rir[:96000]

    omni_rirs_processed.append(curr_rir)
omni_rirs_processed = np.array(omni_rirs_processed) 
omni_rirs_processed = omni_rirs_processed

In [None]:
for i in range(800, 830):  
    plt.plot(omni_rirs_processed[i][:48000])
    plt.title(i)
    plt.show()

# Load common decay times

In [None]:
common_decay_times = load_common_decay_times("../data/treble_common_decay_times.npy", total_rirs)

print (common_decay_times)

In [None]:
# omni_rirs = omni_rirs_processed[:, :48000]

# Load fitted result

In [None]:
pos_fit_result, neg_fit_result, all_original_env = load_envelope_fit_result("../data/treble_model_fit_result.pkl", omni_rirs, common_decay_times, plot=False)
print (pos_fit_result.shape, neg_fit_result.shape)

In [None]:
print (all_original_env.shape)

# Plot negative amplitude result

In [None]:
def mu_law_encoding(x) : 
    mu = 255 
    val = np.log(1 + mu * np.abs(x)) / np.log(1 + mu)
    return np.sign(x) * val 

In [None]:
def plot_room_geometry(ax) : 
    color = 'black'
    linewidth = 1
    

    x1, y1 = [0,0],[0, 8]
    x2, y2 = [0,4], [8, 8]
    x3, y3 = [4, 4], [8, 4.25]
    x4, y4 = [4,4], [2.75, 0]
    x5, y5 = [4,0], [0,0] 
    ax.plot(x1, y1, color=color, linewidth=linewidth)
    ax.plot(x2, y2, color=color, linewidth=linewidth)
    ax.plot(x3, y3, color=color, linewidth=linewidth)
    ax.plot(x4, y4, color=color, linewidth=linewidth)
    ax.plot(x5, y5, color=color, linewidth=linewidth) 
    
    x1, y1 = [4, 8.5],[5, 5]
    x2, y2 = [10, 10], [5, 2.2]
    x3, y3 = [10, 4], [2.2, 2.2]
    ax.plot(x1, y1, color=color, linewidth=linewidth)
    ax.plot(x2, y2, color=color, linewidth=linewidth)
    ax.plot(x3, y3, color=color, linewidth=linewidth)
    
    
    x1, y1 = [6, 6], [5, 13]
    x2, y2 = [6, 10], [13, 13]
    x3, y3 = [10, 10], [13, 5] 
    ax.plot(x1, y1, color=color, linewidth=linewidth)
    ax.plot(x2, y2, color=color, linewidth=linewidth)
    ax.plot(x3, y3, color=color, linewidth=linewidth)


In [None]:
# Each octave separately 
cmap = plt.colormaps.get_cmap('RdYlBu')  # viridis is the default colormap for imshow
cmap.set_bad(color='white')
fBands = [250, 500, "1k", "2k", "4k", "8k"] 
fig, axes = plt.subplots(len(fBands), 3, figsize=(text_width, text_width * 1.2))

# fig_labels = ['a','b','c','d','e','f','g','h','i','j','k','l', 'm']


for j in range(len(fBands)): 

    for k in range(config['n_slopes']): 
       
        X = np.zeros((32, 42)) 
        X[:] = np.nan 

        
        for i in range(len(omni_rirs)) : 
            # print("before", rcvPos[i])
            curr_pos = rcvPos[i] * 10
   
            curr_pos = curr_pos - 2
            # print (curr_pos)
            curr_pos = curr_pos / 3
# 
            # print("after", curr_pos)

            floor = np.floor(curr_pos[0]) 
            ceil = np.ceil(curr_pos[0]) 
        

            int_curr_pos0, int_curr_pos1 = int(round(curr_pos[0], 0)), int(round(curr_pos[1],0))

            if rcvPos[i][0] > 4 : 
                int_curr_pos0 -= 1 

            if rcvPos[i][0] > 6 and rcvPos[i][1] > 5 : 
                int_curr_pos1 -= 1


        
            X[int_curr_pos0][int_curr_pos1] = mu_law_encoding(neg_fit_result[i][j+1][k])

        X_remove_nan = X[~np.isnan(X)]
        vmin = np.abs(np.min(X_remove_nan))
        # vmin = np.min(X_remove_nan)
        vmax = np.max(X_remove_nan)
        vmax = max(vmin, vmax)
        vmin = -vmax
        

        im = axes[j,k].imshow(X.transpose(1,0), interpolation='none', aspect='equal', cmap=cmap, extent=[0, 10, 0, 13], norm='linear', origin='lower', resample=True, vmin=vmin, vmax=vmax) 
   
        plt.colorbar(im, ax=axes[j,k], label=fr'Amplitude')
        axes[j][k].set_title(fr"$A_{k}$ at {fBands[j]}Hz - $T_{k}$ = { common_decay_times[j+1][k] :.2f}s ")
        plot_room_geometry(axes[j][k])
        axes[j][k].scatter(srcPos[0], srcPos[1], marker='x', color='green')
        axes[j][k].set_xticks(np.arange(0, 12, 2))
        axes[j][k].set_xlabel("x in meters")
        axes[j][k].set_ylabel("y in meters")

plt.tight_layout()

# dpi = 600
fig.savefig(fname='../figures/simulation_room_plots.pdf', format='pdf',bbox_inches="tight") 


In [None]:
# 500 + 1000 Hz together 
cmap = plt.colormaps.get_cmap('RdYlBu')  # viridis is the default colormap for imshow
cmap.set_bad(color='white')

fig, axes = plt.subplots(1, 3, figsize=(text_width, text_width*0.3),  sharey=True, sharex=True)
plt.subplots_adjust(wspace=0.2)

fig_labels = ['a','b','c']


for k in range(config['n_slopes']): 
   
    X = np.zeros((32, 42)) 
    X[:] = np.nan 

    
    for i in range(len(omni_rirs)) : 
        # print("before", rcvPos[i])
        curr_pos = rcvPos[i] * 10

        curr_pos = curr_pos - 2
        # print (curr_pos)
        curr_pos = curr_pos / 3
# 
        # print("after", curr_pos)

        floor = np.floor(curr_pos[0]) 
        ceil = np.ceil(curr_pos[0]) 
    

        int_curr_pos0, int_curr_pos1 = int(round(curr_pos[0], 0)), int(round(curr_pos[1],0))

        if rcvPos[i][0] > 4 : 
            int_curr_pos0 -= 1 

        if rcvPos[i][0] > 6 and rcvPos[i][1] > 5 : 
            int_curr_pos1 -= 1


    
        X[int_curr_pos0][int_curr_pos1] = mu_law_encoding(np.sum(neg_fit_result[i, 2:5, k], 0))

    X_remove_nan = X[~np.isnan(X)]
    vmin = np.abs(np.min(X_remove_nan))
    # vmin = np.min(X_remove_nan)
    vmax = np.max(X_remove_nan)
    vmax = max(vmin, vmax)
    vmin = -vmax
    

    im = axes[k].imshow(X.transpose(1,0), interpolation='none', aspect='equal', cmap=cmap, extent=[0, 10, 0, 13], norm='linear', origin='lower', resample=True, vmin=vmin, vmax=vmax) 
    if k == 2: 
        plt.colorbar(im, ax=axes[k], label=fr'Amplitude')
    else : 
        plt.colorbar(im, ax=axes[k])
    axes[k].set_title(fr"({fig_labels[k]}) $A_{k}$, $T_{k}$={ np.mean(common_decay_times[2:5,k]) :.2f}s ")
    plot_room_geometry(axes[k])
    axes[k].scatter(srcPos[0], srcPos[1], marker='x', color='green')
    axes[k].set_xticks(np.arange(0, 12, 2))
    axes[k].set_xlabel("x in meters")
axes[0].set_ylabel("y in meters")



# dpi = 600


fig.savefig(fname='../figures/simulation_room_plots_sum.pdf', bbox_inches="tight") 


# Plot errors 

In [None]:
# first remove the direct sound

print (omni_rirs.shape)

# Preprocess 
direct_cutted_rirs = [] 
mask=np.ones(10)/10


direct_thresh = -30
direct_length = int(48000 * 0.002)
print (direct_length)

noise_thresh = -70

for i in range(0, len(omni_rirs)) : 
    curr_rir = omni_rirs[i]
    log_energy = 10 * np.log10(np.convolve(curr_rir**2, mask))


    front_index = np.where(log_energy[:700] > direct_thresh)
    if len(front_index[0]) > 0 : 
        direct_loc = front_index[0][0] + direct_length
        # print ("found direct", i) 
        # plt.plot(log_energy[:5000])
        # plt.axvline(direct_loc, color='orange')
        # plt.show()
        # plt.plot(curr_rir[:5000])
        # plt.axvline(direct_loc, color='orange')
        # plt.show()
        cut_rir  = curr_rir[direct_loc:] 
    else : 
        # remove noise 
        front_index = np.where(log_energy[:2000] > noise_thresh)
        if len(front_index[0]) > 0 : 
            start_loc = front_index[0][0]
        else :
            start_loc = 2000 
        # print ("NO direct", i)
        # plt.plot(log_energy[:5000])
        # plt.axvline(start_loc, color='orange')
        # plt.show()
        # plt.plot(curr_rir[:5000])
        # plt.axvline(start_loc, color='orange')
        # plt.show()

        cut_rir = curr_rir[start_loc:] 

    if len(cut_rir) < 96000: 
        # Pad
        tmp = np.zeros((96000,))
        tmp[:len(cut_rir)] = cut_rir
        cut_rir =tmp 
    else :
        cut_rir = cut_rir[:96000]
   
    direct_cutted_rirs.append(cut_rir)
direct_cutted_rirs = np.array(direct_cutted_rirs) 

print (direct_cutted_rirs.shape)

In [None]:


L = 96000 

X_neg = np.zeros((32, 42)) 
X_neg[:] = np.nan

X_pos = np.zeros((32, 42)) 
X_pos[:] = np.nan

# For cutting away direct part
mask=np.ones(10)/10
direct_thresh = -30
# direct_length = int(48000 * 0.002)
print (direct_length)
noise_thresh = -70

downsample_rate= 400
# downsample_rate= 200

for bIdx in range(len(config['f_bands'])): 
   

    # Make time axis (downsampled version and full version)
    timeAxis = np.linspace(0, 96000 / 48000, 238) 
    timeAxis_fullLength = np.linspace(0, 96000 / 48000, L) 
    # timeAxis = np.linspace(0, 1, 238) 
    # timeAxis_fullLength = np.linspace(0, 1, L) 

    # Get the exponentials 
    envelopeTimes = 2 * common_decay_times[bIdx]
    envelopes = decay_kernel(envelopeTimes, timeAxis)

    envelopes_fullLength = decay_kernel(envelopeTimes, timeAxis_fullLength)

    

    # Filter signal by octave
    filterbank = FilterByOctaves(order=6, sample_rate=48000, backend='scipy',
                                                center_frequencies=[config['f_bands'][bIdx]])
    
    for i in range(len(omni_rirs)) : 
        curr_rir = omni_rirs[i]

        # Find where the direct location is 
        log_energy = 10 * np.log10(np.convolve(curr_rir**2, mask))
        front_index = np.where(log_energy[:700] > direct_thresh)
        if len(front_index[0]) > 0 : 
            direct_loc = front_index[0][0] + direct_length
            direct_loc_ds = direct_loc // downsample_rate + 2
        else : 
            # # remove noise 
            # front_index = np.where(log_energy[:2000] > noise_thresh)
            # if len(front_index[0]) > 0 : 
            #     direct_loc = front_index[0][0]
            # else :
            #     direct_loc = 2000 
            direct_loc_ds = 0 
        
        
        # print (direct_loc, direct_loc_ds)

        # Perform octave filtering at the current octave 
        octave_filtered_rir = filterbank(torch.FloatTensor(curr_rir))[0]
        octave_filtered_rir = octave_filtered_rir.numpy()

        # plt.plot(octave_filtered_rir)
        # plt.show()

        original_rms = np.sqrt(np.mean(octave_filtered_rir**2))
    
        original_edf = np.flipud(np.cumsum(np.flipud(octave_filtered_rir**2)))
        original_edf_db = 10 * np.log10(original_edf) 
        
        # original_envs = all_original_env[i][bIdx]
        original_envs = get_envelope(octave_filtered_rir, downsample_rate )[2+direct_loc_ds:]

        
        neg_envs=  np.dot(envelopes, neg_fit_result[i][bIdx])[direct_loc_ds:]
        pos_envs = np.dot(envelopes, pos_fit_result[i][bIdx]) [direct_loc_ds:]

        # noise = np.random.randn(L*2)
        # noise_rms = np.sqrt(np.mean(noise**2))
        # octave_filtered_noise = filterbank(torch.FloatTensor(noise))
        
        # octave_filtered_noise = octave_filtered_noise.numpy()[0]
        # octave_filtered_noise  = octave_filtered_noise[L//2 : -L//2]

        # neg_shaped_noise = neg_envs * octave_filtered_noise
        # pos_shaped_noise = pos_envs * octave_filtered_noise

        # neg_shaped_noise_rms = np.sqrt(np.mean(neg_shaped_noise**2)) 
        # neg_shaped_noise *= original_rms / neg_shaped_noise_rms 
        # pos_shaped_noise_rms = np.sqrt(np.mean(pos_shaped_noise**2)) 
        # pos_shaped_noise *= original_rms / pos_shaped_noise_rms 

        # neg_envs_calc = get_envelope(neg_shaped_noise, 400 )[2+direct_loc_ds:]
        # pos_envs_calc = get_envelope(pos_shaped_noise, 400 )[2+direct_loc_ds:] 

        # plt.plot(original_envs)
        # plt.plot(neg_envs )
        # plt.plot(pos_envs)
        # plt.show()

        # neg_edf =  np.flipud(np.cumsum(np.flipud(neg_shaped_noise**2)))
        # pos_edf =  np.flipud(np.cumsum(np.flipud(pos_shaped_noise**2)))
        # neg_edf_db = 10 * np.log10 ( neg_edf)
        # pos_edf_db = 10 * np.log10 ( pos_edf) 
        
        # neg_diff= np.mean((original_edf_db - neg_edf_db)**2) 
        # pos_diff =  np.mean((original_edf_db - pos_edf_db)**2) 
        # original_envs = original_envs* 10000
        # neg_envs_calc *= 10000
        # pos_envs_calc *= 10000
        neg_diff =  np.mean((original_envs[:] ** 0.5 - neg_envs[:] ** 0.5) **2) 
        pos_diff =  np.mean((original_envs[:] ** 0.5 - pos_envs[:] ** 0.5) **2) 
        
        neg_diff = np.sqrt(neg_diff)
        pos_diff = np.sqrt(pos_diff)

        if np.isnan(neg_diff):
            print ("nan")
            print (neg_envs_calc) 
            
        if np.isnan(pos_diff):
            print ("nan") 
            print(pos_envs_calc) 
        
        # print (neg_diff, pos_diff)
        if i > 800 and i < 810 : 
            print (neg_diff ,pos_diff) 
        #     plt.plot(original_envs**0.5)
        #     plt.plot(neg_envs_calc**0.5)
        #     plt.plot(pos_envs_calc**0.5)
        #     plt.show()

        #     plt.plot(octave_filtered_rir)
        #     plt.plot(neg_shaped_noise)
        #     plt.show() 

        #     plt.plot(octave_filtered_rir)
        #     plt.plot(pos_shaped_noise)
        #     plt.show() 


        
            # plt.plot(original_envs)
            # plt.show() 
        # print("before", rcvPos[i])
        curr_pos = rcvPos[i] * 10

        curr_pos = curr_pos - 2
        # print (curr_pos)
        curr_pos = curr_pos / 3
# 
        # print("after", curr_pos)

        floor = np.floor(curr_pos[0]) 
        ceil = np.ceil(curr_pos[0]) 
    

        int_curr_pos0, int_curr_pos1 = int(round(curr_pos[0], 0)), int(round(curr_pos[1],0))

        # if rcvPos[i][0] < 4 and rcvPos[i][1] < 8 : 
        #     continue 

        if rcvPos[i][0] > 4 : 
            int_curr_pos0 -= 1 

        if rcvPos[i][0] > 6 and rcvPos[i][1] > 5 : 
            int_curr_pos1 -= 1

        if bIdx >  0: 
            X_neg[int_curr_pos0][int_curr_pos1] += neg_diff
            X_pos[int_curr_pos0][int_curr_pos1] += pos_diff
        else : 
            X_neg[int_curr_pos0][int_curr_pos1] = neg_diff
            X_pos[int_curr_pos0][int_curr_pos1] = pos_diff

    #     break
    # break

In [None]:

cmap = plt.colormaps.get_cmap('viridis')  # viridis is the default colormap for imshow
cmap.set_bad(color='white')



fig = plt.figure(figsize=(column_width, column_width *0.9))
ax = plt.gca()

X_neg_remove_nan = X_neg[~np.isnan(X_neg)]
X_pos_remove_nan = X_pos[~np.isnan(X_pos)]
vmin = min (np.min(X_neg_remove_nan), np.min(X_pos_remove_nan)) 
vmax = max( np.max(X_neg_remove_nan), np.max(X_pos_remove_nan))

im = ax.imshow(X_neg.transpose(1,0), interpolation='none', aspect='equal', cmap=cmap, extent=[0, 10, 0, 13], norm='linear', origin='lower', resample=True, vmin=vmin, vmax=vmax) 

plt.colorbar(im, ax=ax, label=fr'RMSE')
# axes[j][k].set_title(fr"{fig_labels[j*3 + k]}) $A_{k}$ at {fBands[j]}Hz - $T_{k}$ = { common_decay_times[j+1][k] :.2f}s ")
plot_room_geometry(ax)
ax.scatter(srcPos[0], srcPos[1], marker='x', color='green')
ax.set_xticks(np.arange(0, 12, 2))
ax.set_xlabel("x in meters")
ax.set_ylabel("y in meters")        
                 
plt.savefig('../figures/simulation_error_plot_neg.pdf', bbox_inches="tight")           

fig = plt.figure(figsize=(column_width, column_width *0.9))
ax = plt.gca()
    
im = ax.imshow(X_pos.transpose(1,0), interpolation='none', aspect='equal', cmap=cmap, extent=[0, 10, 0, 13], norm='linear', origin='lower', resample=True, vmin=vmin, vmax=vmax) 
plt.colorbar(im, ax=ax, label=fr'RMSE')
# axes[j][k].set_title(fr"{fig_labels[j*3 + k]}) $A_{k}$ at {fBands[j]}Hz - $T_{k}$ = { common_decay_times[j+1][k] :.2f}s ")
plot_room_geometry(ax)
ax.scatter(srcPos[0], srcPos[1], marker='x', color='green')
ax.set_xticks(np.arange(0, 12, 2))
ax.set_xlabel("x in meters")
ax.set_ylabel("y in meters")

plt.savefig('../figures/simulation_error_plot_pos.pdf', bbox_inches="tight")                 
                     
                     

# X_neg_remove_nan = X_neg[~np.isnan(X_neg)]
# X_pos_remove_nan = X_pos[~np.isnan(X_pos)]
# vmin = min (np.min(X_neg_remove_nan), np.min(X_pos_remove_nan)) 
# vmax = max( np.max(X_neg_remove_nan), np.max(X_pos_remove_nan))

# im = axes[0].imshow(X_neg.transpose(1,0), interpolation='none', aspect='equal', cmap=cmap, extent=[0, 10, 0, 13], norm='linear', origin='lower', resample=True, vmin=vmin, vmax=vmax) 

# plt.colorbar(im, ax=axes[0], label=fr'MSE')
# # axes[j][k].set_title(fr"{fig_labels[j*3 + k]}) $A_{k}$ at {fBands[j]}Hz - $T_{k}$ = { common_decay_times[j+1][k] :.2f}s ")
# plot_room_geometry(axes[0])
# axes[0].scatter(srcPos[0], srcPos[1], marker='x', color='green')
# axes[0].set_xticks(np.arange(0, 12, 2))
# axes[0].set_xlabel("x in meters")
# axes[0].set_ylabel("y in meters")
# # axes[0].set_title("MSE of neg-allowed model")

# im = axes[1].imshow(X_pos.transpose(1,0), interpolation='none', aspect='equal', cmap=cmap, extent=[0, 10, 0, 13], norm='linear', origin='lower', resample=True, vmin=vmin, vmax=vmax) 
# plt.colorbar(im, ax=axes[1], label=fr'MSE')
# # axes[j][k].set_title(fr"{fig_labels[j*3 + k]}) $A_{k}$ at {fBands[j]}Hz - $T_{k}$ = { common_decay_times[j+1][k] :.2f}s ")
# plot_room_geometry(axes[1])
# axes[1].scatter(srcPos[0], srcPos[1], marker='x', color='green')
# axes[1].set_xticks(np.arange(0, 12, 2))
# axes[1].set_xlabel("x in meters")
# axes[1].set_ylabel("y in meters")
# # axes[1].set_title("MSE of pos-only model")


# # dpi = 600
# plt.savefig('simulation_error_plot.pdf', bbox_inches="tight") 

# Plot EDC and RIRs per room

In [None]:
# fig, axes = plt.subplots(2,1, sharey=True, sharex=True, figsize=(5,7))
# axes[0].plot(omni_rirs[197, :48000])
# axes[1].plot(omni_rirs[750, :48000])

# plt.style.use('default')

rcvPos1 = rcvPos[197][:2]
rcvPos2 = rcvPos[384][:2]
rcvPos3 = rcvPos[750][:2]


# fig = plt.figure(figsize=(10,4))
# gs = plt.GridSpec(nrows=1, ncols=3, width_ratios=[1, 1.8, 1])

fig = plt.figure(figsize=(text_width*0.25, text_width*0.35))
plt.gca().set_aspect('equal')
ax = plt.gca()
plot_room_geometry(ax)
ax.scatter(srcPos[0], srcPos[1], marker='x', linewidth=1, color='black')
ax.scatter(rcvPos1[0], rcvPos1[1], marker='*',linewidth= 1,  color='C0')
ax.scatter(rcvPos2[0], rcvPos2[1], marker='*', linewidth= 1,color='C1')
ax.scatter(rcvPos3[0], rcvPos3[1], marker='*', linewidth=1,color='C2')
# ax.set_title('a) Room layout')
ax.set_xlabel("x in meters")
ax.set_ylabel("y in meters")
# ax2.text( 5, 0.7, "Source")
ax.annotate("Source",
            xy=(2.5, 2), xycoords='data',
            xytext=(5, 0.5), textcoords='data',
            arrowprops=dict(arrowstyle="->", connectionstyle="arc3"))

ax.text( 2, 4, "R1\n(Short\nT60)",  horizontalalignment='center')
ax.text( 7, 3.2, "R2\n(Long T60)", horizontalalignment='center')
ax.text( 8, 7, "R3\n(Mid\nT60)", horizontalalignment='center')
ax.fill_betweenx([0,8],[0,0],[4,4], alpha=0.2)
ax.fill_betweenx([2,5],[4,4],[10,10], alpha=0.2)
ax.fill_betweenx([5,13],[6,6],[10,10], alpha=0.2)
ax.set_xticks([0, 2, 5, 7.5, 10],[0, 2, 5, 7.5, 10])
ax.set_xlim([0,10])
ax.set_ylim([0,12])

plt.savefig(fname='../figures/simulation_edc_rirs_A.pdf', bbox_inches="tight") 

plt.show() 


fig = plt.figure(figsize=(text_width*0.42, text_width*0.35))
ax0 = plt.gca()
xaxis = np.arange(0, 1, 1/48000) 

test_edf = np.flipud(np.cumsum(np.flipud(omni_rirs[197, :96000]**2)))
ax0.plot(xaxis,(10 * np.log10(test_edf))[:48000], label=f"In R1")
test_edf = np.flipud(np.cumsum(np.flipud(omni_rirs[384, 100:96100]**2)))
ax0.plot(xaxis,(10 * np.log10(test_edf))[:48000], label=f"In R2")
test_edf = np.flipud(np.cumsum(np.flipud(omni_rirs[750, 1000:97000]**2)))
ax0.plot(xaxis,(10 * np.log10(test_edf))[:48000], label=f"In R3")
ax0.legend()




ax0.set_ylabel("Decay level in dB")
# ax0.set_xticks(list(np.arange(0, 48000,1)))
ax0.set_xlabel("Time in seconds")
# ax0.set_title("b) Energy decay curves")
ax0.set_ylim([-60, 0]) 
ax0.set_xlim([0, 1.0]) 


plt.savefig(fname='../figures/simulation_edc_rirs_B.pdf', bbox_inches="tight") 


plt.show() 


fig = plt.figure(figsize=(text_width*0.2, text_width*0.35))
ax1 = plt.gca()

xaxis = np.arange(0, 20000/48000, 1/48000) 
ax1.plot(xaxis,omni_rirs[197, :20000], label=f"R1")
ax1.plot(xaxis,omni_rirs[384, 100:20100] - 0.1, label=f"R2")
ax1.plot(xaxis,omni_rirs[750, 1000:21000] - 0.15, label=f"R3")
ax1.set_yticks([])
# ax1.set_ylabel("Amplitude") 
ax1.set_xlabel("Time in seconds")
ax1.set_xticks([0, 0.2, 0.4],[0, 0.2, 0.4])
# ax1.set_title("c) RIRs") 

plt.savefig(fname='../figures/simulation_edc_rirs_C.pdf', bbox_inches="tight") 

# ax2 = fig.add_subplot(gs[:, 0], aspect='equal')
# # axins = ax0.inset_axes(
# #     [0.6, 0.54, 0.45, 0.45], xticklabels=[], yticklabels=[], aspect='equal')
# plot_room_geometry(ax2)
# ax2.scatter(srcPos[0], srcPos[1], marker='x', linewidth=1, color='black')
# ax2.scatter(rcvPos1[0], rcvPos1[1], marker='*',linewidth= 1,  color='C0')
# ax2.scatter(rcvPos2[0], rcvPos2[1], marker='*', linewidth= 1,color='C1')
# ax2.scatter(rcvPos3[0], rcvPos3[1], marker='*', linewidth=1,color='C2')
# ax2.set_title('a) Room layout')
# ax2.set_xlabel("x in meters")
# ax2.set_ylabel("y in meters")
# # ax2.text( 5, 0.7, "Source")
# ax2.annotate("Source",
#             xy=(2.5, 2), xycoords='data',
#             xytext=(5, 0.5), textcoords='data',
#             arrowprops=dict(arrowstyle="->", connectionstyle="arc3"))

# ax2.text( 2, 4, "Room 1\n(Short\nT60)",  horizontalalignment='center')
# ax2.text( 7, 3.2, "Room 2\n(Long T60)", horizontalalignment='center')
# ax2.text( 8, 7, "Room 3\n(Mid\nT60)", horizontalalignment='center')
# ax2.fill_betweenx([0,8],[0,0],[4,4], alpha=0.2)
# ax2.fill_betweenx([2,5],[4,4],[10,10], alpha=0.2)
# ax2.fill_betweenx([5,13],[6,6],[10,10], alpha=0.2)


# xaxis = np.arange(0, 1, 1/48000) 


# ax0 = fig.add_subplot(gs[:, 1])
# test_edf = np.flipud(np.cumsum(np.flipud(omni_rirs[197, :96000]**2)))
# ax0.plot(xaxis,(10 * np.log10(test_edf))[:48000], label=f"In room 1, [{rcvPos1[0]}m, {rcvPos1[1]}m]")
# test_edf = np.flipud(np.cumsum(np.flipud(omni_rirs[384, 100:96100]**2)))
# ax0.plot(xaxis,(10 * np.log10(test_edf))[:48000], label=f"In room 2, [{rcvPos2[0]}m, {rcvPos2[1]}m]")
# test_edf = np.flipud(np.cumsum(np.flipud(omni_rirs[750, 1000:97000]**2)))
# ax0.plot(xaxis,(10 * np.log10(test_edf))[:48000], label=f"In room 3, [{rcvPos3[0]}m, {rcvPos3[1]}m]")
# ax0.legend()




# ax0.set_ylabel("Decay level in dB")
# # ax0.set_xticks(list(np.arange(0, 48000,1)))
# ax0.set_xlabel("Time in seconds")
# ax0.set_title("b) Energy decay curves")
# ax0.set_ylim([-60, 0]) 
# ax0.set_xlim([0, 1.0]) 

# ax1 = fig.add_subplot(gs[:, 2])

# ax1.plot(xaxis,omni_rirs[197, :48000], label=f"Room 1")
# ax1.plot(xaxis,omni_rirs[384, 100:48100] - 0.1, label=f"Room 2")
# ax1.plot(xaxis,omni_rirs[750, 1000:49000] - 0.15, label=f"Room 3")
# # ax1.legend() # ncols=1, loc='upper center', bbox_to_anchor=(0.5, 1))
# # ax1.tick_params(axis='y',  left=False, labelleft=False)
# ax1.set_yticks([])

# # ax1 = fig.add_subplot(gs[0, 1])
# # ax1.plot(omni_rirs[197, :48000])
# # ax1.set_title(f"Room 1 : x={rcvPos1[0]}m, y={rcvPos1[1]}m")
# # # ax1.set_xlim([0, 48000])

# # ax2 = fig.add_subplot(gs[1, 1], sharey=ax1, sharex=ax1)
# # ax2.plot(omni_rirs[384, :48000])
# # ax2.set_title(f"Room 2 : x={rcvPos2[0]}m, y={rcvPos2[1]}m")


# # ax3 = fig.add_subplot(gs[2, 1], sharey=ax1, sharex=ax1)
# # ax3.plot(omni_rirs[750, :48000])
# # ax3.set_title(f"Room 3 : x={rcvPos3[0]}m, y={rcvPos3[1]}m")
# ax1.set_ylabel("Amplitude") 
# ax1.set_xlabel("Time in seconds")
# ax1.set_title("c) RIRs") 
# # params = dict(bottom=0, left=0, right=1)
# # ax1.subplots_adjust(**params)
# plt.tight_layout()
# # # ax3 = fig.add_axes([0.6, 0.6, 0.2, 0.2])
# # # ax3.hist(distribution)
# # # plt.show()

# fig.savefig(fname='simulation_edc_rirs.pdf', format='pdf', bbox_inches="tight") 


