### In this notebook, we want to investigate the colouration when we filter and sum multiple subband GFDNs

In [None]:
import numpy as np
import torch
import re
import pyfar as pf
import os
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.io import loadmat
from torch import nn
from numpy.typing import ArrayLike, NDArray
from typing import Optional, List
from IPython import display
import soundfile as sf
from loguru import logger
from copy import deepcopy

os.chdir('..')  # This changes the working directory to DiffGFDN

from diff_gfdn.dataloader import load_dataset, RIRData, ThreeRoomDataset, RoomDataset
from diff_gfdn.config.config import DiffGFDNConfig, CouplingMatrixType, SubbandProcessingConfig
from diff_gfdn.solver import convert_common_slopes_rir_to_room_dataset
from diff_gfdn.model import DiffGFDNVarReceiverPos
from diff_gfdn.utils import is_unitary, db2lin, db, ms_to_samps, get_response, spectral_flatness
from diff_gfdn.plot import plot_magnitude_response
from diff_gfdn.colorless_fdn.utils import get_colorless_fdn_params
from diff_gfdn.colorless_fdn.model import ColorlessFDN
from slope2noise.utils import schroeder_backward_int
from src.run_model import load_and_validate_config

In [None]:
freqs = [63, 125, 250, 500, 1000, 2000, 4000, 8000]

for k in range(len(freqs)):
    config_path = 'data/config/'
    fig_path = 'figures/'
    config_name = f'treble_data_grid_training_{freqs[k]}Hz_colorless_loss'
    config_file = config_path + config_name
    config_dict = load_and_validate_config(config_file + '.yml',
                                           DiffGFDNConfig)
    room_data = ThreeRoomDataset(Path(config_dict.room_dataset_path).resolve(), config_dict)
    dataset_has_cs_params = True    
    config_dict = config_dict.model_copy(update={"num_groups": room_data.num_rooms})
    trainer_config = config_dict.trainer_config

    ## initialise variables
    if k == 0:
        H_fdn_init = torch.zeros((trainer_config.num_freq_bins // 2 + 1, config_dict.num_groups, len(freqs)), 
                            dtype=torch.complex64, requires_grad=False)
        H_fdn_subband_init = torch.zeros_like(H_fdn_init)
        H_fdn_final = torch.zeros_like(H_fdn_init)
        H_fdn_subband_final = torch.zeros_like(H_fdn_final)
    
    # force the trainer config device to be CPU
    if trainer_config.device != 'cpu':
        trainer_config = trainer_config.model_copy(update={"device": 'cpu'})
    
    # prepare the training and validation data for DiffGFDN
    train_dataset, valid_dataset = load_dataset(
        room_data, trainer_config.device, train_valid_split_ratio=1.0,
        batch_size=trainer_config.batch_size, shuffle=False)
    
    # get the colorless FDN params
    if config_dict.colorless_fdn_config.use_colorless_prototype:
        colorless_fdn_params = get_colorless_fdn_params(config_dict)
    else:
        colorless_fdn_params = None
    
    # initialise the model
    model = DiffGFDNVarReceiverPos(config_dict.sample_rate,
                     config_dict.num_groups,
                     config_dict.delay_length_samps,
                     trainer_config.device,
                     config_dict.feedback_loop_config,
                     config_dict.output_filter_config,
                     config_dict.decay_filter_config.use_absorption_filters,
                     common_decay_times=room_data.common_decay_times if config_dict.decay_filter_config.initialise_with_opt_values else None,
                     learn_common_decay_times=config_dict.decay_filter_config.learn_common_decay_times,
                     colorless_fdn_params=colorless_fdn_params,
                     use_colorless_loss=trainer_config.use_colorless_loss
                     )

    ### Get each FDN's magnitude response
    freq_bins_rad = torch.tensor(room_data.freq_bins_rad)
    freq_bins_hz = room_data.freq_bins_hz
    z_values = torch.polar(torch.ones_like(freq_bins_rad),
                           freq_bins_rad * 2 * np.pi)
    
    # load the trained weights for the particular epoch
    max_epochs = trainer_config.max_epochs
    checkpoint_dir = Path(trainer_config.train_dir + 'checkpoints/').resolve()

    init_checkpoint = torch.load(f'{checkpoint_dir}/model_e-1.pt',
                                 weights_only=True,
                                 map_location=torch.device('cpu'))
    model.load_state_dict(init_checkpoint, strict=False)
    model.eval()
    H_fdn_init[...,k], _ = model.sub_fdn_output(z_values)

    final_checkpoint = torch.load(f'{checkpoint_dir}/model_e{max_epochs-1}.pt',
                                  weights_only=True,
                                  map_location=torch.device('cpu'))
    # Load the trained model state
    model.load_state_dict(final_checkpoint, strict=False)
    model.eval()
    H_fdn_final[..., k] , _ = model.sub_fdn_output(z_values)

    ### Filter the sub-FDN outputs with the pyfar filters, sum them and observe colouration
    subband_filters, subband_freqs = pf.dsp.filter.reconstructing_fractional_octave_bands(
                None,
                num_fractions=trainer_config.subband_process_config.num_fraction_octaves,
                frequency_range=trainer_config.subband_process_config.frequency_range,
                sampling_rate=config_dict.sample_rate,
            )
    subband_filter_idx = np.argmin(
        np.abs(subband_freqs -
               trainer_config.subband_process_config.centre_frequency))
    subband_filter = torch.tensor(
        subband_filters.coefficients[subband_filter_idx])
    subband_filter_freq_resp = torch.fft.rfft(
        subband_filter, n=trainer_config.num_freq_bins)
    H_fdn_subband_init[..., k] = H_fdn_init[..., k] * subband_filter_freq_resp.unsqueeze(1)
    H_fdn_subband_final[..., k] = H_fdn_final[..., k] * subband_filter_freq_resp.unsqueeze(1)

    
    # plot colouration of individual FDNs
    save_path = f'{fig_path}/{config_name}_mag_spectrum.png'
    plot_magnitude_response(room_data, config_dict, model, save_path)

#### Plot colouration of filtered and summed FDNs

In [None]:
def mag_response_plot(room_data:RoomDataset, H_sub_fdn_filtered: torch.tensor, freq_bins_hz:torch.tensor, freq_labels:List[int], 
                      title:Optional[str]=None):
    """Plot magnitude response of filtered sub-FDN responses
    Args:
        H_sub_fdn_filtered is of shape num_bins x num_groups x num_freq_bands
    """
    # Create subplots
    fig, axes = plt.subplots(room_data.num_rooms,
                             1,
                             figsize=(8, 10),
                             sharex=True)

    for i in range(room_data.num_rooms):
       
        axes[i].semilogx(
            freq_bins_hz,
            db(H_sub_fdn_filtered[:, i, :].detach().numpy()),
            label=[f"{freq_labels[k]} Hz" for k in range(len(freq_labels))],
            linestyle="-",
            alpha=0.8,
        )

        axes[i].semilogx(freq_bins_hz, db(torch.sum(H_sub_fdn_filtered[:,i,:], dim=-1).detach().numpy()), 
                         label='summed', linestyle='--')

        axes[i].set_ylabel("Magnitude (dB)")
        axes[i].set_xlabel('Frequencies (Hz)')
        axes[i].set_title(f"FDN {i+1}")
        axes[i].grid(True)
        axes[i].set_xlim([20, 16000])

        logger.info(
            f'FDN {i+1} spectral flatness is {spectral_flatness(db(torch.sum(H_sub_fdn_filtered[:, i, :], dim=-1).detach().numpy())):.3f}'
        )
    
    axes[i].legend(loc="upper left", bbox_to_anchor=(1, 1))
    if title is not None:
        fig.suptitle(title)

mag_response_plot(room_data, H_fdn_subband_init, freq_bins_hz, freqs, title="Initialisation")

In [None]:
mag_response_plot(room_data, H_fdn_subband_final, freq_bins_hz, freqs, title="Post optimisation")