In [None]:
%load_ext autoreload
%autoreload 2

# Notebook that generates the plot for the training of random model with fixed amount of parameters

In [None]:
import flybrain.utils as utils
import numpy as np 
import matplotlib.pyplot as plt
import torch
import os
import pandas as pd
import seaborn as sns

In [None]:
def load_experiments(nSamples:int,
                     subfolder:str, 
                     activation:str, 
                     weights=True, 
                     shifts=False, 
                     gains=False,
                     N=100,
                     Nle=1,
                     Epochs=400,
                     lr=0.1 ,
                     loss='MSE',
                     target='0.00',
                     g=1.0,
                     tons=0.2,
                     tsim=200,
                     dt=0.1, 
                     param="0_100_0"):
    data=[]
    gen_path=os.path.join(utils.get_root(), 'data', 'logs',subfolder) 
    if loss=='Entropy':
        model_path=f"{activation}_Weights{weights}_Shifts{shifts}_Gains{gains}_N{N}_lr{lr}_NLE{Nle}_Epochs{Epochs}_{loss}_g{g}_Tons{tons}_Tsim{tsim}_Param{param}"
    
    else:
        model_path=f"{activation}_Weights{weights}_Shifts{shifts}_Gains{gains}_N{N}_lr{lr}_NLE{Nle}_Epochs{Epochs}_{loss}_{target}_g{g}_Tons{tons}_Tsim{tsim}_Param{param}"
    for i in range(nSamples):
        path=os.path.join(gen_path, model_path +f"_Sample{i}_logs.json")
        data.append(utils.load_logs(file_path=path))
    return data

def load_experiments_weights(nSamples:int,
                     subfolder:str, 
                     activation:str, 
                     weights=True, 
                     shifts=False, 
                     gains=False,
                     N=100,
                     Nle=1,
                     Epochs=400,
                     lr=0.01 ,
                     loss='MSE',
                     target='0.00',
                     g=1.0,
                     tons=0.2,
                     tsim=200,
                     dt=0.1):
    data=[]
    gen_path=os.path.join(utils.get_root(), 'data', 'logs',subfolder,'weigth') 
    if loss=='Entropy':
        model_path=f"{activation}_Weights{weights}_Shifts{shifts}_Gains{gains}_N{N}_lr{lr}_NLE{Nle}_Epochs{Epochs}_{loss}_g{g}_Tons{tons}_Tsim{tsim}_dt{dt}"
    
    else:
        model_path=f"{activation}_Weights{weights}_Shifts{shifts}_Gains{gains}_N{N}_lr{lr}_NLE{Nle}_Epochs{Epochs}_{loss}_{target}_g{g}_Tons{tons}_Tsim{tsim}_dt{dt}"
    for i in range(nSamples):
        path=os.path.join(gen_path, model_path +f"_Sample{i}_logs.json")
        data.append(utils.load_logs(file_path=path))
    return data

def load_experiments_activation(nSamples:int,
                     subfolder:str, 
                     activation:str, 
                     weights=False, 
                     shifts=True, 
                     gains=True,
                     N=100,
                     Nle=1,
                     Epochs=400,
                     lr=0.1 ,
                     loss='MSE',
                     target='0.00',
                     g=1.0,
                     tons=0.2,
                     tsim=200,
                     dt=0.1):
    data=[]
    gen_path=os.path.join(utils.get_root(), 'data', 'logs',subfolder,'activation') 
    if loss=='Entropy':
        model_path=f"{activation}_Weights{weights}_Shifts{shifts}_Gains{gains}_N{N}_lr{lr}_NLE{Nle}_Epochs{Epochs}_{loss}_g{g}_Tons{tons}_Tsim{tsim}_dt{dt}"
    
    else:
        model_path=f"{activation}_Weights{weights}_Shifts{shifts}_Gains{gains}_N{N}_lr{lr}_NLE{Nle}_Epochs{Epochs}_{loss}_{target}_g{g}_Tons{tons}_Tsim{tsim}_dt{dt}"
    for i in range(nSamples):
        path=os.path.join(gen_path, model_path +f"_Sample{i}_logs.json")
        data.append(utils.load_logs(file_path=path))
    return data

## 1) Load the runs

In [None]:
tanhpos_n100_nle1_weights_N2=load_experiments_weights(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=1)
tanhpos_n100_nle10_weights_N2=load_experiments_weights(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=10)
tanhpos_n100_nle25_weights_N2=load_experiments_weights(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=25)
tanhpos_n100_nle50_weights_N2=load_experiments_weights(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=50)
tanhpos_n100_nle75_weights_N2=load_experiments_weights(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=75)
tanhpos_n100_weights_N2=[tanhpos_n100_nle1_weights_N2,
         tanhpos_n100_nle10_weights_N2,
         tanhpos_n100_nle25_weights_N2,
         tanhpos_n100_nle50_weights_N2,
         tanhpos_n100_nle75_weights_N2,]

tanhpos_n100_nle1_weights_N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=1, weights=True,gains=False,shifts=False, param="100_0_0")
tanhpos_n100_nle10_weights_N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=10, weights=True,gains=False,shifts=False, param="100_0_0")
tanhpos_n100_nle25_weights_N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=25, weights=True,gains=False,shifts=False, param="100_0_0")
tanhpos_n100_nle50_weights_N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=50, weights=True,gains=False,shifts=False, param="100_0_0")
tanhpos_n100_nle75_weights_N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=75, weights=True,gains=False,shifts=False, param="100_0_0")
tanhpos_n100_weights_N=[tanhpos_n100_nle1_weights_N,
                    tanhpos_n100_nle10_weights_N,
                    tanhpos_n100_nle25_weights_N,
                    tanhpos_n100_nle50_weights_N,
                    tanhpos_n100_nle75_weights_N]

tanhpos_n100_nle1_gains_shifts_2N=load_experiments_activation(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=1)
tanhpos_n100_nle10_gains_shifts_2N=load_experiments_activation(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=10)
tanhpos_n100_nle25_gains_shifts_2N=load_experiments_activation(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=25)
tanhpos_n100_nle50_gains_shifts_2N=load_experiments_activation(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=50)
tanhpos_n100_nle75_gains_shifts_2N=load_experiments_activation(nSamples=5,subfolder='rd_rnn',activation='tanh_positive', Nle=75)
tanhpos_n100_gains_shifts_2N=[tanhpos_n100_nle1_gains_shifts_2N,
         tanhpos_n100_nle10_gains_shifts_2N,
         tanhpos_n100_nle25_gains_shifts_2N,
         tanhpos_n100_nle50_gains_shifts_2N,
         tanhpos_n100_nle75_gains_shifts_2N]

tanhpos_n100_nle1_gains_N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=1, weights=False,gains=True,shifts=False, param="0_100_0")
tanhpos_n100_nle10_gains_N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=10, weights=False,gains=True,shifts=False, param="0_100_0")
tanhpos_n100_nle25_gains_N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=25, weights=False,gains=True,shifts=False, param="0_100_0")
tanhpos_n100_nle50_gains_N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=50, weights=False,gains=True,shifts=False, param="0_100_0")
tanhpos_n100_nle75_gains_N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=75, weights=False,gains=True,shifts=False, param="0_100_0")
tanhpos_n100_gains_N=[tanhpos_n100_nle1_gains_N,
                    tanhpos_n100_nle10_gains_N,
                    tanhpos_n100_nle25_gains_N,
                    tanhpos_n100_nle50_gains_N,
                    tanhpos_n100_nle75_gains_N]


tanhpos_n100_nle1_shifts_N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=1, weights=False,gains=False,shifts=True, param="0_0_100")
tanhpos_n100_nle10_shifts_N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=10, weights=False,gains=False,shifts=True, param="0_0_100")
tanhpos_n100_nle25_shifts_N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=25, weights=False,gains=False,shifts=True, param="0_0_100")
tanhpos_n100_nle50_shifts_N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=50, weights=False,gains=False,shifts=True, param="0_0_100")
tanhpos_n100_nle75_shifts_N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=75, weights=False,gains=False,shifts=True, param="0_0_100")
tanhpos_n100_shifts_N=[tanhpos_n100_nle1_shifts_N,
                    tanhpos_n100_nle10_shifts_N,
                    tanhpos_n100_nle25_shifts_N,
                    tanhpos_n100_nle50_shifts_N,
                    tanhpos_n100_nle75_shifts_N]

tanhpos_n100_nle1_weights_2N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=1, weights=True,gains=False,shifts=False, param="200_0_0")
tanhpos_n100_nle10_weights_2N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=10, weights=True,gains=False,shifts=False, param="200_0_0")
tanhpos_n100_nle25_weights_2N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=25, weights=True,gains=False,shifts=False, param="200_0_0")
tanhpos_n100_nle50_weights_2N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=50, weights=True,gains=False,shifts=False, param="200_0_0")
tanhpos_n100_nle75_weights_2N=load_experiments(nSamples=3,subfolder='rd_RNN_fixed_param',activation='tanh_positive', Nle=75, weights=True,gains=False,shifts=False, param="200_0_0")
tanhpos_n100_weights_2N=[tanhpos_n100_nle1_weights_2N,
                         tanhpos_n100_nle10_weights_2N,
                    tanhpos_n100_nle25_weights_2N,
                    tanhpos_n100_nle50_weights_2N,
                    tanhpos_n100_nle75_weights_2N]



## 2) Plot the training logs 

In [None]:
def plot_logs_trainings(list_of_logs, 
                        entry_used, 
                        axs,
                        serie, 
                        color):
    data=np.zeros((len(list_of_logs),len(list_of_logs[0][entry_used])))
    for i, sample in enumerate(list_of_logs):
        run = np.array(sample[entry_used])
        axs.plot(np.arange(len(run)),run, alpha=0.2, color=color, lw=1)
    
    # Assign the processed run to the data matrix
        data[i, :] = run
    mean=np.nanmean(data[:,:-1], axis=0 )
    var=np.nanstd(data[:,:-1], axis=0)
    epochs=np.arange(len(mean))
    
    axs.plot(epochs, mean,label=f'{serie}' ,alpha=0.9, lw=2,color=color)
    #axs.fill_between(epochs,mean-var,mean+var,alpha=0.1
                    #marker="D",
                    #markersize=1,
                    #linewidth=0.5,label=fr'$T_{"on"}:{tON}$'
    #                )
    axs.legend()
    return

def plot_logs_spectrums(list_of_logs, 
                        entry_used, 
                        axs,
                        serie, 
                        color):
    data=np.zeros((len(list_of_logs),len(list_of_logs[0][entry_used])))
    for i, sample in enumerate(list_of_logs):
        run = np.array(sample[entry_used])
        axs.plot(np.linspace(0,100, len(run)),run, alpha=0.2, color=color, lw=1)
    
    # Assign the processed run to the data matrix
        data[i, :] = run
    mean=np.nanmean(data[:,:-1], axis=0 )
    var=np.nanstd(data[:,:-1], axis=0)
    epochs=np.linspace(0,100, len(mean))
    
    axs.plot(epochs, mean,label=f'{serie}' ,alpha=0.9, lw=2,color=color)
    #axs.fill_between(epochs,mean-var,mean+var,alpha=0.1
                    #marker="D",
                    #markersize=1,
                    #linewidth=0.5,label=fr'$T_{"on"}:{tON}$'
    #                )
    axs.legend()
    return

##  3) Plot the convergence as function of the number of lyapunov exponent used

In [None]:
def compute_conv_grad_based(list_of_logs, window_size=50):
    grad_weigths=np.zeros((len(list_of_logs),len(list_of_logs[0]['grad_weights'])))
    grad_shifts=np.zeros((len(list_of_logs),len(list_of_logs[0]['grad_shifts'])))
    grad_gains=np.zeros((len(list_of_logs),len(list_of_logs[0]['grad_gains'])))

    for i, sample in enumerate(list_of_logs):
        grad_weigths[i, :]=np.array(sample['grad_weights'])
        grad_shifts[i, :]=np.array(sample['grad_shifts'])
        grad_gains[i, :]=np.array(sample['grad_gains'])
        #plt.plot(np.arange(100),grad_shifts[i, -100:])

    mean_weight=np.mean(grad_weigths[:, -window_size:], axis=1)
    mean_shifts=np.mean(grad_shifts[:, -window_size:], axis=1)
    mean_gains=np.mean(grad_gains[:, -window_size:], axis=1)
    
    return {'weights':mean_weight,
            'gains':mean_gains,
            'shifts':mean_shifts}
    
def compute_conv_loss_based(list_of_logs, window_size=50):
    loss=np.zeros((len(list_of_logs),len(list_of_logs[0]['training_loss'])))

    for i, sample in enumerate(list_of_logs):
        loss[i, :]=np.array(sample['training_loss'])
        #plt.plot(np.arange(100),grad_shifts[i, -100:])

    mean_loss=np.mean(loss[:, -window_size:], axis=1)
    return {'loss':mean_loss}

In [None]:
def plot_conv_grad_based(expriment_lists, experiment_names,ax, key_list=['weights'], lab=''):
    # Prepare data for plotting
    data = []
    for exp, label in zip(expriment_lists, experiment_names):
        dict_conv =compute_conv_grad_based(exp)
        combined_values=0
        for key in key_list:
            combined_values += dict_conv[key]   # Combine gains and shifts
        for value in combined_values:
            data.append({'Experiment': label, 'Value': value})
    # Convert to DataFrame
    df = pd.DataFrame(data)
  
    # Calculate means and error bars (standard deviation or standard error)
    summary = df.groupby('Experiment')['Value'].agg(['mean', 'std']).reset_index()
    summary['x'] = range(1, len(summary) + 1)
    # Plot using matplotlib
    
    #ax.scatter(df['Experiment'], df['Value'], color='black', alpha=0.2)
    #ax.scatter(summary['Experiment'], summary['mean'], color='black', marker='d',alpha=0.8)
    #ax.errorbar(summary['Experiment'], summary['mean'], yerr=summary['std'], fmt='o', markersize=10, lw=1,capsize=4, alpha=0.4, label=lab)
    ax.plot(summary['Experiment'], summary['mean'],marker='o', lw=2, alpha=0.8, label=lab)
    # Customize the plot
    ax.spines["right"].set_color("none")
    ax.spines["top"].set_color("none")
    ax.set_xticks(summary['Experiment'])
    ax.set_xticklabels(summary['Experiment'], fontsize=12)
    
def plot_conv_loss_based(expriment_lists, experiment_names,ax, lab='', color='blue', ls='-', marker='o'):
    # Prepare data for plotting
    data = []
    for exp, label in zip(expriment_lists, experiment_names):
        dict_conv =compute_conv_loss_based(exp)
        combined_values=dict_conv['loss']
        for value in combined_values:
            data.append({'Experiment': label, 'Value': value})
    # Convert to DataFrame
    df = pd.DataFrame(data)
  
    # Calculate means and error bars (standard deviation or standard error)
    summary = df.groupby('Experiment')['Value'].agg(['mean', 'std']).reset_index()
    summary['x'] = range(1, len(summary) + 1)
    # Plot using matplotlib
    
    #ax.scatter(df['Experiment'], df['Value'], color='black', alpha=0.2)
    #ax.scatter(summary['Experiment'], summary['mean'], color='black', marker='d',alpha=0.8)
    #ax.errorbar(summary['Experiment'], summary['mean'], yerr=summary['std'], fmt='o', markersize=10, lw=1,capsize=4, alpha=0.4, label=lab)
    ax.plot(summary['Experiment'], summary['mean'],marker=marker,markersize=10, lw=2, ls=ls,alpha=0.8, label=lab)
    # Customize the plot
    ax.spines["right"].set_color("none")
    ax.spines["top"].set_color("none")
    ax.set_xticks(summary['Experiment'])
    ax.set_xticklabels(summary['Experiment'], fontsize=12)


In [None]:
a=7
b=20
b/a

In [None]:
a=5
fig, ax=plt.subplots(1,1, figsize=(int(a*2.857),a),sharex=True)
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")

experiment_labels = [1, 10, 25, 50, 75, 100]
plot_conv_loss_based(tanhpos_n100_weights_N,experiment_labels,ax,lab=r'$N$ parameters, Weights', marker='D')
plot_conv_loss_based(tanhpos_n100_gains_N,experiment_labels,ax,lab=r'$N$ parameters,Gains', marker='D')
#plot_conv_loss_based(tanhpos_n100_shifts_N,experiment_labels,ax,lab=r'$N$ parameters,Shifts',marker='D')
plot_conv_loss_based(tanhpos_n100_weights_2N,experiment_labels,ax,lab=r'$2N$ parameters, Weights', marker='P')
plot_conv_loss_based(tanhpos_n100_gains_shifts_2N,experiment_labels,ax,lab=r'$2N$ parameters, Gains and shifts', marker='P')
plot_conv_loss_based(tanhpos_n100_weights_N2,experiment_labels,ax,lab= "$N(N-1)$ parameters,Weights ",  marker='^')




#plot_conv_loss_based(tanhpos_n100_shifts_N,experiment_labels,ax,lab=r'$N$ parameters,Shifts')


ax.set_ylabel(r"$\frac{1}{N} \sum_{T-50}^{T}L_{\theta}^t$", fontsize=24)
ax.set_xlabel("Number of Lyapunov exponent", fontsize=24)
ax.tick_params(axis='both', which='major', labelsize=22)
#ax.set_yscale('log')
#ax.set_ylim([1E-15, 1E1])

ax.legend(fontsize=22, frameon=False)
plt.tight_layout()
plt.savefig('../data/fig/FINAL/3_RD_Fixed_convergence_loss_based.svg')
plt.show()

In [None]:
fig, ax=plt.subplots(1,1, figsize=(7,5),sharex=True)
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")

experiment_labels = [1, 10, 25, 50, 75, 100]
plot_conv_grad_based(tanhpos_n100_weights_N2,experiment_labels,ax,key_list=['weights'],lab= "Weights, $N^2$ parameters")
plot_conv_grad_based(tanhpos_n100_gains_shifts_2N,experiment_labels,ax,key_list=['gains','shifts'] ,lab=r'Gains and shifts, $2N$ parameters')
plot_conv_grad_based(tanhpos_n100_weights_2N,experiment_labels,ax,key_list=['weights'],lab=r'Weights, $2N$ parameters')
plot_conv_grad_based(tanhpos_n100_weights_N,experiment_labels,ax,key_list=['weights'],lab=r'Weights, $N$ parameters')
#plot_conv_grad_based(tanhpos_n100_gains_N,experiment_labels,ax,key_list=['gains'],lab=r'Gains, $N$ parameters')
#plot_conv_grad_based(tanhpos_n100_shifts_N,experiment_labels,ax,key_list=['shifts'],lab=r'Shifts, $N$ parameters')


ax.set_ylabel(r"$|\nabla L_{\theta}|$", fontsize=20)
ax.set_xlabel("Number of lyapunov exponent", fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=16)
#ax.set_yscale('log')
#ax.set_ylim([1E-15, 1E1])
#ax.set_xlim([0,90])
ax.legend(fontsize=10, frameon=False)
plt.tight_layout()
plt.savefig('../data/fig/FINAL/3_RD_Fixed_convergence_grad_based.svg')
plt.show()

## 4) Plot the time of compute

In [None]:
def plot_time(expriment_lists, experiment_names,ax, lab=''):
    # Prepare data for plotting
    data = []
    for exp, label in zip(expriment_lists, experiment_names):
        time=[]
        for run in exp:
            time.append(float(run['time_training'])/400)
        combined_values = time
        for value in combined_values:
            data.append({'Experiment': label, 'Value': value})
    # Convert to DataFrame
    df = pd.DataFrame(data)
    
    # Calculate means and error bars (standard deviation or standard error)
    summary = df.groupby('Experiment')['Value'].agg(['mean', 'std']).reset_index()
    summary['x'] = range(1, len(summary) + 1)
    # Plot using matplotlib

    #ax.scatter(df['Experiment'], df['Value'], alpha=0.2)
    ax.plot(summary['Experiment'], summary['mean'],marker='o',alpha=0.7, label=lab)
    #ax.errorbar(summary['Experiment'], summary['mean'], yerr=summary['std'], fmt='o', markersize=10, lw=1,capsize=4, label='Error (std)', alpha=0.4)
    
    # Customize the plot
    ax.spines["right"].set_color("none")
    ax.spines["top"].set_color("none")
    ax.set_xticks(summary['Experiment'])
    ax.set_xticklabels(summary['Experiment'], fontsize=12)
    

In [None]:
fig, ax=plt.subplots(1,1, figsize=(7,5),sharex=True)
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")

ax.set_ylabel(r"Computation time $[\frac{s}{epoch}]$", fontsize=12)
ax.set_xlabel("Number of lyapunov exponent", fontsize=12)

experiment_labels = [1, 10, 25, 50, 75, 100]
plot_time(tanhpos_n100_weights_N2,experiment_labels,ax,lab=r'Weight, $N^2$')
plot_time(tanhpos_n100_gains_shifts_2N,experiment_labels,ax,lab=r'Gains-Shifts, $2N$')
plot_time(tanhpos_n100_gains_N,experiment_labels,ax,lab=r'Gains, $N$')
ax.legend(fontsize=10, frameon=False)
plt.savefig('../data/fig/FINAL/3_RD_Fixed_Effi.svg')
plt.show()