In [None]:
%load_ext autoreload
%autoreload 2

# Notebook that generates the plot for the training of random model with learning through synpatic weight

In [None]:
import flybrain.utils as utils
import flybrain.model as model
import numpy as np 
import matplotlib.pyplot as plt
import torch
import os
import pandas as pd
import seaborn as sns

In [None]:
def load_experiments(nSamples:int,
                     subfolder:str, 
                     activation:str, 
                     weights=True, 
                     shifts=False, 
                     gains=False,
                     N=100,
                     Nle=1,
                     Epochs=400,
                     lr=0.01 ,
                     loss='MSE',
                     target='0.00',
                     g=1.0,
                     tons=0.2,
                     tsim=200,
                     dt=0.1):
    data=[]
    gen_path=os.path.join(utils.get_root(), 'data', 'logs',subfolder,'weigth') 
    if loss=='Entropy':
        model_path=f"{activation}_Weights{weights}_Shifts{shifts}_Gains{gains}_N{N}_lr{lr}_NLE{Nle}_Epochs{Epochs}_{loss}_g{g}_Tons{tons}_Tsim{tsim}_dt{dt}"
    
    else:
        model_path=f"{activation}_Weights{weights}_Shifts{shifts}_Gains{gains}_N{N}_lr{lr}_NLE{Nle}_Epochs{Epochs}_{loss}_{target}_g{g}_Tons{tons}_Tsim{tsim}_dt{dt}"
    for i in range(nSamples):
        path=os.path.join(gen_path, model_path +f"_Sample{i}_logs.json")
        data.append(utils.load_logs(file_path=path))
    return data

## 1) Load the runs

In [None]:

tanhpos_n100_nle1=load_experiments(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=1)
tanhpos_n100_nle10=load_experiments(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=10)
tanhpos_n100_nle25=load_experiments(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=25)
tanhpos_n100_nle50=load_experiments(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=50)
tanhpos_n100_nle75=load_experiments(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=75)
tanhpos_n100_nle100=load_experiments(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=100)
exp_nle=[tanhpos_n100_nle1,
         tanhpos_n100_nle10,
         tanhpos_n100_nle25,
         tanhpos_n100_nle50,
         tanhpos_n100_nle75,
         tanhpos_n100_nle100]

tanhpos_n25_nle2=load_experiments(N=25,nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=2)
tanhpos_n50_nle5=load_experiments(N=50,nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=5)
tanhpos_n100_nle10=load_experiments(N=100,nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=10)
tanhpos_n200_nle20=load_experiments(N=200,nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=20)
tanhpos_n400_nle40=load_experiments(N=400,nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=40)
tanhpos_n500_nle50=load_experiments(N=500,nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=50)
exp_n=[tanhpos_n25_nle2,
         tanhpos_n50_nle5,
         tanhpos_n100_nle10,
         tanhpos_n200_nle20,
         tanhpos_n400_nle40,
         tanhpos_n500_nle50
        ]
         
tanhpos_n50_nle5_100=load_experiments(N=50,nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=5, target="100.00")            
tanhpos_n100_nle10_Exploding=load_experiments(N=100,nSamples=3,subfolder='rd_rnn',activation='tanh_positive', Nle=10, target="2.00")       
tanhpos_n100_nle10_Vanishing=load_experiments(N=100,nSamples=3,subfolder='rd_rnn',activation='tanh_positive', Nle=10, target="-2.00")  
tanhpos_n100_nle10_custom=load_experiments(N=100,nSamples=3,subfolder='rd_rnn',activation='tanh_positive', Nle=10, target="1.25_-1.25") 

tanhpos_n100_nle10_custom_01=load_experiments(N=100,nSamples=3,subfolder='rd_rnn',activation='tanh_positive', Nle=10, target="1.25_-1.25", lr=0.1) 

tanhpos_n100_nle10_10=load_experiments(N=100,nSamples=3,subfolder='rd_rnn',activation='tanh_positive', Nle=10, target="10.00", lr=0.01, Epochs=300) 

grad_exp=[tanhpos_n50_nle5_100,tanhpos_n100_nle10_Exploding,tanhpos_n100_nle10_Vanishing]     

## 2) Plot the training logs 

In [None]:
def plot_logs_trainings(list_of_logs, 
                        entry_used, 
                        axs,
                        serie, 
                        color):
    data=np.zeros((len(list_of_logs),len(list_of_logs[0][entry_used])))
    for i, sample in enumerate(list_of_logs):
        run = np.array(sample[entry_used])
        axs.plot(np.arange(len(run)),run, alpha=0.2, color=color, lw=1)
    
    # Assign the processed run to the data matrix
        data[i, :] = run
    mean=np.nanmean(data[:,:-1], axis=0 )
    var=np.nanstd(data[:,:-1], axis=0)
    epochs=np.arange(len(mean))
    
    axs.plot(epochs, mean,label=f'{serie}' ,alpha=0.9, lw=2,color=color)
    #axs.fill_between(epochs,mean-var,mean+var,alpha=0.1
                    #marker="D",
                    #markersize=1,
                    #linewidth=0.5,label=fr'$T_{"on"}:{tON}$'
    #                )
    #axs.legend()
    return

def plot_logs_spectrums(list_of_logs, 
                        entry_used, 
                        axs,
                        serie, 
                        color):
    data=np.zeros((len(list_of_logs),len(list_of_logs[0][entry_used])))
    for i, sample in enumerate(list_of_logs):
        run = np.array(sample[entry_used])
        axs.plot(np.linspace(0,100, len(run)),run, alpha=0.2, color=color, lw=1)
    
    # Assign the processed run to the data matrix
        data[i, :] = run
    mean=np.nanmean(data[:,:-1], axis=0 )
    var=np.nanstd(data[:,:-1], axis=0)
    epochs=np.linspace(0,100, len(mean))
    
    axs.plot(epochs, mean,label=f'{serie}' ,alpha=0.9, lw=2,color=color, marker='o', markersize=2)
    #axs.fill_between(epochs,mean-var,mean+var,alpha=0.1
                    #marker="D",
                    #markersize=1,
                    #linewidth=0.5,label=fr'$T_{"on"}:{tON}$'
    #                )
    axs.legend()
    return

### 2.1 Experiment number of exponent used

#### Loss vs Epochs

In [None]:
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")

plot_logs_trainings(tanhpos_n100_nle1,'training_loss', axs, 'NLE:1', color='blue')
plot_logs_trainings(tanhpos_n100_nle10,'training_loss', axs, 'NLE:10', color='orange')
plot_logs_trainings(tanhpos_n100_nle25,'training_loss', axs, 'NLE:25', color='red')
plot_logs_trainings(tanhpos_n100_nle50,'training_loss', axs, 'NLE:50', color='green')
plot_logs_trainings(tanhpos_n100_nle75,'training_loss', axs, 'NLE:75', color='gray')
#plot_logs_trainings(tanhpos_n100_nle100,'training_loss', axs, 'NLE:100', color='pink')

axs.set_xlim([0,80])
#axs.set_yscale('log')
axs.set_ylabel(r"$|L(\hat \lambda_{\theta})|$",fontsize=20)
axs.set_xlabel(r"$epochs$",fontsize=14)
axs.tick_params(axis='both', which='major', labelsize=16)
axs.legend(fontsize=20, frameon=False)

plt.tight_layout()
plt.savefig('../data/fig/FINAL/1_RD_Weights_NLE_trainingloss.svg')

#### Lambda max vs Epochs

In [None]:
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")

plot_logs_trainings(tanhpos_n100_nle1,'training_loss', axs, 'NLE:1', color='blue')
plot_logs_trainings(tanhpos_n100_nle10,'training_lambda_max', axs, 'NLE:10', color='orange')
plot_logs_trainings(tanhpos_n100_nle25,'training_lambda_max', axs, 'NLE:25', color='red')
plot_logs_trainings(tanhpos_n100_nle50,'training_lambda_max', axs, 'NLE:50', color='green')
plot_logs_trainings(tanhpos_n100_nle75,'training_lambda_max', axs, 'NLE:75', color='gray')
#plot_logs_trainings(tanhpos_n100_nle100,'training_loss', axs, 'NLE:100', color='pink')


axs.set_xlim([0,80])
axs.hlines(y=0,xmin=0, xmax=100, ls='--', color='black' ,lw=2)
axs.set_ylabel(r"$\lambda{max}$",fontsize=20)
axs.set_xlabel(r"$epochs$",fontsize=20)
axs.tick_params(axis='both', which='major', labelsize=16)
axs.legend(fontsize=20, frameon=False)
plt.tight_layout()
plt.savefig('../data/fig/FINAL/1_RD_Weights_NLE_lambdaMax_evolution.svg')
plt.show()


#### Spectrum

In [None]:
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")

plot_logs_trainings(tanhpos_n100_nle1,'spectrum', axs, 'NLE:1', color='blue')
plot_logs_trainings(tanhpos_n100_nle10,'spectrum', axs, 'NLE:10', color='orange')
plot_logs_trainings(tanhpos_n100_nle25,'spectrum', axs, 'NLE:25', color='red')
plot_logs_trainings(tanhpos_n100_nle50,'spectrum', axs, 'NLE:50', color='green')
plot_logs_trainings(tanhpos_n100_nle75,'spectrum', axs, 'NLE:75', color='gray')
#plot_logs_trainings(tanhpos_n100_nle100,'training_loss', axs, 'NLE:100', color='pink')

axs.set_xlim([0,100])
axs.hlines(y=0,xmin=0, xmax=100, ls='--', color='black' ,lw=2, alpha=0.5)
axs.set_ylabel(r"$\lambda_{i}$",fontsize=20)
axs.set_xlabel(r"$i$",fontsize=20)
axs.tick_params(axis='both', which='major', labelsize=16)
axs.legend(loc='lower left',fontsize=19, frameon=False)
#axs.set_ylim(-4,1)
axs.set_ylim(-4,1)
plt.tight_layout()
plt.savefig('../data/fig/FINAL/1_RD_Weights_NLE_spectrum.svg')
plt.show()


### 2.2 Experiment Network size

####  Spectrum

In [None]:
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")

plot_logs_spectrums(tanhpos_n25_nle2,'spectrum', axs, 'N:25', color='blue')
plot_logs_spectrums(tanhpos_n50_nle5,'spectrum', axs, 'N:50', color='orange')
plot_logs_spectrums(tanhpos_n100_nle100,'spectrum', axs, 'N:100', color='red')
plot_logs_spectrums(tanhpos_n200_nle20,'spectrum', axs, 'N:200', color='green')
plot_logs_spectrums(tanhpos_n400_nle40,'spectrum', axs, 'N:400', color='pink')
plot_logs_spectrums(tanhpos_n500_nle50,'spectrum', axs, 'N:500', color='gray')

axs.set_xlim([0,100])
axs.hlines(y=0,xmin=0, xmax=100, ls='--', color='black' ,lw=2, alpha=0.5)
axs.set_ylabel(r"$\lambda_{i}$",fontsize=22)
axs.set_xlabel(r"$i$",fontsize=22)
axs.tick_params(axis='both', which='major', labelsize=16)
#axs.legend(loc='lower left',fontsize=20, frameon=False)
axs.set_ylim(-2,1) 

#axs.set_xlim([0,52]) 
plt.tight_layout()
#plt.savefig('../data/fig/FINAL/1_RD_Weights_NetworkSize_Effi.svg')
plt.show()

### 2.3 Experiment MinMax spectrum

#### Loss vs epochs

In [None]:
#Grad analysis
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")

plot_logs_trainings(tanhpos_n100_nle10_Exploding,'grad_weights', axs, r'$\lambda_{target}:2$', color='green')
plot_logs_trainings(tanhpos_n100_nle10,'grad_weights', axs, r'$\lambda_{target}:0$', color='red')
plot_logs_trainings(tanhpos_n100_nle10_Vanishing,'grad_weights', axs, r'$\lambda_{target};:-2$', color='blue')



#axs.set_xlim([0,80])
#axs.set_ylim([0,10])
axs.set_xlim([0,250])
axs.set_yscale('log')
axs.set_ylim([1E-6,1E2])
axs.set_ylabel(r"$|\nabla_{\theta}L|$",fontsize=22)
axs.set_xlabel(r"$epochs$",fontsize=22)
axs.tick_params(axis='both', which='major', labelsize=16)
#axs.legend(fontsize=18, frameon=False)
plt.tight_layout()
plt.savefig('../data/fig/FINAL/1_RD Weights_Limit_grad.svg')

#### Lambda max vs epochs

In [None]:
#Grad analysis
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")

plot_logs_trainings(tanhpos_n100_nle10_Exploding,'training_lambda_max', axs, r'$\lambda_{target}:10$', color='green')
plot_logs_trainings(tanhpos_n100_nle10,'training_lambda_max', axs, r'$\lambda_{target}:0$', color='red')
plot_logs_trainings(tanhpos_n100_nle10_Vanishing,'training_lambda_max', axs, r'$\lambda_{target}:-2$', color='blue')




axs.set_xlim([0,100])
#axs.set_yscale('log')
axs.set_ylabel(r"$\lambda_{max}$",fontsize=20)
axs.set_xlabel(r"$epochs$",fontsize=20)
axs.tick_params(axis='both', which='major', labelsize=16)
#axs.legend(fontsize=12,loc='lower right', frameon=False)

plt.tight_layout()

#### Spectrum

In [None]:
#Grad analysis
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")

plot_logs_trainings(tanhpos_n100_nle10_Exploding,'spectrum', axs, r'$\lambda_{target}:\vec{2}$', color='green')
plot_logs_trainings(tanhpos_n100_nle10,'spectrum', axs, r'$\lambda_{target}:\vec{0}$', color='red')
plot_logs_trainings(tanhpos_n100_nle10_Vanishing,'spectrum', axs, r'$\lambda_{target}:\vec{-2}$', color='blue')


axs.set_xlim([0,100])
axs.hlines(y=0,xmin=0, xmax=100, ls='--', color='black' ,lw=2, alpha=0.5)
axs.set_ylabel(r"$\lambda_{i}$",fontsize=22)
axs.set_xlabel(r"$i$",fontsize=22)
axs.tick_params(axis='both', which='major', labelsize=16)
axs.legend(fontsize=20, loc='upper right',frameon=False)
axs.set_ylim(-3,2.5)
plt.tight_layout()
plt.savefig('../data/fig/FINAL/1_RD Weights_Limit_spect.svg')


## 3) Plot the convergence as function of the number of lyapu

In [None]:
def compute_conv_grad_based(list_of_logs, window_size=100):
    grad_weigths=np.zeros((len(list_of_logs),len(list_of_logs[0]['grad_weights'])))
    grad_shifts=np.zeros((len(list_of_logs),len(list_of_logs[0]['grad_shifts'])))
    grad_gains=np.zeros((len(list_of_logs),len(list_of_logs[0]['grad_gains'])))

    for i, sample in enumerate(list_of_logs):
        grad_weigths[i, :]=np.array(sample['grad_weights'])
        grad_shifts[i, :]=np.array(sample['grad_shifts'])
        grad_gains[i, :]=np.array(sample['grad_gains'])
        #plt.plot(np.arange(100),grad_shifts[i, -100:])

    mean_weight=np.mean(grad_weigths[:, -window_size:], axis=1)
    mean_shifts=np.mean(grad_shifts[:, -window_size:], axis=1)
    mean_gains=np.mean(grad_gains[:, -window_size:], axis=1)
    
    return {'weights':mean_weight,
            'gains':mean_gains,
            'shifts':mean_shifts}

def compute_conv_loss_based(list_of_logs, window_size=100):
    loss=np.zeros((len(list_of_logs),len(list_of_logs[0]['training_loss'])))

    for i, sample in enumerate(list_of_logs):
        loss[i, :]=np.array(sample['training_loss'])
        #plt.plot(np.arange(100),grad_shifts[i, -100:])

    mean_loss=np.mean(loss[:, -window_size:], axis=1)
    return {'loss':mean_loss}
    

In [None]:
def plot_conv_grad_based(expriment_lists, experiment_names,ax):
    # Prepare data for plotting
    data = []
    for exp, label in zip(expriment_lists, experiment_names):
        dict_conv =compute_conv_grad_based(exp)
        combined_values = dict_conv['weights']  # Combine gains and shifts
        for value in combined_values:
            data.append({'Experiment': label, 'Value': value})
    # Convert to DataFrame
    df = pd.DataFrame(data)
  
    # Calculate means and error bars (standard deviation or standard error)
    summary = df.groupby('Experiment')['Value'].agg(['mean', 'std', 'min', 'max']).reset_index()
    summary['x'] = range(1, len(summary) + 1)
    # Plot using matplotlib
    
    # Combine upper and lower errors
    yerr = [summary['min'], summary['max']]

    # Plot line and error bars with the adjusted errors
    ax.scatter(df['Experiment'], df['Value'],color='#3759FF', alpha=0.4,label='Runs')
    ax.plot(summary['Experiment'], summary['mean'], color='black',marker='d',markersize=8, alpha=0.6, label='Mean')
    #ax.errorbar(summary['Experiment']-3, summary['mean'], yerr=yerr, fmt='None', markersize=8, 
    #             capsize=4, alpha=0.6)
    # Customize the plot
    ax.spines["right"].set_color("none")
    ax.spines["top"].set_color("none")
    ax.set_xticks(summary['Experiment'])
    ax.set_xticklabels(summary['Experiment'], fontsize=12)
    
   

def plot_conv_loss_based(expriment_lists, experiment_names,ax):
    # Prepare data for plotting
    data = []
    for exp, label in zip(expriment_lists, experiment_names):
        dict_conv =compute_conv_loss_based(exp)
        combined_values=dict_conv['loss'] # Combine gains and shifts
        for value in combined_values:
            data.append({'Experiment': label, 'Value': value})
    # Convert to DataFrame
    df = pd.DataFrame(data)
  
    # Calculate means and error bars (standard deviation or standard error)
    summary = df.groupby('Experiment')['Value'].agg(['mean', 'std', 'min', 'max']).reset_index()
    summary['x'] = range(1, len(summary) + 1)
    # Plot using matplotlib
    
    # Combine upper and lower errors
    yerr = [summary['min'], summary['max']]

    # Plot line and error bars with the adjusted errors
    ax.scatter(df['Experiment'], df['Value'],color='#3759FF', alpha=0.4,label='Runs')
    ax.plot(summary['Experiment'], summary['mean'], color='black',marker='d',markersize=8, alpha=0.6, label='Mean')
    #ax.errorbar(summary['Experiment']-3, summary['mean'], yerr=yerr, fmt='None', markersize=8, 
    #             capsize=4, alpha=0.6)
    # Customize the plot
    ax.spines["right"].set_color("none")
    ax.spines["top"].set_color("none")
    ax.set_xticks(summary['Experiment'])
    ax.set_xticklabels(summary['Experiment'], fontsize=12)



In [None]:
fig, ax=plt.subplots(1,1, figsize=(7,5),sharex=True)
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")

experiment_labels = [1, 10, 25, 50, 75, 100]
plot_conv_loss_based(exp_nle,experiment_labels,ax)

ax.set_ylabel(r"$\frac{1}{N} \sum_{T-50}^{T}L_{\theta}^t$", fontsize=14)
ax.set_xlabel("Number of Lyapunov exponent", fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=20)
ax.set_yscale('log')
ax.legend(loc='lower right',fontsize=16, frameon=False)
ax.set_ylim([1E-15, 1E0])
plt.tight_layout()
plt.savefig('../data/fig/FINAL/1_RD_Weights_NLE_convergence_loss_based.svg')
plt.show()

In [None]:
fig, ax=plt.subplots(1,1, figsize=(7,5),sharex=True)
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")

experiment_labels = [1, 10, 25, 50, 75, 100]
plot_conv_grad_based(exp_nle,experiment_labels,ax)

ax.set_ylabel(r"$|\nabla L_{\theta}|$", fontsize=20)
ax.set_xlabel("Number of Lyapunov exponent", fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=16)
ax.set_yscale('log')
ax.legend(loc='lower right',fontsize=12, frameon=False)
ax.set_ylim([1E-9, 1E0])
plt.tight_layout()
plt.savefig('../data/fig/FINAL/1_RD_Weights_NLE_convergence_grad_based.svg')
plt.show()

## 4) Plot the convergence as function of the net size

In [None]:
fig, ax=plt.subplots(1,1, figsize=(7,5),sharex=True)
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")

experiment_labels = [25, 50, 100, 200, 400,500]
plot_conv_loss_based(exp_n,experiment_labels,ax)

ax.set_ylabel(r"$\frac{1}{N} \sum_{T-50}^{T}L_{\theta}^t$", fontsize=20)
ax.set_xlabel("Network size", fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=16)
ax.set_yscale('log')
ax.set_ylim([1E-9, 1E0])
plt.tight_layout()
plt.savefig('../data/fig/FINAL/1_RD_Weights_N_convergence_loss_based.svg')
plt.show()


In [None]:
fig, ax=plt.subplots(1,1, figsize=(7,5),sharex=True)
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")

experiment_labels = [25, 50, 100, 200, 400,500]
plot_conv_grad_based(exp_n,experiment_labels,ax)

ax.set_ylabel(r"$|\nabla L_{\theta}|$", fontsize=20)
ax.set_xlabel("Network size", fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=16)
ax.set_yscale('log')
ax.set_ylim([1E-9, 1E0])
ax.set_ylim([0, 0.25])
plt.tight_layout()
plt.savefig('../data/fig/FINAL/1_RD_Weights_N_convergence_grad_based.svg')
plt.show()


## 5) Plot the time of compute

In [None]:
def plot_time(expriment_lists, experiment_names,ax):
    # Prepare data for plotting
    data = []
    for exp, label in zip(expriment_lists, experiment_names):
        time=[]
        for run in exp:
            time.append(float(run['time_training'])/400)
        combined_values = time
        for value in combined_values:
            data.append({'Experiment': label, 'Value': value})
    # Convert to DataFrame
    df = pd.DataFrame(data)
    
    # Calculate means and error bars (standard deviation or standard error)
    summary = df.groupby('Experiment')['Value'].agg(['mean', 'std']).reset_index()
    summary['x'] = range(1, len(summary) + 1)
    # Plot using matplotlib
    print(summary)
    #ax.scatter(df['Experiment'], df['Value'], color='black', alpha=0.2)
    ax.plot(summary['Experiment'], summary['mean'], color='blue',alpha=0.4)
    ax.errorbar(summary['Experiment'], summary['mean'], yerr=summary['std'], fmt='o', markersize=10, lw=1,capsize=4, label='Error (std)', alpha=0.4, color='blue')
    
    # Customize the plot
    ax.spines["right"].set_color("none")
    ax.spines["top"].set_color("none")
    ax.set_xticks(summary['Experiment'])
    ax.set_xticklabels(summary['Experiment'], fontsize=16)
    

In [None]:
fig, ax=plt.subplots(1,1, figsize=(7,5),sharex=True)
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")

ax.set_ylabel(r"Computation time $[\frac{s}{epoch}]$", fontsize=20)
ax.set_xlabel("Number of lyapunov exponent", fontsize=20)

experiment_labels = [1, 10, 25, 50, 75, 100]
plot_time(exp_nle,experiment_labels,ax)
plt.savefig('../data/fig/FINAL/1_RD_Weights_Nle_Effi.svg')
plt.show()

In [None]:
fig, ax=plt.subplots(1,1, figsize=(7,5),sharex=True)
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")

ax.set_ylabel(r"Computation time $[\frac{s}{epoch}]$", fontsize=20)
ax.set_xlabel("Network size", fontsize=20)

experiment_labels = [25, 50, 100, 200, 400,500]
plot_time(exp_n,experiment_labels,ax)
plt.savefig('../data/fig/FINAL/1_RD_Weights_N_Effi.svg')
plt.show()

## 6) Weights distribution

In [None]:
def get_model_path( subfolder:str,
                    activation:str, 
                     weights=True, 
                     shifts=False, 
                     gains=False,
                     nSamples=0,
                     Nle=10,
                     Epochs=400,
                     lr=0.01 ,
                     loss='MSE',
                     target='0.00',
                     tons=0.2,
                     tsim=200,
                     dt=0.1):
    data=[]
    gen_path=os.path.join(utils.get_root(), 'data', 'models',subfolder,'weigth') 
    model_path=f"{activation}_Weights{weights}_Shifts{shifts}_Gains{gains}_N100_lr{lr}_NLE{Nle}_Epochs{Epochs}_{loss}_{target}_g1.0_Tons{tons}_Tsim{tsim}_dt{dt}_Sample{nSamples}"
    return os.path.join(gen_path, model_path)

In [None]:

def plot_weight_dist(model, ax, serie_name='', color='red', alpha=0.3):
    """
    Plot the weight distribution of a model on the provided axes.

    Parameters:
        model: The model object with weight attributes.
        ax: The Matplotlib Axes object to plot on.
        serie_name: Label for the legend.
        color: Color for the distribution plot.
        alpha: Transparency level for the plot.
    """
    array_weight = model.W.view(-1)
    sns.kdeplot(array_weight, ax=ax, label=serie_name, color=color, alpha=alpha)
    ax.set_title('Synpatic coupling distribution')


In [None]:
nle10_1=model.RNN(get_model_path(subfolder='rd_RNN',activation='tanh_positive', Nle=10,nSamples=0),'pos' )
nle10_2=model.RNN(get_model_path(subfolder='rd_RNN',activation='tanh_positive', Nle=10,nSamples=1),'pos' )
nle10_3=model.RNN(get_model_path(subfolder='rd_RNN',activation='tanh_positive', Nle=10,nSamples=2),'pos' )
nle10_4=model.RNN(get_model_path(subfolder='rd_RNN',activation='tanh_positive', Nle=10,nSamples=3),'pos' )

nle50_1=model.RNN(get_model_path(subfolder='rd_RNN',activation='tanh_positive', Nle=50,nSamples=0),'pos' )
nle50_2=model.RNN(get_model_path(subfolder='rd_RNN',activation='tanh_positive', Nle=50,nSamples=1),'pos' )
nle50_3=model.RNN(get_model_path(subfolder='rd_RNN',activation='tanh_positive', Nle=50,nSamples=2),'pos' )
nle50_4=model.RNN(get_model_path(subfolder='rd_RNN',activation='tanh_positive', Nle=50,nSamples=3),'pos' )


nle75_1=model.RNN(get_model_path(subfolder='rd_RNN',activation='tanh_positive', Nle=75,nSamples=0),'pos' )
nle75_2=model.RNN(get_model_path(subfolder='rd_RNN',activation='tanh_positive', Nle=75,nSamples=1),'pos' )
nle75_3=model.RNN(get_model_path(subfolder='rd_RNN',activation='tanh_positive', Nle=75,nSamples=2),'pos' )
nle75_4=model.RNN(get_model_path(subfolder='rd_RNN',activation='tanh_positive', Nle=75,nSamples=3),'pos' )

# Define models and their colors
model_groups = {
    'NLE=10': {'models': [nle10_1, nle10_2, nle10_3, nle10_4], 'color': 'red'},
    'NLE=50': {'models': [nle50_1, nle50_2, nle50_3, nle50_4], 'color': 'blue'},
    'NLE=75': {'models': [nle75_1, nle75_2, nle75_3, nle75_4], 'color': 'orange'},
}



In [None]:
# Create the figure
fig, ax = plt.subplots(1, 1, figsize=(7, 4))
# Plot distributions for each group
for label, group in model_groups.items():
    for idx, model in enumerate(group['models']):
        # Add series name only for the last model in each group
        serie_name = label if idx == len(group['models']) - 1 else ''
        plot_weight_dist(model, ax, serie_name=serie_name, color=group['color'])

# Display the legend
ax.legend()
plt.tight_layout()
plt.show()