### This notebook shows how to analyze and interpret a trained S-Edge model
[Paper-Link](https://link.springer.com/article/10.1007/s10994-025-06807-z)

In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from tqdm import tqdm
import os
import sys
import importlib
import torch
from src.utils.experimentManager import ExperimentManagerLoadFunction, ExperimentManagerReadExistingEntry
from copy import deepcopy

# matplotlib graph of the eigenvalues
import matplotlib.pyplot as plt

seed = 1234
# set seed for pytorch and numpy
torch.manual_seed(seed)
np.random.seed(seed)


transparent_color_list_plt = [
    (31/255, 120/255, 180/255, 0.3),   # Deep Blue
    (51/255, 160/255, 44/255, 0.3),    # Forest Green
    (227/255, 26/255, 28/255, 0.3),    # Berry Red
    (255/255, 127/255, 0/255, 0.3),    # Bright Orange
    (106/255, 61/255, 154/255, 0.3),   # Deep Purple
    (177/255, 89/255, 40/255, 0.3),    # Sienna Brown
    (166/255, 206/255, 227/255, 0.3),  # Sky Blue
    (178/255, 223/255, 138/255, 0.3),  # Light Green
    (251/255, 154/255, 153/255, 0.3),  # Soft Red
    (253/255, 191/255, 111/255, 0.3)   # Apricot Orange
]

color_list_plt = [
    (31/255, 120/255, 180/255),   # Deep Blue
    (51/255, 160/255, 44/255),    # Forest Green
    (227/255, 26/255, 28/255),    # Berry Red
    (255/255, 127/255, 0/255),    # Bright Orange
    (106/255, 61/255, 154/255),   # Deep Purple
    (177/255, 89/255, 40/255),    # Sienna Brown
    (166/255, 206/255, 227/255),  # Sky Blue
    (178/255, 223/255, 138/255),  # Light Green
    (251/255, 154/255, 153/255),  # Soft Red
    (253/255, 191/255, 111/255)   # Apricot Orange
]

SMALL_SIZE = 10
MEDIUM_SIZE = 10
BIGGER_SIZE = 10

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

### Model Loading

In [None]:
# Specify results path and run number
results_path = './results_journal/'
run_number = 0

In [None]:
# Load the model configuration and paths
model_save_path = ExperimentManagerLoadFunction(results_path, run=run_number)
model_config = ExperimentManagerReadExistingEntry(model_save_path)

# Dynamic import of the model, from backup folder in the model_save_path
try:
    spec = importlib.util.spec_from_file_location("src.model.classifier", os.path.join(model_save_path, 'backup', 'src', 'model', 'classifier.py'))
    print(spec)
    ssm_model = importlib.util.module_from_spec(spec)
    print(ssm_model)
    sys.modules["module.name"] = ssm_model
    spec.loader.exec_module(ssm_model)

    SC_Model_classifier = ssm_model.SC_Model_classifier
except Exception as e:
    print("Error importing model from backup folder")
    print(e)
    print("Switing to default model")
    from src.model.classifier import SC_Model_classifier


model_init = SC_Model_classifier(input_size=model_config['input_size'],
                            classes=model_config['classes'],
                            hidden_sizes=model_config['hidden_sizes'],
                            output_sizes=model_config['output_sizes'],
                            ZeroOrderHoldRegularization=model_config['zeroOrderHoldRegularization'],
                            input_bias=model_config['input_bias'],
                            bias_init=model_config['bias_init'],
                            output_bias=model_config['output_bias'],
                            norm=model_config['norm'],
                            complex_output=model_config['complex_output'],
                            norm_type=model_config['norm_type'],
                            B_C_init=model_config['B_C_init'],
                            stability=model_config['stability'],
                            trainable_SkipLayer=model_config['trainable_skip_connections'],
                            act=model_config['act'],
                            )
# model_trained = SC_Model_classifier(input_size=1, hidden_size=50, output_size=35, n_layer=10, input_bias=True, output_bias=False, complex_output=False, norm=False, norm_type='ln', B_C_init='orthogonal', stability='abs')
model_trained = deepcopy(model_init)

model_trained.load_state_dict(torch.load(os.path.join(model_save_path, 'best_valid_loss_model.pt'), map_location=torch.device('cpu')),strict=False)


### Eigenvalue Analysis

##### Continuous Eigenvalues
True/Real continuous-time dynamcis as a scaling of the eigenvalues and time scales. (See section 3.3 within the paper)
$$\widetilde{\lambda}_{r} = \Delta \circ \widetilde{\lambda}$$

In [None]:
# Comparison between initialized lamndas and trained lambdas
lambdas_init  = [] 
lambdas_trained = []
time_scales_init = []
time_scales_trained = []

real_dynamics_init = []
real_dynamics_trained = []

for i in range(len(model_init.seq)):
    time_scales_init.append(np.exp(model_init.seq[i].s5.seq.log_step.data.cpu().detach().numpy()))
    lambdas = model_init.seq[i].s5.seq.Lambda.data.cpu().detach().numpy()
    lambdas_init.append(lambdas)
    time_scales_trained.append(np.exp(model_trained.seq[i].s5.seq.log_step.data.cpu().detach().numpy()))
    lambdas = model_trained.seq[i].s5.seq.Lambda.data.cpu().detach().numpy()
    lambdas[:,0] = -np.abs(lambdas[:,0]) # Ensure real part is negative for stability also done with abs during training
    lambdas_trained.append(lambdas)

    real_dynamics_init.append(np.expand_dims(time_scales_init[i], 1)*lambdas_init[i])
    real_dynamics_trained.append(np.expand_dims(time_scales_trained[i],1)*lambdas_trained[i])


In [None]:
# Plot the individual eigenvalues 
fig = go.Figure()
for i in range(len(lambdas_init)):
    fig.add_trace(go.Scatter(x=lambdas_init[i][:,0], y=lambdas_init[i][:,1], mode='markers', name=f'layer{i}_init'))
    fig.add_trace(go.Scatter(x=lambdas_trained[i][:,0], y=lambdas_trained[i][:,1], mode='markers', name=f'layer{i}_trained'))

fig.update_xaxes(title='Real part')
fig.update_yaxes(title='Imaginary part')
fig.update_layout(title='Eigenvalues')
fig.show()

# plot the eigenvalues scale with time scales -- true/real continuous-time dynamics
fig = go.Figure()
step_scales = 2**np.arange(0,3)
mi = 0
for i in range(len(lambdas_init)):
    fig.add_trace(go.Scatter(x=real_dynamics_init[i][:,0], y=real_dynamics_init[i][:,1], mode='markers', name=f'layer{i}_init'))
    fig.add_trace(go.Scatter(x=real_dynamics_trained[i][:,0], y=real_dynamics_trained[i][:,1], mode='markers', name=f'layer{i}_trained'))
    # plot connecting lines
    fig.add_trace(go.Scatter(x=[real_dynamics_init[i][0,0],real_dynamics_trained[i][0,0]], y=[real_dynamics_init[i][0,1],real_dynamics_trained[i][0,1]], line=dict(color="grey"), opacity=0.3, mode='lines',legendgroup=f'connection_{i}', name=f'connection_{i}'))
    for q in range(len(real_dynamics_init[i][:,0])):
        fig.add_trace(go.Scatter(x=[real_dynamics_init[i][q,0],real_dynamics_trained[i][q,0]], y=[real_dynamics_init[i][q,1],real_dynamics_trained[i][q,1]],showlegend=False, line=dict(color="grey"), opacity=0.3,mode='lines',legendgroup=f'connection_{i}', name=f'connection_{i}'))
    mi = np.min([mi,np.min(real_dynamics_init[i][:,0])])
    mi = np.min([mi,np.min(real_dynamics_trained[i][:,0])])

for step_scale in step_scales:
    fig.add_vline(x=-3/step_scale, line_dash="dash", line_color="black", line_width=1, opacity=0.3)
    fig.add_annotation(x=-3/step_scale, y=0, text=f'1/{step_scale}', showarrow=False, yshift=10, xshift=10, font=dict(size=10))
    fig.add_hline(y=(np.pi)/step_scale, line_dash="dash", line_color="black", line_width=1, opacity=0.3)
    fig.add_hline(y=-(np.pi)/step_scale, line_dash="dash", line_color="black", line_width=1, opacity=0.3)
    fig.add_annotation(x=0, y=(np.pi)/step_scale, text=f'pi/{step_scale}', showarrow=False, yshift=10, xshift=10, font=dict(size=10))
    fig.add_annotation(x=0, y=-(np.pi)/step_scale, text=f'-pi/{step_scale}', showarrow=False, yshift=10, xshift=10, font=dict(size=10))

# add title and labels
fig.update_xaxes(range=[mi*1.2,0])
fig.update_xaxes(title='Real part')
fig.update_yaxes(title='Imaginary part')
fig.update_layout(title='Eigenvalues scaled by time scales -- Real dynamics')
fig.show()


#### Discrete eigenvalues
$$\widetilde{\lambda}_\mathrm{d} = e^{\Delta\circ\widetilde{\lambda}\mathrm{T_s}}$$

In [None]:
step_scale = 1
fig = go.Figure()
for i in range(len(lambdas_init)):
    discrete_lambdas_init = np.exp((real_dynamics_init[i][:,0]+1j*real_dynamics_init[i][:,1])*step_scale)
    discrete_lambdas_trained = np.exp((real_dynamics_trained[i][:,0]+1j*real_dynamics_trained[i][:,1])*step_scale)

    fig.add_trace(go.Scatter(x=np.real(discrete_lambdas_init), y=np.imag(discrete_lambdas_init), mode='markers', name=f'layer{i}_init'))
    fig.add_trace(go.Scatter(x=np.real(discrete_lambdas_trained), y=np.imag(discrete_lambdas_trained), mode='markers', name=f'layer{i}_trained'))

fig.add_trace(go.Scatter(x=np.cos(np.linspace(0,2*np.pi,100)), y=np.sin(np.linspace(0,2*np.pi,100)), mode='lines', line=dict(color="grey"), opacity=0.3,showlegend=False))
# define squqare size of the plot
fig.update_xaxes(range=[-1.2,1.2])
fig.update_yaxes(range=[-1.2,1.2])
fig.update_xaxes(title='Real part')
fig.update_yaxes(title='Imaginary part')
fig.update_layout(title='Discrete Eigenvalues scaled by time scales -- Real dynamics')
# with and height 
fig.update_layout(width=600, height=600)
fig.show()

### Transfer function and Discretization/Aliasign analysis

In [None]:
# UTILS
def as_complex(t: torch.Tensor, dtype=torch.complex64):
    assert t.shape[-1] == 2, "as_complex can only be done on tensors with shape=(...,2)"
    nt = torch.complex(t[..., 0], t[..., 1])
    if nt.dtype != dtype:
        nt = nt.type(dtype)
    return nt

def export_layer_parameters(model_trained, step_scale_list = [1]*10):
    layer_dict  = {}
    layer_dict_c = {}
    for i in range(len(model_trained.seq)):
        A_continious = model_trained.seq[i].s5.seq.Lambda.detach()
        A_continious[:,0] = -abs(A_continious[:,0])
        B_continious = model_trained.seq[i].s5.seq.B.detach()
        B_bias_continious = model_trained.seq[i].s5.seq.B_bias.detach()
        C = model_trained.seq[i].s5.seq.C.detach()
        C_bias = model_trained.seq[i].s5.seq.C_bias.detach()#.numpy()
        try:
            SkipLayer = model_trained.seq[i].skipLayer.weight.detach().numpy()
            #print('real skiplayer:', SkipLayer.shape)
        except:
            SkipLayer = np.ones((C.shape[0], B_continious.shape[1]))
            #print('ones skiplayer:', SkipLayer.shape)
        SkipLayer = SkipLayer + 1j * np.zeros_like(SkipLayer) # complex representation with zero imaginary part        

        A_continious = as_complex(A_continious)
        B_continious = as_complex(B_continious) 
        B_bias_continious = as_complex(B_bias_continious)
        C_continious = as_complex(C).numpy() 
        C = as_complex(C)
        C_bias = as_complex(C_bias)

        step_scale = step_scale_list[i]
        delta =  torch.exp(model_trained.seq[i].s5.seq.log_step).detach()
        #print(delta.shape)
        #print(step_scale)
        step = step_scale*delta
        A_d, Bb_d = model_trained.seq[i].s5.seq.discretize(A_continious, B_continious, B_bias_continious, step, model_trained.seq[i].s5.seq.input_bias)
        B_d = Bb_d[:,0:-1]
        B_bias_d = Bb_d[:,-1]

        layer_dict_c[i] = {f'A': (delta*A_continious).numpy(), f'B': (delta.unsqueeze(1)*B_continious).numpy(), f'B_bias': (delta.unsqueeze(1)*B_bias_continious).numpy(), f'C': C_continious, f'C_bias': C_bias, 'SkipLayer': SkipLayer}
        layer_dict[i] = {f'A': A_d.detach().numpy(), f'B': B_d.detach().numpy(), f'B_bias': B_bias_d.detach().numpy(), f'C': C.detach().numpy(), f'C_bias': C_bias.detach().numpy(), 'SkipLayer': SkipLayer, 'step_scale': step_scale}

    # Add final Linear Layer
    layer_dict[i+1] = {'W': model_trained.decoder.weight.detach().numpy(), 'b': model_trained.decoder.bias.detach().numpy()} 
    
    return layer_dict, layer_dict_c

# transfer function of continuous model G(s) = C (sI - A)^-1 B + D(SkipLayer)
def transfer_function_Gs(A, B, C, SkipLayer=None, min_freq=-4, Ta_max=1, max_points = 250):
    max_freq = np.pi/Ta_max
    # Define CUDA tensors for your variables if they aren't already
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    min_freq = torch.tensor(min_freq, device=device)  # Replace min_freq, max_freq with the appropriate values
    max_freq = torch.tensor(max_freq, device=device)
    max_points = torch.tensor(max_points, device=device)

    # map A, B, C SkipLayer to torch tensors
    A = torch.tensor(A, device=device)
    B = torch.tensor(B, device=device)
    C = torch.tensor(C, device=device)

    A_conj = torch.conj(A)
    B_conj = torch.conj(B)
    C_conj = torch.conj(C)

    I = torch.eye(A.size(0), device=device)  # Identity matrix of appropriate size
    if SkipLayer is not None:
        SkipLayer = torch.tensor(SkipLayer, device=device)
        SkipLayer_conj = torch.conj(SkipLayer)

    # Define the identity matrix I of the same size as A
    s = 1j * torch.logspace(min_freq, max_freq, steps=max_points, device=device)  # Complex frequency range  # Frequency range for G(s) (complex frequency s = jω)
    #s = 1j *  torch.linspace(0, max_freq, max_points)  # Frequency range for G(s) (complex frequency s = jω)
    #s_idx = s.imag <= np.pi/Ta_max
    
    s_idx = s.imag <= max_freq
    s = s[s_idx]

    #print('max_freq', max_freq.shape)
    #print('s', s.shape)
    
    if SkipLayer is not None:
        # Compute G(s)
        G_s = torch.stack([(C @ torch.linalg.inv(si * I - A) @ B + SkipLayer) for si in s])
        #print(G_s.shape)
        G_s_conj = torch.stack([(C_conj @ torch.linalg.inv(si * I - A_conj) @ B_conj + SkipLayer_conj) for si in s])
        #print(G_s_conj.shape)
     
    else:# Compute the transfer function G(s) at each frequency
        G_s = torch.stack([(C @ torch.linalg.inv(si * I - A) @ B ) for si in s])
        #print(G_s.shape)
        G_s_conj = torch.stack([(C_conj @ torch.linalg.inv(si * I - A_conj) @ B_conj) for si in s])
        #print(G_s_conj.shape)

    G_s_real = (G_s + G_s_conj) / 2 

    G_s_real = G_s_real.detach().cpu().numpy()
    s = s.detach().cpu().numpy()

    return G_s_real, s

# transfer function of discret model G_z = C (zI - A)^-1 B
def transfer_function_Gz(A, B, C, s, SkipLayer=None, Ta= 1):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Define the identity matrix I of the same size as A
    #s_idx = s.imag <= np.pi/Ta
    #s = s[s_idx]

    s = torch.tensor(s, device=device)
    z = torch.exp(s*Ta)  # Frequency range for G(z) (complex frequency z = exp(jω*Ta)

    
    # map A, B, C SkipLayer to torch tensors
    A = torch.tensor(A, device=device)
    B = torch.tensor(B, device=device)
    C = torch.tensor(C, device=device)

    A_conj = torch.conj(A)
    B_conj = torch.conj(B)
    C_conj = torch.conj(C)

    I = torch.eye(A.size(0), device=device)  # Identity matrix of appropriate size
    if SkipLayer is not None:
        SkipLayer = torch.tensor(SkipLayer, device=device)
        SkipLayer_conj = torch.conj(SkipLayer)

    if SkipLayer is not None:
        # Compute G(s)
        G_z = torch.stack([(C @ torch.linalg.inv(zi * I - A) @ B + SkipLayer) for zi in z])
        #print(G_z.shape)
        G_z_conj = torch.stack([(C_conj @ torch.linalg.inv(zi * I - A_conj) @ B_conj + SkipLayer_conj) for zi in z])
        #print(G_z_conj.shape)
     
    else:# Compute the transfer function G(s) at each frequency
        G_z = torch.stack([(C @ torch.linalg.inv(zi * I - A) @ B ) for zi in z])
        #print(G_z.shape)
        G_z_conj = torch.stack([(C_conj @ torch.linalg.inv(zi * I - A_conj) @ B_conj) for zi in z])
        #print(G_z_conj.shape)

    G_z_real = (G_z + G_z_conj) / 2

    G_z_real = G_z_real.detach().cpu().numpy()
    s = s.detach().cpu().numpy()
    z = z.detach().cpu().numpy()

    return G_z_real, s

In [None]:
# Another Visualization of the eigenvalues, Time constants and frequencies.
# We link it with the transfer function and discretization/aliasing analysis

# Get model parameter dicts
layer_dict_d_trained , layer_dict_c_trained = export_layer_parameters(model_trained)
layer_dict_d_init , layer_dict_c_init= export_layer_parameters(model_init)

scale = 16000 # kHz  sampling rate of the original audio data

fig, ax = plt.subplots(1, 3, figsize=(8.25, 8.25/3))
for i in range(len(layer_dict_c_trained)):
    A_init = layer_dict_c_init[i]['A']
    B_init = layer_dict_c_init[i]['B']
    B_bias_init = layer_dict_c_init[i]['B_bias']
    C_init = layer_dict_c_init[i]['C']

    A_trained = layer_dict_c_trained[i]['A']
    B_trained = layer_dict_c_trained[i]['B']
    B_bias_trained = layer_dict_c_trained[i]['B_bias']
    C_trained = layer_dict_c_trained[i]['C']

    ax[i].scatter(A_init.real, A_init.imag, label='Layer '+str(i+1) + ' init', alpha=0.4, color = color_list_plt[0])
    ax[i].scatter(A_trained.real, A_trained.imag, label='Layer '+str(i+1) + ' trained', alpha=0.4, color = color_list_plt[i+1])
    ax[i].set_xlim([-0.15, 0.01])
    ax[i].set_ylim([-1, 1])

    ax[i].set_xticks(np.arange(-0.15, 0.01, 0.05))
    ax[i].set_yticks(np.arange(-np.pi, np.pi + np.pi/4, np.pi/4))

    ax[i].legend()

    ax[i].grid(axis='y', linestyle='--')

    if i == 1:
        ax[i].set_xlabel('Real')
    if i == 0:
        ax[i].set_ylabel('Imaginary')

plt.tight_layout()
# save figure as pdf
plt.savefig(os.path.join(model_save_path, 'eigenvalues.pdf'))
plt.show()

fig, ax = plt.subplots(1, 3, figsize=(8.25, 8.25/3))

for i in range(len(layer_dict_c_trained)):
    A_init = layer_dict_c_init[i]['A']
    B_init = layer_dict_c_init[i]['B']
    B_bias_init = layer_dict_c_init[i]['B_bias']
    C_init = layer_dict_c_init[i]['C']

    A_trained = layer_dict_c_trained[i]['A']
    B_trained = layer_dict_c_trained[i]['B']
    B_bias_trained = layer_dict_c_trained[i]['B_bias']
    C_trained = layer_dict_c_trained[i]['C']

    ax[i].scatter(1/(abs(A_init.real)*scale), abs(A_init.imag)*scale/(2*np.pi), alpha=0.5, color = color_list_plt[0], label='Layer '+str(i+1) + ' init')
    ax[i].scatter(1/(abs(A_trained.real)*scale), abs(A_trained.imag)*scale/(2*np.pi),  alpha=0.5, color = color_list_plt[i+1], label='Layer '+str(i+1) + ' trained')
    ax[i].legend() #loc='upper right'
    ax[i].grid()

ax[0].set_xscale('log')
ax[0].set_yscale('log')

ax[1].set_xscale('log')
ax[1].set_yscale('log')

ax[2].set_xscale('log')
ax[2].set_yscale('log')
#ax.grid()
ax[1].set_xlabel('exp. decay time constants [s]')
ax[0].set_ylabel('frequencies [Hz]')
#ax.legend()
plt.tight_layout()
# save figure as pdf
plt.savefig(os.path.join(model_save_path, 'frequency_time_scales.pdf'))

plt.show()
# export 


#### Transfer function extraction

In [None]:
# Show all transfer functions for all layers, 
# Warning this can take a while to comput and plot 
# becasue we have a trasnfer function for each input/output pair for each layer
# We plot the G(s) here and just the magnitude
fig, ax = plt.subplots(3, 1, figsize=(8.25, 8.25))
for layer in range(3):
    Ta = 1
    T_z = 1
    step_scales = [T_z]*10
    dictionary_d , dictionary_c = export_layer_parameters(model_trained, step_scale_list = step_scales)

    A = np.diag(dictionary_c[layer]['A'])
    B = dictionary_c[layer]['B']
    C = dictionary_c[layer]['C']
    SkipLayer = dictionary_c[layer]['SkipLayer'] 
    # Including Skip or not is a philosophical question, here we include it, 
    # becasue we say we analyze the leffects of beeing in the linear pahse of ReLU,
    # setting it to none slightly changes the shapes but not drastically

    G_s_real, s = transfer_function_Gs(A, B, C, SkipLayer=SkipLayer, Ta_max=Ta)

    magnitude = 20*np.log10(np.abs(G_s_real))
    phase = np.angle(G_s_real, deg=False)
    print('shape mag:', magnitude.shape)
    n_out = C.shape[0]
    n_in = B.shape[1]
    ax[layer].set_xscale('log')
    # n_out + n_in random colors for the lines for plotly 
    color = ['blue', 'red', 'green', 'black', 'orange', 'purple', 'yellow', 'brown', 'pink', 'cyan']
    for i_out in range(n_out):
        for i_in in range(n_in):
            ax[layer].plot(s.imag, magnitude[:,i_out,i_in], label='G(s) Out'+str(i_out) + 'In'+str(i_in), alpha=0.2)#, color=color[layer]) 
plt.show()

##### First Visualization of Discretization Effects
Setting the time scales allows to sub-sample the system and introduces aliasign effects.
Next graph shows differences between, G(s) and G(z) for the
continuous-time and three subsampled-systems for a single input/output transfer function.


In [None]:
# Visualization of discretization error 
model = model_trained
fig, ax = plt.subplots(2, 1, figsize=(8.25, 8.25/2))
layer = 0
max_points = 500
Ta = 1
T_z = 1
step_scales = [T_z]*10
dictionary_d , dictionary_c = export_layer_parameters(model, step_scale_list = step_scales)
A = np.diag(dictionary_c[layer]['A'])
B = dictionary_c[layer]['B']
C = dictionary_c[layer]['C']
SkipLayer = dictionary_c[layer]['SkipLayer']
#SkipLayer = None
A_d = np.diag(dictionary_d[layer]['A'])
B_d = dictionary_d[layer]['B']
C_d = dictionary_d[layer]['C']
SkipLayer_d = dictionary_d[layer]['SkipLayer']
#SkipLayer_d = None
G_s_real, s = transfer_function_Gs(A, B, C, SkipLayer=SkipLayer, Ta_max=Ta, max_points= max_points)
G_z_real, z = transfer_function_Gz(A_d, B_d, C_d, s, SkipLayer=SkipLayer, Ta=T_z)
magnitude = 20*np.log10(np.abs(G_s_real))
phase = np.angle(G_s_real, deg=False)
magnitude_z = 20*np.log10(np.abs(G_z_real))
phase_z = np.angle(G_z_real, deg=False, )


T_z = 2
step_scales = [T_z]*10
dictionary_d , dictionary_c = export_layer_parameters(model, step_scale_list = step_scales)
A = np.diag(dictionary_c[layer]['A'])
B = dictionary_c[layer]['B']
C = dictionary_c[layer]['C']
SkipLayer = dictionary_c[layer]['SkipLayer']
#SkipLayer = None
A_d = np.diag(dictionary_d[layer]['A'])
B_d = dictionary_d[layer]['B']
C_d = dictionary_d[layer]['C']
SkipLayer_d = dictionary_d[layer]['SkipLayer']
#SkipLayer_d = None
G_s_real, s = transfer_function_Gs(A, B, C, SkipLayer=SkipLayer, Ta_max=Ta,  max_points= max_points)
G_z_real, z2 = transfer_function_Gz(A_d, B_d, C_d, s, SkipLayer=SkipLayer, Ta=T_z)
magnitude = 20*np.log10(np.abs(G_s_real))
phase = np.angle(G_s_real, deg=False)
magnitude_z_2 = 20*np.log10(np.abs(G_z_real))

T_z = 4
step_scales = [T_z]*10
dictionary_d , dictionary_c = export_layer_parameters(model, step_scale_list = step_scales)
A = np.diag(dictionary_c[layer]['A'])
B = dictionary_c[layer]['B']
C = dictionary_c[layer]['C']
SkipLayer = dictionary_c[layer]['SkipLayer']
#SkipLayer = None
A_d = np.diag(dictionary_d[layer]['A'])
B_d = dictionary_d[layer]['B']
C_d = dictionary_d[layer]['C']
SkipLayer_d = dictionary_d[layer]['SkipLayer']
#SkipLayer_d = None
G_s_real, s = transfer_function_Gs(A, B, C, SkipLayer=SkipLayer, Ta_max=Ta,  max_points= max_points)
G_z_real, z4 = transfer_function_Gz(A_d, B_d, C_d, s, SkipLayer=SkipLayer, Ta=T_z)
magnitude = 20*np.log10(np.abs(G_s_real))
phase = np.angle(G_s_real, deg=False)
magnitude_z_4 = 20*np.log10(np.abs(G_z_real))


print('shape mag:', magnitude.shape)
n_out = C_d.shape[0]
n_in = B_d.shape[1]

i_out=2 # Here you can select the output index you want to plot
i_in=0 # Here you can select the input index you want to plot

scale = 16000 # kHz  1/s
idx_freq_1 = z.imag <= np.pi/1
idx_freq_2 = z2.imag <= np.pi/2
idx_freq_4 = z4.imag <= np.pi/4

error0 = np.abs(magnitude[:,i_out,i_in] - magnitude_z[:,i_out,i_in])
error1 = np.abs(magnitude[:,i_out,i_in] - magnitude_z_2[:,i_out,i_in])
error2 = np.abs(magnitude[:,i_out,i_in] - magnitude_z_4[:,i_out,i_in])

ax[0].plot(s.imag/(2*np.pi)*scale, magnitude[:,i_out,i_in], label='Continuous', alpha=1 , color=color_list_plt[0])

ax[0].plot(z[idx_freq_1].imag/(2*np.pi)*scale, magnitude_z[idx_freq_1,i_out,i_in], label='Discrete (T=16kHz)', alpha=1, color=color_list_plt[1])
ax[0].plot(z[~idx_freq_1].imag/(2*np.pi)*scale, magnitude_z[~idx_freq_1,i_out,i_in], linestyle='dashdot', alpha=1, color=color_list_plt[1])

ax[0].plot(z2[idx_freq_2].imag/(2*np.pi)*scale, magnitude_z_2[idx_freq_2,i_out,i_in], label='Discrete (T=8kHz)', alpha=1, color=color_list_plt[2])
ax[0].plot(z2[~idx_freq_2].imag/(2*np.pi)*scale, magnitude_z_2[~idx_freq_2,i_out,i_in], alpha=0.4, color=color_list_plt[2])

ax[0].plot(z4[idx_freq_4].imag/(2*np.pi)*scale, magnitude_z_4[idx_freq_4,i_out,i_in], label='Discrete (T=4kHz)', alpha=1, color=color_list_plt[3])
ax[0].plot(z4[~idx_freq_4].imag/(2*np.pi)*scale, magnitude_z_4[~idx_freq_4,i_out,i_in],  alpha=0.4, color=color_list_plt[3])

ax[0].vlines(scale/2, -60, 25, linestyle='dashdot', alpha=0.7, label='Nyquist (8kHz)',color=color_list_plt[1])
ax[0].vlines(scale/4, -60, 25, linestyle='dashed', alpha=0.7, label='Nyquist (4kHz)', color=color_list_plt[2])
ax[0].vlines(scale/8, -60, 25, linestyle='dotted', alpha=0.7, label='Nyquist (2kHz)', color=color_list_plt[3])
ax[0].set_ylabel('Magnitude [dB]')
ax[0].set_ylim([-60, 25])
ax[0].set_xlim([0.5, 8500])
ax[0].grid()
ax[0].legend(ncol=2, loc = 'upper left')
ax[0].set_xscale('log')
# omit the x labels
ax[0].set_xticklabels([])

# area plot between zero and error0 
ax[1].fill_between(s.imag/(2*np.pi)*scale, 0, error2, alpha=1, label='T=4kHz', color=color_list_plt[3])
ax[1].fill_between(s.imag/(2*np.pi)*scale, 0, error2, alpha=0.7, label='T=4kHz', color=color_list_plt[3])

ax[1].fill_between(s.imag/(2*np.pi)*scale, 0, error1, alpha=1, label='T=8kHz', color=color_list_plt[2])
ax[1].fill_between(s.imag/(2*np.pi)*scale, 0, error1, alpha=1, label='T=8kHz', color=color_list_plt[2])

ax[1].fill_between(s.imag/(2*np.pi)*scale, 0, error0, alpha=1, label='T=16kHz', color=color_list_plt[1])


#ax[1].plot(s.imag/(2*np.pi)*scale, error0 , label='T=16kHz', alpha=0.9 , color=color_list_plt[1])
#ax[1].plot(s.imag/(2*np.pi)*scale, error1 , label='T=8kHz', alpha=0.9 , color=color_list_plt[2])
#ax[1].plot(s.imag/(2*np.pi)*scale, error2 , label='T=4kHz', alpha=0.9 , color=color_list_plt[3])

ax[1].set_xscale('log')
ax[1].set_ylabel('Discretization Error')
ax[1].set_xlabel('Frequency [Hz]')
ax[1].set_xlim([0.5, 8500])
#ax[1].set_yscale('log')
ax[1].grid()
ax[1].legend(ncol=3, loc = 'upper left')

plt.tight_layout()
# save figure as pdf
plt.savefig(os.path.join(model_save_path, 'transfer_function_element.pdf'))
plt.show()

### Discretization Erros per layer


In [None]:
# Here we extract the discretization error as descriped in the paper (See Section A.3)
mae_error_list = []
rmse_error_list = []
mean_relative_error_list = []
step_scale_list = []

layer_list = []

magnitude_z_list = []
magnitude_list = []

step_scale_sweep = list(np.linspace(1, 64, 64))
#step_scale_sweep.extend([2, 4, 8])
step_scale_sweep.sort()

#step_scale_sweep = np.arange(1, 8.5, 1)
Ta_max = step_scale_sweep[-1]

for step_scale in tqdm(step_scale_sweep):

    Ta = step_scale
    step_scales = [step_scale]*6
    #print('step_scales', step_scales)
    magnitude_z_step_list = []
    magnitude_step_list = []
    #print(step_scales)
    layer_dict_d_trained , layer_dict_c_trained = export_layer_parameters(model_trained, step_scale_list = step_scales)
    #layer_dict_d_init , layer_dict_c_init= export_layer_parameters(model_init, step_scale_list=step_scales)
  
    for i in range(len(layer_dict_c_trained)):
    #for i in tqdm(range(3)):
        layer = i
        layer_list.append(str(layer))
        
        dictionary_c = layer_dict_c_trained
        dictionary_d = layer_dict_d_trained

        A = np.diag(dictionary_c[layer]['A'])
        B = dictionary_c[layer]['B']
        C = dictionary_c[layer]['C']

        SkipLayer_c = None #dictionary_c[layer]['SkipLayer']

        A_d = np.diag(dictionary_d[layer]['A'])
        B_d = dictionary_d[layer]['B']
        C_d = dictionary_d[layer]['C']
        SkipLayer_d = None #dictionary_d[layer]['SkipLayer']

        G_s_real, s = transfer_function_Gs(A , B, C, SkipLayer_c, Ta_max=Ta_max)
        G_z_real, z = transfer_function_Gz(A_d , B_d, C_d, s, SkipLayer_d, Ta=Ta)


        magnitude = 20*np.log10(np.abs(G_s_real))
        magnitude_z = 20*np.log10(np.abs(G_z_real))

        magnitude_step_list.append(magnitude)
        magnitude_z_step_list.append(magnitude_z)

        #print(magnitude.shape)
        #print(magnitude_z.shape)

        mae_error = np.mean(np.abs(magnitude - magnitude_z))
        rmse_error = np.sqrt(np.mean((magnitude - magnitude_z)**2))
        mean_relative_error = np.mean(np.abs(magnitude - magnitude_z) / np.abs(magnitude))

        mae_error_list.append(mae_error)
        rmse_error_list.append(rmse_error)
        mean_relative_error_list.append(mean_relative_error)

        step_scale_list.append(step_scale)

    magnitude_z_list.append(magnitude_z_step_list)
    magnitude_list.append(magnitude_step_list)


In [None]:
def get_errors(mag_list):
    num_layer = len(mag_list[0])
    layer_list = []
    step_scale_list = []
    mae_error_list = []
    length_min = mag_list[-1][layer][:,0,0].shape[0]
    for i, step_scale in  enumerate(step_scale_sweep):
        #print(step_scale)
        max_mag0_lin = np.array(1)
        for j in range(num_layer):
            #print(j)
            length = mag_list[i][j].shape[0]

            max_mag0_lin = max_mag0_lin.reshape(1,1,-1)
            mag0_lin = 10**(mag_list[0][j][0:length]/20)
            mag_lin = 10**(mag_list[i][j][0:length]/20)

            mag0_lin = mag0_lin*max_mag0_lin
            mag_lin = mag_lin*max_mag0_lin

            max_mag0_lin = np.max(np.sum(mag0_lin, axis=2, keepdims=True),axis=0,keepdims=True) # mag0_lin.shape = (freuqency, outdim, indim)

            mag0_lin = mag0_lin/max_mag0_lin
            mag_lin = mag_lin/max_mag0_lin

            error = np.abs(mag_lin - mag0_lin)
            mae_error = np.mean(error)

            layer_list.append(j)
            step_scale_list.append(step_scale)
            mae_error_list.append(mae_error)

    df = pd.DataFrame({'layer': layer_list, 'step_scale': step_scale_list, 'mae_error': mae_error_list})#, 'rmse_error': rmse_error_list, 'transfer_energy': transfer_energy_list})#, 'base_energy': base_energy_list})

    return df

df_error = get_errors(magnitude_z_list)

df_error.head()


In [None]:
fig, ax = plt.subplots(1, 3, figsize=(8.25, 8.25/3))

for layer in range(df_error['layer'].nunique()):
    #df_error_reg_layer = df_error_reg[df_error_reg['layer'] == layer]
    df_error_layer = df_error[df_error['layer'] == layer]
    #ax[layer].plot(df_error_no_reg_layer['step_scale'], df_error_no_reg_layer['mae_error'], linestyle= '-.', label='Layer '+str(layer+1) + ' unreg.' , color=color_list_plt[layer+1])
    ax[layer].plot(df_error_layer['step_scale'], df_error_layer['mae_error'], label='Layer '+str(layer+1) , color=color_list_plt[layer+1])
    ax[layer].grid()

    ax[layer].set_ylim([1e-7, 1.2])

    if layer == 1:
        ax[layer].set_xlabel('Downsampling Factor [1]')
    if layer == 0:
        ax[layer].set_ylabel('discretization error')
    ax[layer].legend(loc='upper left')#, bbox_to_anchor=(0, 1))
    ax[layer].set_yscale('log')
    # specify legend location 
    #ax[layer].legend(loc='upper right', bbox_to_anchor=(1, 1))
#ax.set_xlabel('sampling time [ms]')
#ax.set_ylabel('Discritization Error')

# save plot as pdf
plt.tight_layout()
plt.savefig(os.path.join(model_save_path, 'discretization_error.pdf'))
plt.show()