In [None]:
import soundfile as sf
import numpy as np
from numpy.typing import ArrayLike
from pathlib import Path
from tqdm import tqdm
from scipy.signal import fftconvolve, sosfilt
import matplotlib.pyplot as plt
import pandas as pd
import pyfar as pf
import os
import torch
import pickle

os.chdir('..')  # This changes the working directory to DiffGFDN
from slope2noise.generate import shaped_wgn
from slope2noise.utils import schroeder_backward_int
from diff_gfdn.dataloader import ThreeRoomDataset, load_dataset
from diff_gfdn.config.config import DiffGFDNConfig
from diff_gfdn.model import DiffGFDNVarReceiverPos
from diff_gfdn.utils import ms_to_samps, db, get_response
from diff_gfdn.plot import plot_edc_error_in_space, plot_edr, plot_edr_error_in_space

from src.run_model import load_and_validate_config
from src.run_subband_training_treble import sum_arrays

In [None]:
config_path = 'data/config/'
config_name = 'treble_data_grid_training_full_band_colorless_loss'
config_file = config_path + f'{config_name}.yml'
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)
room_data = ThreeRoomDataset(Path(config_dict.room_dataset_path).resolve(), config_dict)

### Pick a position

In [None]:
audio_path  = "audio/"
fig_path = "figures/"
out_path = "output/"
plot_ir = True
use_fixed_pos = True
if use_fixed_pos:
    pos_to_investigate =  [9.3, 6.6, 1.5] #[2.3, 2.6, 1.5] [2.0, 6.8, 1.5] [6.4, 3.8, 1.5]
else:
    rec_idx = np.random.randint(0, high=room_data.num_rec, size=1, dtype=int)
    pos_to_investigate = np.round(np.squeeze(room_data.receiver_position[rec_idx,:]), 2)

decay_times = np.squeeze(room_data.common_decay_times)
band_centre_hz = room_data.band_centre_hz

# find amplitudes corresponding to the receiver position
rec_pos_idx = np.argwhere(
    np.all(np.round(room_data.receiver_position,2) == pos_to_investigate, axis=1))[0]

amps_at_pos = np.squeeze(room_data.amplitudes[rec_pos_idx, ...])
desired_filename = f'ir_({pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f}).wav'

h_true = np.squeeze(room_data.rirs[rec_pos_idx, :int(2*config_dict.sample_rate)])

### Get the full band DiffGFDN solution at all positions

In [None]:
# add number of groups to the config dictionary
config_dict = config_dict.model_copy(update={"num_groups": room_data.num_rooms})

if config_dict.sample_rate != room_data.sample_rate:
    logger.warn("Config sample rate does not match data, alterning it")
    config_dict.sample_rate = sample_rate

# get the training config
trainer_config = config_dict.trainer_config

# force the trainer config device to be CPU
if trainer_config.device != 'cpu':
    trainer_config = trainer_config.model_copy(update={"device": 'cpu'})

# prepare the training and validation data for DiffGFDN
train_dataset, valid_dataset = load_dataset(
    room_data, trainer_config.device, train_valid_split_ratio=1.0,
    batch_size=trainer_config.batch_size, shuffle=False)

# initialise the model
model = DiffGFDNVarReceiverPos(room_data.sample_rate, room_data.num_rooms,
                 config_dict.delay_length_samps,
                 trainer_config.device, 
                 config_dict.feedback_loop_config,
                 config_dict.output_filter_config,
                 config_dict.decay_filter_config.use_absorption_filters,
                 common_decay_times=room_data.common_decay_times,
                 band_centre_hz=room_data.band_centre_hz,
                )

# load the trained weights for the particular epoch
max_epochs = trainer_config.max_epochs
checkpoint_dir = Path(trainer_config.train_dir + 'checkpoints/').resolve()
checkpoint = torch.load(f'{checkpoint_dir}/model_e{max_epochs-1}.pt', weights_only=True, map_location=torch.device('cpu'))
# Load the trained model state
model.load_state_dict(checkpoint)
# in eval mode, no gradients are calculated
model.eval()
all_fullband_pos = []
all_fullband_rirs = []

for data in train_dataset:
    position = data['listener_position']
    H, h = get_response(data, model)
    
    for num_pos in range(position.shape[0]):
        filename = f'ir_({position[num_pos,0]:.2f}, {position[num_pos, 1]:.2f}, {position[num_pos, 2]:.2f}).wav'

        # collate all RIRs at all positions
        all_fullband_pos.append(position[num_pos])
        all_fullband_rirs.append(h[num_pos, ...])

        if plot_ir and filename == desired_filename:
            h_full_gfdn = h[num_pos, ...]

### Get the subband GFDN solution at all positions

In [None]:
# read the saved DiffGFDN RIRs
subband_file_path = Path('output/treble_data_grid_training_final_rirs_colorless_loss_learnt_decay_times_opt_init.pkl')
synth_subband_rirs = pd.read_pickle(subband_file_path)

subband_filters, _ = pf.dsp.filter.reconstructing_fractional_octave_bands(
                        None,
                        num_fractions=1,
                        frequency_range=(room_data.band_centre_hz[0], room_data.band_centre_hz[-1]),
                        sampling_rate=room_data.sample_rate,
                        )

# group by position and compensate for filterbank energy
synth_rirs = synth_subband_rirs.groupby('position').apply(sum_arrays)
synth_rirs_df = synth_rirs.reset_index()
synth_rirs_df.columns = ['position', 'filtered_time_samples']
all_subband_rirs = np.vstack(synth_rirs_df['filtered_time_samples'])
all_subband_pos = np.array(synth_rirs_df['position'].tolist())

subband_rec_pos_idx = synth_rirs_df[synth_rirs_df['position'] == tuple(pos_to_investigate)].index[0]
data_at_pos = synth_rirs_df.iloc[subband_rec_pos_idx]
h_subband_gfdn = data_at_pos['filtered_time_samples']

ir_len_samps = min(len(h_subband_gfdn), h_full_gfdn.shape[0], len(h_true))
h_true = h_true[:ir_len_samps]
h_subband_gfdn = h_subband_gfdn[:ir_len_samps]
h_full_gfdn = h_full_gfdn[:ir_len_samps]

### Get the common slopes solution at all positions

In [None]:
from importlib import reload
import slope2noise
reload(slope2noise.utils)
reload(slope2noise.generate)
from slope2noise.generate import shaped_wgn

ir_len_samps = len(h_true)
synth_all_pos = False
plot_cs_norm = False

if synth_all_pos:
    cs_pickle_path = Path(f'{out_path}/treble_data_grid_common_slopes.pkl').resolve()

    if not os.path.exists(cs_pickle_path):
        # synthesise for all positions - this is slow
        t_vals_expanded = np.repeat(np.array(decay_times.T)[np.newaxis, ...],
                                            room_data.num_rec,
                                            axis=0)
        batch_size = room_data.num_rec
        num_batches = int(np.ceil(float(room_data.num_rec) / batch_size))
        ls_est_rirs = np.zeros((room_data.num_rec, ir_len_samps))
        
        for n in tqdm(range(num_batches)):
            batch_idx = np.arange(n * batch_size,
                                  max(room_data.num_rec, (n + 1) * batch_size),
                                  dtype=np.int32)
            _, ls_est_rirs[batch_idx, :] = shaped_wgn(t_vals_expanded[batch_idx, ...], 
                                                      room_data.amplitudes[batch_idx, ...], 
                                                      room_data.sample_rate, 
                                                      ir_len_samps, 
                                                      band_centre_hz, 
                                                      # n_vals=np.squeeze(room_data.noise_floor[batch_idx, ...])
                                                     )
            if plot_cs_norm:
                _, ls_est_rirs_norm[batch_idx, :] = shaped_wgn(t_vals_expanded[batch_idx, ...], 
                                                               room_data.amplitudes_norm[batch_idx, ...], 
                                                               room_data.sample_rate, 
                                                               ir_len_samps, band_centre_hz)
        
        h_ls = np.squeeze(ls_est_rirs[rec_pos_idx, :])
        if plot_cs_norm:
            h_ls_norm = np.squeeze(ls_est_rirs_norm[rec_pos_idx, :])
    else:
        with open(cs_pickle_path, "rb") as f:
            cs_room_data= pickle.load(f)
            ls_est_rirs = cs_room_data.rirs
            h_ls = np.squeeze(ls_est_rirs[rec_pos_idx, :])
else:
    # synthesise for one position
    t_vals_ls = np.array(decay_times.T)[np.newaxis, ...]
    _, h_ls = shaped_wgn(t_vals_ls, room_data.amplitudes[rec_pos_idx, ...], 
                         room_data.sample_rate, 
                         ir_len_samps, 
                         f_bands=band_centre_hz, 
                         # n_vals=np.squeeze(room_data.noise_floor[rec_pos_idx], axis=0),
                         use_pyfar_filterbank=True)
    h_ls = np.squeeze(h_ls)

    if plot_cs_norm:
        _, h_ls_norm = shaped_wgn(t_vals_ls, room_data.amplitudes_norm[rec_pos_idx, ...], 
                             room_data.sample_rate, 
                             ir_len_samps, 
                             f_bands=band_centre_hz, 
                             # n_vals=np.squeeze(room_data.noise_floor_norm[rec_pos_idx], axis=0),
                             use_pyfar_filterbank=True)
        h_ls_norm = np.squeeze(h_ls_norm)

    sf.write(Path(f'{audio_path}/common_slopes_full_band/{desired_filename}_pyfar.wav').resolve(), h_ls, room_data.sample_rate)

### Plot the EDCs at picked position

In [None]:
fs = room_data.sample_rate
mixing_time_samp = ms_to_samps(20.0, fs)
crop_end_samp = ms_to_samps(5.0, fs)
norm_flag = False

trunc_true_ir = h_true[mixing_time_samp:-crop_end_samp] 
true_edf = schroeder_backward_int(trunc_true_ir, normalize=norm_flag, discard_last_zeros=False)
time = np.linspace(0, (len(trunc_true_ir) - 1) / fs,
                   len(trunc_true_ir))

#### Use DecayFitNet to get the EDC fit from the subband GFDN RIRs, and compare it to the true subband RIRs

In [None]:
def plot_edc_in_each_subband(h_sub_true: ArrayLike, h_sub_gfdn: ArrayLike, h_broad_gfdn: ArrayLike, 
                            fs: float, freq: int, normalize: bool=True):
    mixing_time_samp = ms_to_samps(20.0, fs)
    crop_end_samp = ms_to_samps(5.0, fs)
    
    trunc_true_ir = h_sub_true[mixing_time_samp:-crop_end_samp]
    true_subband_edf = schroeder_backward_int(trunc_true_ir, normalize=normalize)
    time = np.linspace(0, (len(trunc_true_ir) - 1) / fs,
                   len(trunc_true_ir))

    fig, ax = plt.subplots(figsize=(4, 3))
    ax.plot(time, db(true_subband_edf, is_squared=True), label=f'True EDF, freq = {freq} Hz')
    ax.set_title(
        f'Truncated EDF at position {pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f} m'
    )
    
    trunc_approx_sub_gfdn = h_sub_gfdn[mixing_time_samp:-crop_end_samp]
    synth_edf_subband = schroeder_backward_int(trunc_approx_sub_gfdn, normalize=normalize)
    trunc_approx_broad_gfdn = h_broad_gfdn[mixing_time_samp:-crop_end_samp]
    synth_edf_broadband = schroeder_backward_int(trunc_approx_broad_gfdn, normalize=normalize)
    ax.plot(time, db(synth_edf_broadband, is_squared=True), label=f'Broadbabd DiffGFDN freq = {freq} Hz')
    ax.plot(time, db(synth_edf_subband, is_squared=True), label=f'Subband DiffGFDN freq = {freq} Hz')
    ax.legend(loc='upper center', bbox_to_anchor=(0, 1.5))
    return true_subband_edf, synth_edf_subband, synth_edf_broadband


# get EDC params for each subband using DecayFitNet
freq_bands = room_data.band_centre_hz
n_bands = len(freq_bands)
t_vals = np.zeros((n_bands, room_data.num_rooms))
a_vals = np.zeros_like(t_vals)
n_vals = np.zeros(n_bands)
synth_edf_subband_dfn = np.zeros((n_bands, len(time)))
true_edf_subband = np.zeros((n_bands, len(time)))
synth_edf_subband_gfdn = np.zeros_like(true_edf_subband)
synth_edf_subband_dfn_norm = np.zeros_like(true_edf_subband)
synth_edf_broadband_gfdn = np.zeros_like(true_edf_subband)
h_broadband_gfdn = np.zeros(ir_len_samps)

# prepare the filterbank filterbank
subband_filters = pf.dsp.filter.fractional_octave_bands(
    None,
    num_fractions=1,
    frequency_range=(room_data.band_centre_hz[0], room_data.band_centre_hz[-1]),
    sampling_rate=fs,
)

for freq, idx in zip(freq_bands, range(n_bands)):
    broadband_gfdn_rir = synth_subband_rirs.loc[(synth_subband_rirs['frequency'] == freq) 
                     & (synth_subband_rirs['position'] == tuple(pos_to_investigate)), 'time_samples'].values[0]
   
    h_broadband_gfdn += broadband_gfdn_rir[:ir_len_samps]

    subband_gfdn_rir = synth_subband_rirs.loc[(synth_subband_rirs['frequency'] == freq) 
                 & (synth_subband_rirs['position'] == tuple(pos_to_investigate)), 'filtered_time_samples'].values[0]

    subband_ref_rir = sosfilt(subband_filters.coefficients[idx, ...],h_true)
    
    true_edf_subband[idx, :], synth_edf_subband_gfdn[idx, :], synth_edf_broadband_gfdn[idx, :] = plot_edc_in_each_subband(subband_ref_rir[:ir_len_samps], 
                                                                subband_gfdn_rir[:ir_len_samps],
                                                                broadband_gfdn_rir[:ir_len_samps],
                                                                fs, freq, normalize=False)

#### Compare the broadband RIR's EDC

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(time, db(true_edf, is_squared=True), label='True EDC')
ax.set_title(
    f'Truncated EDF at position {pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f} m'
)

trunc_approx_ir_subband_gfdn = h_subband_gfdn[mixing_time_samp:-crop_end_samp]
synth_edf_subband = schroeder_backward_int(trunc_approx_ir_subband_gfdn, normalize=norm_flag, discard_last_zeros=False)
ax.plot(time, db(synth_edf_subband, is_squared= True), label='Subband DiffGFDN')

trunc_approx_ir_fullband_gfdn = h_full_gfdn[mixing_time_samp:-crop_end_samp].detach().numpy()
synth_edf_fullband = schroeder_backward_int(trunc_approx_ir_fullband_gfdn, normalize=norm_flag, discard_last_zeros=False)
ax.plot(time, db(synth_edf_fullband, is_squared=True), label='Full band DiffGFDN')

trunc_approx_ir_common_slopes = h_ls[mixing_time_samp:-crop_end_samp]
synth_edf_common_slopes = schroeder_backward_int(trunc_approx_ir_common_slopes, normalize=norm_flag, discard_last_zeros=False)
ax.plot(time, db(synth_edf_common_slopes, is_squared=True), label='Common slopes model')

if plot_cs_norm:
    trunc_approx_ir_common_slopes_norm = h_ls_norm[mixing_time_samp:-crop_end_samp]
    synth_edf_common_slopes_norm = schroeder_backward_int(trunc_approx_ir_common_slopes_norm, normalize=norm_flag, discard_last_zeros=False)
    ax.plot(time, db(synth_edf_common_slopes_norm, is_squared=True), label='Common slopes model Georg params')

ax.legend()
fig.savefig(Path(f'{fig_path}/compare_edf_{pos_to_investigate}_treble_data_all_models.png').resolve())

### Plot EDR at picked position

In [None]:
import diff_gfdn
reload(diff_gfdn.plot)
from diff_gfdn.plot import plot_spectrogram, plot_edr
plot_edr(torch.tensor(h_true), model.sample_rate, title=f'True EDR', log_freq_axis=True,
                             save_path=f'{fig_path}/true_edr_treble_fullband_{pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f}.png')
plot_edr(torch.tensor(h_subband_gfdn), model.sample_rate, 
         title=f'Estimated EDR, Subband GFDN', log_freq_axis=True,
         save_path=f'{fig_path}/est_edr_treble_subband_gfdn_{pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f}.png'
        )
plot_edr(torch.tensor(h_full_gfdn), model.sample_rate, 
         title=f'Estimated EDR, Fullband GFDN', log_freq_axis=True,
         save_path=f'{fig_path}/est_edr_treble_fullband_gfdn_{pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f}.png'
        )

plot_edr(torch.tensor(h_ls), model.sample_rate, 
         title=f'Estimated EDR, Common slopes', log_freq_axis=True,
         save_path=f'{fig_path}/est_edr_treble_common_slopes_{pos_to_investigate[0]:.2f}, {pos_to_investigate[1]:.2f}, {pos_to_investigate[2]:.2f}.png'
        )

### Plot the mean EDC/EDR error

In [None]:
plot_edc_error_in_space(room_data, all_fullband_rirs, all_fullband_pos, freq_to_plot=None, 
                        save_path=f'{fig_path}/avg_edc_error_treble_fullband_training', 
                        norm_edc=False)

In [None]:
plot_edc_error_in_space(room_data, all_subband_rirs, all_subband_pos, freq_to_plot=None, 
                        save_path=f'{fig_path}/avg_edc_error_treble_subband_training_learnt_decay_times', 
                        norm_edc=False)

In [None]:
plot_edc_error_in_space(room_data, ls_est_rirs, room_data.receiver_position, freq_to_plot=None, 
                        save_path=f'{fig_path}/avg_edc_error_treble_common_slopes_model', pos_sorted=True, norm_edc=False)

In [None]:
import slope2noise
import diff_gfdn
reload(slope2noise.utils)
reload(slope2noise.rooms)
reload(diff_gfdn.plot)
from diff_gfdn.plot import plot_edr_error_in_space

plot_edr_error_in_space(room_data, all_fullband_rirs, all_fullband_pos, 
                        save_path=f'{fig_path}/avg_edr_error_treble_fullband_training', 
                       )

In [None]:
plot_edr_error_in_space(room_data, all_subband_rirs, all_subband_pos, 
                        save_path=f'{fig_path}/avg_edr_error_treble_subband_training')

In [None]:
plot_edr_error_in_space(room_data, ls_est_rirs, room_data.receiver_position,
                        save_path=f'{fig_path}/avg_edr_error_treble_common_slopes_model', pos_sorted=True)