For a GFDN with a single group, investigate the loss surfaces as a function of $g_{in}$ and $g_{out}$. This will be a 3D plot and give us more intuition about the training process.

In [None]:
import numpy as np
import torch
import re
import os
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.io import loadmat
import soundfile as sf
import pickle
from torch import nn
from scipy.fft import rfft, rfftfreq
from numpy.typing import ArrayLike
from typing import Optional, List, Dict
from copy import deepcopy
from IPython import display
from loguru import logger
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm

os.chdir('..')  # This changes the working directory to DiffGFDN

from diff_gfdn.dataloader import load_dataset, RIRData
from diff_gfdn.config.config import DiffGFDNConfig, CouplingMatrixType
from diff_gfdn.solver import convert_common_slopes_rir_to_room_dataset
from diff_gfdn.model import DiffGFDN, DiffGFDNSinglePos
from diff_gfdn.utils import is_unitary, db2lin, db, ms_to_samps, get_response, to_complex
from diff_gfdn.plot import plot_edr, animate_coupled_feedback_matrix
from diff_gfdn.analysis import get_decay_fit_net_params
from diff_gfdn.colorless_fdn.utils import get_colorless_fdn_params
from diff_gfdn.colorless_fdn.losses import amse_loss, sparsity_loss
from diff_gfdn.losses import edc_loss, edr_loss
from src.run_model import load_and_validate_config

In [None]:
config_path = 'data/config/'
fig_path = 'figures/'
config_name = 'single_rir_fit_broadband_two_stage_decay_colorless_prototype_pos2.yml'
config_file = config_path + config_name
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)

In [None]:
room_data = convert_common_slopes_rir_to_room_dataset(config_dict.room_dataset_path, 
                                                      num_freq_bins=config_dict.trainer_config.num_freq_bins,
                                                      )

config_dict = config_dict.model_copy(update={"num_groups": room_data.num_rooms})

trainer_config = config_dict.trainer_config
# prepare the training and validation data for DiffGFDN
if trainer_config.batch_size != room_data.num_freq_bins:
    trainer_config = trainer_config.model_copy(
        update={"batch_size": room_data.num_freq_bins})

# get the colorless FDN params
if config_dict.colorless_fdn_config.use_colorless_prototype:
    colorless_fdn_params = get_colorless_fdn_params(config_dict)
else:
    colorless_fdn_params = None

# initialise the model
model = DiffGFDNSinglePos(config_dict.sample_rate,
                 config_dict.num_groups,
                 config_dict.delay_length_samps,
                 trainer_config.device,
                 config_dict.feedback_loop_config,
                 config_dict.output_filter_config,
                 use_absorption_filters=False,
                 common_decay_times=room_data.common_decay_times,
                 colorless_fdn_params=colorless_fdn_params,
                 use_colorless_loss=trainer_config.use_colorless_loss
                 )

### Find the desired RIR params

In [None]:
match = re.search(r'ir_\([^)]+\)', config_dict.ir_path)
dir_name = Path(trainer_config.ir_dir).parts[-1]
ir_name = match.group()
approx_ir_path = f'{trainer_config.ir_dir}/approx_{ir_name}.wav'
final_approx_ir, fs = sf.read(Path(approx_ir_path).resolve())

# find receiver position from string
match = re.search(r'ir_\(([^,]+), ([^,]+), ([^,]+)\)', ir_name)
# Convert the extracted values to floats
x, y, z = map(float, match.groups())
rec_pos = np.array([x, y, z])

# find amplitudes corresponding to the receiver position
rec_pos_idx = np.argwhere(
    np.all(np.round(room_data.receiver_position,2) == rec_pos, axis=1))[0]
amplitudes = np.squeeze(room_data.amplitudes[rec_pos_idx, :])
true_ir = np.squeeze(room_data.rirs[rec_pos_idx, :])

# plot time domain RIRs
if len(true_ir) > len(final_approx_ir):
    plt.plot(np.stack((true_ir[:len(final_approx_ir)], final_approx_ir[:, 0]), axis=-1))
    end_samp = len(final_approx_ir)
else:
    plt.plot(np.stack((true_ir, final_approx_ir[:len(true_ir), 0]), axis=-1))
    end_samp = len(true_ir)

plt.xlim([0, int(1.5 * config_dict.sample_rate)])
plt.savefig(f'{fig_path}/compare_{dir_name}_{ir_name}.png')
plt.show()

In [None]:
def get_edc_params(rir: ArrayLike, n_slopes: int, fs:float):
    est_params_decay_net, norm_vals, fitted_edc_subband = get_decay_fit_net_params(rir, None, n_slopes, fs)
    est_t60 = np.mean(est_params_decay_net[0], axis=0)
    est_amp = np.mean(est_params_decay_net[1], axis=0)
    est_noise = np.mean(est_params_decay_net[2], axis=0)
    fitted_edc = torch.squeeze(torch.mean(fitted_edc_subband, dim=0))
    return est_t60, est_amp, est_noise, fitted_edc

def plot_final_decay_fit_net_edc(num_groups:int, og_amps: List, est_amps: List, og_edc: ArrayLike, synth_edc:ArrayLike):
    
    time = np.linspace(0, (len(og_edc) - 1) / fs, len(og_edc))
    fig, ax = plt.subplots(figsize=(6, 4))

    ax.plot(time, db(og_edc, is_squared=True), linestyle='-', label=f'Original estimated EDC (norm)')
    ax.plot(time[:len(synth_edc)], db(synth_edc, is_squared=True), linestyle='--', label=f'Final estimated EDC (norm)')
    ax.plot(np.zeros(num_groups), db(og_amps, is_squared=True), 'kx', label='Original amps')
    ax.plot(np.zeros(num_groups), db(est_amps, is_squared=True), 'gd', label='Synth amps')

    ax.legend()
    plt.show()
    fig.savefig(Path(f'{fig_path}/final_synth_edf_{rec_pos}_{config_dict.feedback_loop_config.coupling_matrix_type.value}_random_init_lr={trainer_config.io_lr}.png').resolve())

og_est_t60, og_est_amp, og_noise_floor, og_fitted_edc = get_edc_params(true_ir[:end_samp], config_dict.num_groups, config_dict.sample_rate)
est_t60, est_amp, _, fitted_edc = get_edc_params(final_approx_ir[:end_samp,0], config_dict.num_groups, config_dict.sample_rate)
plot_final_decay_fit_net_edc(config_dict.num_groups, og_est_amp, est_amp, og_fitted_edc, fitted_edc)

In [None]:
# create RIRDataset
rir_data = RIRData(rir=true_ir,
                   sample_rate=config_dict.sample_rate,
                   common_decay_times=room_data.common_decay_times,
                   band_centre_hz=room_data.band_centre_hz,
                   amplitudes=amplitudes,
                   nfft=config_dict.trainer_config.num_freq_bins,
                   )

# prepare the training and validation data for DiffGFDN
train_dataset = load_dataset(
    rir_data, trainer_config.device, train_valid_split_ratio=1.0,
    batch_size=trainer_config.batch_size, shuffle=False)

### Find each loss for a grid of source-receiver gain values

In [None]:
def find_multi_gain_transfer_function(model: DiffGFDN, b: torch.tensor, c: torch.tensor, data: Dict,  
                                      input_scalars: torch.tensor, output_scalars:torch.tensor, num_inp_points: int, num_out_points: int):
    """
    Calculate the Diff GFDN's transfer function for a grid of source-receiver gains specified by 
    input_scalars, output_scalars each containing num_point elements
    Args:
        input_scalars (Tensor): g_in of size num_points x num_groups x 1
        output_scalars (Tensor): g_out of size num_points x num_groups x 1
    Returns:
        Tensor: magnitude response of size num_points x num_points x num_freq_bins
    """
    z = data['z_values']
    num_freq_pts = len(z)
    c = c.unsqueeze(0).unsqueeze(-1)
    b = b.unsqueeze(0).unsqueeze(-1)
    output_scalars = output_scalars.view(*output_scalars.shape, 1)
    input_scalars = input_scalars.view(*input_scalars.shape, 1)
    
    C_init = to_complex(
        c.expand(num_out_points, model.num_delay_lines, num_freq_pts))
    B_init = to_complex(
        b.expand(num_inp_points, model.num_delay_lines, num_freq_pts))
    
    output_scalars_expanded = output_scalars.expand(
                num_out_points, model.num_groups, num_freq_pts)
    output_scalars_expanded = output_scalars_expanded.repeat_interleave(
            model.num_delay_lines_per_group, dim=1)
    # size num_out_points x num_del_lines x num_freq_bins
    C = to_complex(output_scalars_expanded)

    input_scalars_expanded = input_scalars.expand(
                num_inp_points, model.num_groups, num_freq_pts)
    input_scalars_expanded = input_scalars_expanded.repeat_interleave(
            model.num_delay_lines_per_group, dim=1)
    # size num_inp_points x num_del_lines x num_freq_bins
    B = to_complex(input_scalars_expanded)

    # of size num_points x Ndel x num_freq_points
    C *= C_init
    B *= B_init

    # get the output of the feedback loop, this is of size num_freq_points x Ndel x Ndel
    P = model.feedback_loop(z)
    # this is of size num_freq_points x num_del_lines x num_out_points
    Htemp = torch.einsum('kmb, kmn -> knb', C.permute(-1,1,0), P)
    # Htemp = torch.einsum('knm, bmk -> bkn', P, C)
    # this is of size num_inp_points x num_out_points x num_freq_points
    H = torch.einsum('knb, knc -> kbc', Htemp, B.permute(-1,1,0)).permute(-1,1,0)
    # H = torch.einsum('bnk, ckn -> bck', B, Htemp)
    direct_filter = data['target_early_response']
    direct_filter = direct_filter.view(1, 1, *direct_filter.shape).expand(num_inp_points, num_out_points, num_freq_pts)
    H += direct_filter

    if model.use_colorless_loss:
        H_sub_fdn = model.sub_fdn_output(z)
        return H, H_sub_fdn
    else:
        return H


In [None]:
# read the model parameters per epoch
checkpoint_dir = Path(trainer_config.train_dir + 'checkpoints/').resolve()
max_epochs = trainer_config.max_epochs
num_points = 100
input_scalars = torch.linspace(-2.0, 2.0, num_points)
output_scalars = torch.linspace(-2.0, 2.0, num_points)
input_scalar_grid, output_scalar_grid = torch.meshgrid(input_scalars, output_scalars)


edr_loss_val = torch.zeros((num_points, num_points))
edc_loss_val = torch.zeros_like(edr_loss_val)
criterion = [edr_loss(model.sample_rate, use_erb_grouping=trainer_config.use_erb_edr_loss, 
                      use_weight_fn=trainer_config.use_frequency_weighting),
            edc_loss(model.common_decay_times.max() * 1e3,
                     model.sample_rate)]

if trainer_config.use_colorless_loss:
    spectral_loss_val = torch.zeros_like(edr_loss_val)
    sparsity_loss_val = torch.zeros_like(edr_loss_val)
    colorless_criterion = [amse_loss(), sparsity_loss()]

# load the trained weights for the particular epoch
found_exception = True
while found_exception:
    try:
        checkpoint = torch.load(f'{checkpoint_dir}/model_e{max_epochs-1}.pt', weights_only=True, map_location=torch.device('cpu'))
        found_exception = False
    except Exception as e:
        max_epochs -= 1
        found_exception = True
    
    
# Load the trained model state
model.load_state_dict(checkpoint)
# in eval mode, no gradients are calculated
model.eval()

with torch.no_grad():
    param_dict = model.get_param_dict()
    input_gains = deepcopy(param_dict['input_gains'])
    output_gains = deepcopy(param_dict['output_gains'])
    opt_input_scalars = deepcopy(param_dict['input_scalars'])
    opt_output_scalars = deepcopy(param_dict['output_scalars'])
    logger.info(f'Opt source gain {np.round(opt_input_scalars, 3)}')
    logger.info(f'Opt receiver gain {np.round(opt_output_scalars, 3)}')

    if model.num_groups == 1:
        input_scalars = input_scalars.unsqueeze(-1)
        output_scalars = output_scalars.unsqueeze(-1)
    else:
        # because the input and output scalars are the same, we can only vary
        # the input scalars and keep the output scalars static
        input_scalars = torch.stack((opt_input_scalars[0] * torch.ones(num_points), input_scalars), dim=-1)
        output_scalars = torch.stack((output_scalars, opt_output_scalars[1] * torch.ones(num_points)), dim=-1)
        

    for data in train_dataset:
        # get the optimum losses while training
        if trainer_config.use_colorless_loss:
            Hopt, H_sub_fdn, approx_ir = get_response(data, model)
        else:
            Hopt, approx_ir = get_response(data, model)
        
        opt_edr_loss = criterion[0](data['target_rir_response'], Hopt)
        opt_edc_loss = criterion[1](data['target_rir_response'], Hopt)
        logger.info(f'Opt EDR loss {opt_edr_loss:.3f}')
        logger.info(f'Opt EDC loss {opt_edc_loss:.3f}')
        
        if trainer_config.use_colorless_loss:
            opt_spectral_loss = 0.0
            opt_sparsity_loss = 0.0
            for k in range(model.num_groups):
                group_idx = torch.arange(
                    k * model.num_delay_lines_per_group,
                    (k + 1) * model.num_delay_lines_per_group,
                    dtype=torch.int32)
                opt_spectral_loss += colorless_criterion[0](H_sub_fdn[0][..., k], 
                                                            torch.ones_like(H_sub_fdn[0][..., k])) + + colorless_criterion[0](
                                H_sub_fdn[1][group_idx, :,  k], 
                                torch.ones_like(H_sub_fdn[1][group_idx, :, k]))
                opt_sparsity_loss += colorless_criterion[1](model.feedback_loop.ortho_param(model.feedback_loop.M[k]))
        
        # get the losses for a grid of source-receiver gains
        if trainer_config.use_colorless_loss:
            H, H_sub_fdn = find_multi_gain_transfer_function(model, torch.tensor(input_gains), torch.tensor(output_gains), data, 
                                             input_scalars, output_scalars, num_points, num_points)
        else:
            H = find_multi_gain_transfer_function(model, torch.tensor(input_gains), torch.tensor(output_gains), data, 
                                             input_scalars, output_scalars, num_points, num_points)

        for i in tqdm(range(num_points)):        
            for j in range(num_points):
                    edr_loss_val[i,j] = criterion[0](data['target_rir_response'], H[i, j, :])
                    edc_loss_val[i,j] = criterion[1](data['target_rir_response'], H[i, j, :])
                    if trainer_config.use_colorless_loss:
                        for k in range(model.num_groups):
                            group_idx = torch.arange(k * model.num_delay_lines_per_group, (k + 1) * model.num_delay_lines_per_group, dtype=torch.int32)
                            spectral_loss_val[i,j] += colorless_criterion[0](H_sub_fdn[0][..., k], 
                                                                             torch.ones_like(H_sub_fdn[0][..., k])) + colorless_criterion[0](
                                H_sub_fdn[1][group_idx, :,  k], 
                                torch.ones_like(H_sub_fdn[1][group_idx, :, k]))
                            sparsity_loss_val[i,j] += colorless_criterion[1](model.feedback_loop.ortho_param(model.feedback_loop.M[k]))

### Plot each loss surface

In [None]:
# 4. Plot the 3D surface
num_losses = 4 if trainer_config.use_colorless_loss else 2
opt_losses = [opt_edr_loss, opt_edc_loss]
loss_names = ['EDR','EDC','Spectral','Sparsity']

if trainer_config.use_colorless_loss:
    losses = torch.stack((edr_loss_val, edc_loss_val, spectral_loss_val, sparsity_loss_val), dim=-1)
    opt_losses.extend([opt_spectral_loss, opt_sparsity_loss])
else:
    losses = torch.stack((edr_loss_val, edc_loss_val), dim=-1)

if config_dict.num_groups == 1:
    scatter_x = opt_input_scalars
    scatter_y = opt_output_scalars
    xlabel = 'Source gain'
    ylabel = 'Receiver gain'
else:
    scatter_x = opt_input_scalars[0]
    scatter_y = opt_input_scalars[-1]
    xlabel = 'Receiver gain 2'
    ylabel = 'Receiver gain 1'

opt_x_idx = np.argmin(np.abs(input_scalars[:, -1].numpy() - opt_input_scalars[0]))
opt_y_idx = np.argmin(np.abs(output_scalars[:, 0].numpy() - opt_output_scalars[1]))
logger.info(f'Calc EDR loss at optimum values: {edr_loss_val[opt_x_idx, opt_y_idx]}')
logger.info(f'Calc EDC loss at optimum values: {edc_loss_val[opt_x_idx, opt_y_idx]}')
    
fig = plt.figure(figsize=(10, 15))
for k in range(num_losses):
    ax = fig.add_subplot(num_losses, 1, k+1, projection='3d')
    surface = ax.plot_surface(input_scalar_grid, output_scalar_grid, losses[..., k].squeeze(), cmap='viridis', edgecolor='none', alpha=0.5)
    ax.scatter(scatter_y, scatter_x, opt_losses[k], color='r', marker= 'x', s=100, label='Optimum loss', alpha=1.0)
    ax.scatter(input_scalars[opt_x_idx, -1], output_scalars[opt_y_idx, 0], losses[opt_x_idx, opt_y_idx, k], 
               color='k', marker= 'x', s=100, label='Grid loss closest to opt values', alpha=1.0)


    # Add labels and title
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_zlabel(f'{loss_names[k]} Loss')
    
    # Add color bar
    fig.colorbar(surface, ax=ax, shrink=0.5, aspect=10)

plt.show()
fig.savefig(f'{fig_path}/{config_name}_loss_surfaces_masked_edc.png', bbox_inches='tight', pad_inches=0)


### Interactive plot with Plotly

In [None]:
# Create a Plotly figure
from plotly.subplots import make_subplots
import plotly.graph_objects as go
fig = go.Figure()

# Add 3D surface and scatter for each loss type
for k in range(num_losses):
    fig = go.Figure()

    # Add 3D surface for current loss
    fig.add_trace(go.Surface(
        z=losses[..., k].squeeze().numpy(),  # Losses data
        x=input_scalar_grid,  # X coordinates (input_scalar_grid)
        y=output_scalar_grid,  # Y coordinates (output_scalar_grid)
        colorscale='Viridis',  # Colormap
        opacity=0.5,
        showscale=True,
        colorbar=dict(title=f'{loss_names[k]} Loss',
                      thickness=20,  # Adjust thickness to make the colorbar narrower
                        len=0.3,  # Adjust length to make the colorbar shorter
                        x=1.05,  # Adjust x to place colorbar outside the plot
                        y=0.5,   # Adjust y to place colorbar at the middle
                        ticks='outside'  # Position ticks outside the colorbar
        ),
        name=f'{loss_names[k]} Surface'
    ))

    # Add scatter for optimal points
    fig.add_trace(go.Scatter3d(
        x=[scatter_x], 
        y=[scatter_y], 
        z=opt_losses[k].numpy(), 
        mode='markers',
        marker=dict(color='red', size=2, symbol='x'),
        name='Optimum Loss'
    ))

    # fig.add_trace(go.Scatter3d(
    #     x=[input_scalars[opt_x_idx,-1]], 
    #     y=[output_scalars[opt_y_idx,0]], 
    #     z=[losses[opt_x_idx, opt_y_idx, k]], 
    #     mode='markers',
    #     marker=dict(color='black', size=2, symbol='x'),
    #     name='Grid loss closest to opt values'
    # ))

    # Update layout to add axis labels and titles
    fig.update_layout(
        scene=dict(
            xaxis_title=xlabel,
            yaxis_title=ylabel,
            zaxis_title='Loss Value',
        ),
        # xaxis=dict(automargin=True),
        # yaxis=dict(automargin=True),
        margin=dict(l=0, r=0, b=0, t=40),
        height=500,  # Adjust the height to make it bigger
        width=600,
        # for top vie
        scene_camera=dict(
        eye=dict(x=0, y=0, z=2.3)  # z=2 moves the camera directly above
        )
    )

    # Show the figure
    fig.show()

    # Save the figure as PNG
    fig.write_image(f'{fig_path}/{config_name}_{loss_names[k]}_surfaces_masked_edc.png', width=600, height=500, scale=2)