In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pyfar as pf
import soundfile as sf
import os
import torch
import librosa
import pickle
import IPython
from numpy.typing import ArrayLike
from pathlib import Path
from importlib import reload
from scipy.signal import fftconvolve
from copy import deepcopy
from loguru import logger

os.chdir('..')  # This changes the working directory to DiffGFDN
from diff_gfdn.dataloader import ThreeRoomDataset, load_dataset
from diff_gfdn.config.config import DiffGFDNConfig
from diff_gfdn.model import DiffGFDNVarReceiverPos
from diff_gfdn.utils import ms_to_samps, db, get_response
from diff_gfdn.plot import plot_spectrogram
from diff_gfdn.losses import get_stft_torch
from diff_gfdn.colorless_fdn.utils import get_colorless_fdn_params
from slope2noise.generate import shaped_wgn


from src.run_model import load_and_validate_config
from src.run_subband_training_treble import sum_arrays
from src.sound_examples import dynamic_rendering_moving_receiver

In [None]:
config_path = 'data/config/'
fig_path = 'figures/'
audio_path = 'audio/sound_examples/'
out_path = 'output/'
config_name = 'treble_data_grid_training_full_band_colorless_loss'
config_file = config_path + f'{config_name}.yml'
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)
room_data = ThreeRoomDataset(Path(config_dict.room_dataset_path).resolve(), config_dict)
room_data.mixing_time_ms = 50

### Get the stimulus and resample it

In [None]:
from diff_gfdn.utils import ms_to_samps

sig_type = 'drums'

speech_data = pf.signals.files.drums() if sig_type == 'drums' else pf.signals.files.speech()
speech = np.squeeze(speech_data.time)
fs = speech_data.sampling_rate

if fs != room_data.sample_rate:
    speech = librosa.resample(speech, orig_sr = fs, target_sr = room_data.sample_rate)

# add some silence at the end
silence = np.zeros(ms_to_samps(500, room_data.sample_rate))
speech_app = np.concatenate((speech, silence))
                   
save_path = Path(f'{audio_path}/stimulus/{sig_type}.wav').resolve()
sf.write(save_path, speech_app, room_data.sample_rate) 
IPython.display.Audio(save_path)

### Create a trajectory of a listener moving across the space

In [None]:
# along x axis between three rooms
start_pos_x, start_pos_y = (0.5, 3.5)
end_pos_x, end_pos_y = (9, 3.5)
num_pos = 50

linear_trajectory_x = np.linspace(start_pos_x, end_pos_x, num_pos)
linear_trajectory_y = np.linspace(start_pos_y, end_pos_y, num_pos)
linear_trajectory_z = 1.5 * np.ones(num_pos)

rec_pos_list = np.zeros((num_pos, 3))
rec_pos_list[:, 0] = linear_trajectory_x
rec_pos_list[:, 1] = linear_trajectory_y
rec_pos_list[:, 2] = linear_trajectory_z

# along y-axis between rooms 2 and 3
start_pos_x, start_pos_y = (9.1, 3.5)
end_pos_x, end_pos_y = (9.0, 12.0)
num_pos = 68

linear_trajectory_x = np.linspace(start_pos_x, end_pos_x, num_pos)
linear_trajectory_y = np.linspace(start_pos_y, end_pos_y, num_pos)
linear_trajectory_z = 1.5 * np.ones(num_pos)

rec_pos_list = np.vstack((rec_pos_list, np.vstack((linear_trajectory_x, linear_trajectory_y, linear_trajectory_z)).T))


### Get the common slope RIRs for all measured positions in the dataset

In [None]:
cs_pickle_path = Path(f'{out_path}/treble_data_grid_common_slopes.pkl').resolve()

if not os.path.exists(cs_pickle_path):
    # synthesise for all positions - this is slow
    decay_times = np.squeeze(room_data.common_decay_times)
    ir_length_samps = int(2 * room_data.sample_rate)
    t_vals_expanded = np.repeat(np.array(decay_times.T)[np.newaxis, ...],
                                        room_data.num_rec,
                                        axis=0)
    batch_size = room_data.num_rec
    num_batches = int(np.ceil(float(room_data.num_rec) / batch_size))
    ls_est_rirs = np.zeros((room_data.num_rec, ir_length_samps))
    
    for n in range(num_batches):
        batch_idx = np.arange(n * batch_size,
                              max(room_data.num_rec, (n + 1) * batch_size),
                              dtype=np.int32)
        _, ls_est_rirs[batch_idx, :] = shaped_wgn(t_vals_expanded[batch_idx, ...], 
                                                  room_data.amplitudes[batch_idx, ...], 
                                                  room_data.sample_rate, 
                                                  ir_length_samps, 
                                                  room_data.band_centre_hz, 
                                                  # n_vals=np.squeeze(room_data.noise_floor[batch_idx, ...])
                                                 )
    # update the RIRs
    cs_room_data = deepcopy(room_data)
    cs_room_data.update_rirs(ls_est_rirs)
    
    # Save to a file
    with open(cs_pickle_path, "wb") as f:
        pickle.dump(cs_room_data, f)
else:
    with open(cs_pickle_path, "rb") as f:
        cs_room_data= pickle.load(f)


### Get the full band DiffGFDN solution at the trajectory positions - note that the early rirs are all wrong

In [None]:
full_gfdn_pickle_path = Path(f'{out_path}/sound_examples/treble_data_moving_listener_fullband_gfdn.pkl').resolve()

if not os.path.exists(full_gfdn_pickle_path):
    # add number of groups to the config dictionary
    config_dict = config_dict.model_copy(update={"num_groups": room_data.num_rooms})
    
    if config_dict.sample_rate != room_data.sample_rate:
        logger.warn("Config sample rate does not match data, alterning it")
        config_dict.sample_rate = sample_rate
    
    # get the training config
    trainer_config = config_dict.trainer_config
    
    # force the trainer config device to be CPU
    if trainer_config.device != 'cpu':
        trainer_config = trainer_config.model_copy(update={"device": 'cpu'})
    
    full_gfdn_room_data = deepcopy(room_data)
    full_gfdn_room_data.update_receiver_pos(rec_pos_list)
    
    # prepare the training and validation data for DiffGFDN
    train_dataset, valid_dataset = load_dataset(
        full_gfdn_room_data, trainer_config.device, train_valid_split_ratio=1.0,
        batch_size=trainer_config.batch_size, shuffle=False)
    
    # initialise the model
    model = DiffGFDNVarReceiverPos(full_gfdn_room_data.sample_rate, 
                                   full_gfdn_room_data.num_rooms,
                                   config_dict.delay_length_samps,
                                   trainer_config.device, 
                                   config_dict.feedback_loop_config,
                                   config_dict.output_filter_config,
                                   config_dict.decay_filter_config.use_absorption_filters,
                                   common_decay_times=full_gfdn_room_data.common_decay_times,
                                   band_centre_hz=full_gfdn_room_data.band_centre_hz,
                                )
    
    # load the trained weights for the particular epoch
    max_epochs = trainer_config.max_epochs
    checkpoint_dir = Path(trainer_config.train_dir + 'checkpoints/').resolve()
    checkpoint = torch.load(f'{checkpoint_dir}/model_e{max_epochs-1}.pt', weights_only=True, map_location=torch.device('cpu'))
    # Load the trained model state
    model.load_state_dict(checkpoint)
    # in eval mode, no gradients are calculated
    model.eval()
    all_fullband_pos = []
    all_fullband_rirs = []
    
    for data in train_dataset:
        position = data['listener_position']
        H, h = get_response(data, model)    
        for num_pos in range(position.shape[0]):
            # collate all RIRs at all positions
            all_fullband_pos.append(position[num_pos])
            all_fullband_rirs.append(h[num_pos, ...])

    full_gfdn_room_data.update_receiver_pos(np.asarray(all_fullband_pos))
    full_gfdn_room_data.update_rirs(np.asarray(all_fullband_rirs))

    # Save to a file
    with open(full_gfdn_pickle_path, "wb") as f:
        pickle.dump(full_gfdn_room_data, f)
    
else:
    with open(full_gfdn_pickle_path, "rb") as f:
        full_gfdn_room_data= pickle.load(f)


# some plotting to investigate
rec_pos_idx = 34
dist = np.linalg.norm(full_gfdn_room_data.receiver_position[rec_pos_idx, :] - room_data.receiver_position, axis=-1)
ref_rir_idx = np.argmin(dist, axis=0)
plt.plot(full_gfdn_room_data.late_rirs[rec_pos_idx, :])
plt.plot(room_data.late_rirs[ref_rir_idx, :])

S_true, freqs, time_frames = get_stft_torch(torch.tensor(room_data.late_rirs[rec_pos_idx, :]), 
                                       room_data.sample_rate, win_size=2**12, hop_size=2**11, nfft=2**12)
S, freqs, time_frames = get_stft_torch(torch.tensor(full_gfdn_room_data.late_rirs[rec_pos_idx, :]), 
                                       room_data.sample_rate, win_size=2**12, hop_size=2**11, nfft=2**12)

plot_spectrogram(db(torch.abs(S_true)), freqs, time_frames, title='Ref RIR', log_freq_axis=True)
plot_spectrogram(db(torch.abs(S)), freqs, time_frames, title='Fullband DiffGFDN', log_freq_axis=True)

### Get the subband DiffGFDN solution at the trajectory positions

In [None]:
subband_gfdn_pickle_path = Path(f'{out_path}/sound_examples/treble_data_moving_listener_subband_gfdn.pkl').resolve()

if not os.path.exists(subband_gfdn_pickle_path):
    subband_filters, _ = pf.dsp.filter.reconstructing_fractional_octave_bands(
        None,
        num_fractions=1,
        frequency_range=(room_data.band_centre_hz[0], room_data.band_centre_hz[-1]),
        sampling_rate=room_data.sample_rate,
    )
    
    synth_subband_rirs = pd.DataFrame(columns=[
        'frequency', 'position', 'filtered_time_samples'
    ])

    # loop through all subband frequencies
    for k in range(len(room_data.band_centre_hz)):
        logger.info(
            f'Running inferencing for subband = {room_data.band_centre_hz[k]} Hz')

        config_name = f'treble_data_grid_training_{room_data.band_centre_hz[k]}Hz_colorless_loss'
        config_dict = load_and_validate_config(config_path + f'{config_name}.yml', DiffGFDNConfig)
        sub_room_data = ThreeRoomDataset(
            Path(config_dict.room_dataset_path).resolve(), config_dict)

        sub_gfdn_room_data = deepcopy(sub_room_data)
        sub_gfdn_room_data.update_receiver_pos(rec_pos_list)

        config_dict = config_dict.model_copy(
            update={"num_groups": sub_room_data.num_rooms})
        trainer_config = config_dict.trainer_config

        # force the trainer config device to be CPU
        if trainer_config.device != 'cpu':
            trainer_config = trainer_config.model_copy(
                update={"device": 'cpu'})

        # prepare the training and validation data for DiffGFDN
        train_dataset, _ = load_dataset(
            sub_gfdn_room_data,
            trainer_config.device,
            train_valid_split_ratio=1.0,
            batch_size=trainer_config.batch_size,
            shuffle=False)

        if config_dict.colorless_fdn_config.use_colorless_prototype:
            colorless_fdn_params = get_colorless_fdn_params(config_dict)
        else:
            colorless_fdn_params = None

        # initialise the model
        model = DiffGFDNVarReceiverPos(
            config_dict.sample_rate,
            config_dict.num_groups,
            config_dict.delay_length_samps,
            trainer_config.device,
            config_dict.feedback_loop_config,
            config_dict.output_filter_config,
            use_absorption_filters=config_dict.decay_filter_config.use_absorption_filters,
            common_decay_times=sub_gfdn_room_data.common_decay_times if
                               config_dict.decay_filter_config.initialise_with_opt_values else None,
            learn_common_decay_times=config_dict.decay_filter_config.learn_common_decay_times,
            use_colorless_loss=trainer_config.use_colorless_loss,
            colorless_fdn_params=colorless_fdn_params)

        checkpoint_dir = Path(trainer_config.train_dir +
                              'checkpoints/').resolve()

        # load the trained weights for the particular epoch
        checkpoint = torch.load(
            f'{checkpoint_dir}/model_e{trainer_config.max_epochs-1}.pt',
            weights_only=True,
            map_location=torch.device('cpu'))
        # Load the trained model state
        model.load_state_dict(checkpoint)
        # in eval mode, no gradients are calculated
        model.eval()

        # loop through all positions
        for data in train_dataset:
            position = data['listener_position'].detach().cpu().numpy()

            if model.use_colorless_loss:
                _, _, h = get_response(data, model)
            else:
                _, h = get_response(data, model)

            # loop over all positions for a particular frequency band and add it to a dataframe
            for num_pos in range(position.shape[0]):
                cur_rir = h[num_pos, :].detach().cpu().numpy()
                cur_rir_filtered = fftconvolve(
                    cur_rir,
                    subband_filters.coefficients[k, :],
                    mode='same')

                # position should be saved as tuple because numpy array is unhashable
                new_row = pd.DataFrame({
                    'frequency': [room_data.band_centre_hz[k]],
                    'position':
                    [(position[num_pos, 0], position[num_pos,
                                                     1], position[num_pos,
                                                                  2])],
                    'filtered_time_samples': [cur_rir_filtered],
                })
                synth_subband_rirs = pd.concat(
                    [synth_subband_rirs, new_row], ignore_index=True)

    synth_rirs = synth_subband_rirs.groupby('position').apply(sum_arrays)

    # Convert to DataFrame if needed
    synth_rirs_df = synth_rirs.reset_index()
    synth_rirs_df.columns = ['position', 'filtered_time_samples']

    subband_gfdn_room_data = deepcopy(room_data)

    
    subband_gfdn_room_data.update_receiver_pos(np.array(synth_rirs_df['position'].to_list()))
    subband_gfdn_room_data.update_rirs(np.vstack(synth_rirs_df['filtered_time_samples']))

     # Save to a file
    with open(subband_gfdn_pickle_path, "wb") as f:
        pickle.dump(subband_gfdn_room_data, f)

else:
    with open(subband_gfdn_pickle_path, "rb") as f:
        subband_gfdn_room_data = pickle.load(f)


# some plotting to investigate
rec_pos_idx = 34
dist = np.linalg.norm(subband_gfdn_room_data.receiver_position[rec_pos_idx, :] - room_data.receiver_position, axis=-1)
ref_rir_idx = np.argmin(dist, axis=0)
plt.plot(subband_gfdn_room_data.late_rirs[rec_pos_idx, :])
plt.plot(room_data.late_rirs[ref_rir_idx, :])

S_true, freqs, time_frames = get_stft_torch(torch.tensor(room_data.late_rirs[rec_pos_idx, :]), 
                                       room_data.sample_rate, win_size=2**12, hop_size=2**11, nfft=2**12)
S, freqs, time_frames = get_stft_torch(torch.tensor(subband_gfdn_room_data.late_rirs[rec_pos_idx, :]), 
                                       room_data.sample_rate, win_size=2**12, hop_size=2**11, nfft=2**12)

plot_spectrogram(db(torch.abs(S_true)), freqs, time_frames, title='Ref RIR', log_freq_axis=True)
plot_spectrogram(db(torch.abs(S)), freqs, time_frames, title='Subband DiffGFDN', log_freq_axis=True)

### Animate the trajectory

In [None]:
import src
reload(src.sound_examples)
from src.sound_examples import dynamic_rendering_moving_receiver

update_ms = 250 #should be a factor of 1s

dynamic_renderer = dynamic_rendering_moving_receiver(room_data, rec_pos_list, speech_app, update_ms=update_ms)
ani_save_path = Path(f'{fig_path}/sound_examples/treble_data').resolve()
dynamic_renderer.animate_moving_listener(ani_save_path)

# cross-fading convolution with the reference set of RIRs
ref_output = dynamic_renderer.filter_overlap_add()
save_path = Path(f'{audio_path}/reference_moving_listener_{sig_type}.wav').resolve()
sf.write(save_path, ref_output, room_data.sample_rate) 
IPython.display.Audio(save_path)

dynamic_renderer.combine_animation_and_sound(f'{ani_save_path}_moving_listener.mp4', f'{save_path}', f'{ani_save_path}_reference_{sig_type}')

In [None]:
# cross-fading convolution with CS RIRs
dynamic_renderer = dynamic_rendering_moving_receiver(cs_room_data, rec_pos_list, speech_app, update_ms=update_ms)
cs_output = dynamic_renderer.filter_overlap_add()
save_path = Path(f'{audio_path}/cs_moving_listener_{sig_type}.wav').resolve()
sf.write(save_path, cs_output, room_data.sample_rate) 
IPython.display.Audio(save_path)

dynamic_renderer.combine_animation_and_sound(f'{ani_save_path}_moving_listener.mp4', f'{save_path}', f'{ani_save_path}_cs_{sig_type}')
del dynamic_renderer

In [None]:
# cross-fading convolution with fullband GFDN RIRs
dynamic_renderer = dynamic_rendering_moving_receiver(full_gfdn_room_data, rec_pos_list, speech_app, update_ms=update_ms)
full_gfdn_output = dynamic_renderer.filter_overlap_add()
save_path = Path(f'{audio_path}/fullband_gfdn_moving_listener_{sig_type}.wav').resolve()
sf.write(save_path, full_gfdn_output, room_data.sample_rate) 
IPython.display.Audio(save_path)

dynamic_renderer.combine_animation_and_sound(f'{ani_save_path}_moving_listener.mp4', f'{save_path}', f'{ani_save_path}_fullband_DiffGFDN_{sig_type}')
del dynamic_renderer

In [None]:
# cross-fading convolution with subband GFDN RIRs
dynamic_renderer = dynamic_rendering_moving_receiver(subband_gfdn_room_data, rec_pos_list, speech_app, update_ms=update_ms)
sub_gfdn_output = dynamic_renderer.filter_overlap_add()
save_path = Path(f'{audio_path}/subband_gfdn_moving_listener_{sig_type}.wav').resolve()
sf.write(save_path, sub_gfdn_output, room_data.sample_rate) 
IPython.display.Audio(save_path)

dynamic_renderer.combine_animation_and_sound(f'{ani_save_path}_moving_listener.mp4', f'{save_path}', 
                                             f'{ani_save_path}_subband_DiffGFDN_{sig_type}_colorless_prototype')
del dynamic_renderer