In [None]:
import numpy as np
import matplotlib.pyplot as plt
from numpy.typing import ArrayLike
from typing import Union
scale = 2
plt.rcParams.update({
    'font.size': scale * 8,  # base font size
    'axes.labelsize': scale * 7,  # x/y label
    'xtick.labelsize': scale * 7,
    'ytick.labelsize': scale * 7,
    'legend.fontsize': scale * 6,
    'axes.titlesize': scale * 8,  # usually unused in journal figures

})
import os
os.chdir('..')  # This changes the working directory to DiffGFDN

In [None]:
def gfdn_flops(N: ArrayLike, B: int, is_parallel:bool=False):
    if is_parallel:
        return B*(2*np.power(N, 2) + 4*N + 1)
    else:
       return 2*np.power(N, 2) + N + 27*N*B + 1

def mlp_flops(num_layers: Union[int,ArrayLike], num_neurons: Union[int, ArrayLike], F:int) -> ArrayLike:
    return num_layers * (2*np.power(num_neurons,2) + num_neurons) + F*(2*num_neurons + 1)
    

In [None]:
B = 8
N = np.arange(6, 24, 4)
flops_single = gfdn_flops(N, B)
flops_parallel = gfdn_flops(N, B, is_parallel=True)
plt.figure()
plt.plot(N, np.vstack((flops_single, flops_parallel)).T, '-x')
plt.xlabel('Number of delay lines')
plt.ylabel('FLOPS')
plt.legend(['Single frequency-dep GFDN', 'Sum of parallel GFDNs'])
plt.tight_layout()
plt.savefig('figures/compare_flops_subband_GFDNs.png')
plt.show()

In [None]:
num_layers = np.arange(1, 10, 1)
num_neurons = np.power(2, np.arange(5, 10))
NL, NN = np.meshgrid(num_layers, num_neurons)
num_groups = 2
B = 8
flops_mlp_single = mlp_flops(NL, NN, 4 * num_groups * B)
flops_mlp_parallel = B * mlp_flops(NL, NN, 2*num_groups)

fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(NL, NN, np.log10(flops_mlp_single))
ax.plot_surface(NL, NN, np.log10(flops_mlp_parallel))
ax.set_xlabel('Number of hidden layers')
ax.set_ylabel('Number of neurons')
ax.set_zlabel('Log_10 (FLOPS)')
ax.legend(['Single frequency-dep. GFDN', 'Sum of parallel GFDNs'])
ax.view_init(elev=10, azim=-45)  # Elevation (vertical), Azimuth (horizontal)
fig.tight_layout()
fig.savefig('figures/compare_flops_mlp.png')
plt.show()

### Plot storage requirements

In [None]:
#### For MLP
from diff_gfdn.utils import ms_to_samps

def mlp_mems(num_layers: Union[int,ArrayLike], num_neurons: Union[int, ArrayLike], F:int) -> ArrayLike:
    return num_layers * (np.power(num_neurons,2) + num_neurons) + F*(num_neurons + 1)

def gfdn_mems(fs: float, num_del_lines: int, avg_del_line_len_ms: float, num_groups : int):
    matrix_elems = (num_del_lines // num_groups)**2
    num_modes = ms_to_samps(avg_del_line_len_ms, fs) * num_del_lines
    return num_modes + matrix_elems + 2*num_del_lines + 3*num_groups

In [None]:
num_layers = np.arange(1, 10, 1)
num_neurons = np.power(2, np.arange(5, 10))
NL, NN = np.meshgrid(num_layers, num_neurons)
num_groups = 2
B = 8
mems_mlp_parallel = B * mlp_mems(NL, NN, 2*num_groups)

fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(NL, NN, np.log10(mems_mlp_parallel))
ax.set_xlabel('Number of hidden layers')
ax.set_ylabel('Number of neurons')
ax.set_zlabel('Log_10 (FLOPS)')
ax.legend(['Sum of parallel GFDNs'])
ax.view_init(elev=10, azim=-45)  # Elevation (vertical), Azimuth (horizontal)
fig.tight_layout()
fig.savefig('figures/compare_mems_mlp.png')
plt.show()

In [None]:
fs = 44100
num_groups = 2
N = np.arange(6, 24, 4)
avg_del_line_len_ms = 20

num_layers = np.arange(1, 10, 3)
num_neurons = np.power(2, np.arange(3, 7))
NL, NN = np.meshgrid(num_layers, num_neurons)
mems_mlp_parallel = B * mlp_mems(NL, NN, 2*num_groups)
mems_mlp_flat = np.ravel(mems_mlp_parallel)
NL_flat = np.ravel(NL)
NN_flat = np.ravel(NN)

plt.figure()
for k in range(len(mems_mlp_flat)):
    mems_gfdn = B * gfdn_mems(fs, N, avg_del_line_len_ms, num_groups) + mems_mlp_flat[k]
    plt.plot(N, mems_gfdn * 4 / 1000, '-x', label=fr'$A = {NN_flat[k]},\ N_{{\text{{layers}}}} = {NL_flat[k]}$')
    # plt.plot(N, mems_gfdn * 4 / 1000, '-x', label=f'A = {NN_flat[k]}, NL = {NL_flat[k]}')

                                                                                   
plt.xlabel('Number of delay lines')
plt.ylabel('Storage in kB')
plt.legend(loc='upper left', bbox_to_anchor=(1.05, 1))
plt.tight_layout()
plt.savefig('figures/compare_mems_subband_GFDNs.png')
plt.show()