In [None]:
%load_ext autoreload
%autoreload 2

# Notebook that generates the plot for the training of flybrain model with learning through non-linearity

In [None]:
import flybrain.utils as utils
import flybrain.model 
import numpy as np 
import matplotlib.pyplot as plt
import torch
import os
import pandas as pd
import seaborn as sns

cmap = plt.get_cmap('magma')
colors = [cmap(i / (6 - 1)) for i in range(6)]

cmap = plt.get_cmap('tab10')
colors_2 = [cmap(i / (6 - 1)) for i in range(6)]

In [None]:
def load_experiments(subfolder='flybrain_RNN', 
                     activation='tanh_positive', 
                     ROI='EB',
                     weights=False, 
                     shifts=True, 
                     gains=True,
                     nSamples=1,
                     Nle=1,
                     Epochs=400,
                     lr=0.1 ,
                     loss='MSE',
                     target='0.00',
                     tons=0.2,
                     tsim=200,
                     dt=0.1):
    data=[]
    gen_path=os.path.join(utils.get_root(), 'data', 'logs',subfolder) 
    model_path=f"{activation}_ROI_{ROI}_Weights{weights}_Shifts{shifts}_Gains{gains}_lr{lr}_NLE{Nle}_Epochs{Epochs}_{loss}_{target}_Tons{tons}_Tsim{tsim}_dt{dt}"
    for i in range(nSamples):
        path=os.path.join(gen_path, model_path +f"_Sample{i}_logs.json")
        data.append(utils.load_logs(file_path=path))
    return data

## 1) Load the runs

In [None]:
AB_nle5=load_experiments(nSamples=1,Nle=3, ROI='AB' )
AB_nle10=load_experiments(nSamples=1,Nle=6, ROI='AB' )

AB_nle20=load_experiments(nSamples=1,Nle=12, ROI='AB' )
AB_nle50=load_experiments(nSamples=1,Nle=29, ROI='AB' )
AB=[AB_nle5,AB_nle10,AB_nle20,AB_nle50]

EB_nle5=load_experiments(nSamples=1,Nle=22, ROI='EB' )
EB_nle10=load_experiments(nSamples=1,Nle=44, ROI='EB' )
EB_nle20=load_experiments(nSamples=1,Nle=88, ROI='EB' )
EB_nle50=load_experiments(nSamples=1,Nle=218, ROI='EB' )
EB_nle10_exploding=load_experiments(nSamples=1,Nle=44, ROI='EB', target='2.00' )
EB_nle10_vanishing=load_experiments(nSamples=1,Nle=44, ROI='EB', target='-2.00' )
EB_nle10_custom=load_experiments(nSamples=1,Nle=44, ROI='EB', target='1.25_-1.25' )
EB=[EB_nle5,EB_nle10,EB_nle20,EB_nle50]

PB_nle5=load_experiments(nSamples=1,Nle=19, ROI='PB' )
PB_nle10=load_experiments(nSamples=1,Nle=38, ROI='PB' )
PB_nle20=load_experiments(nSamples=1,Nle=76, ROI='PB' )
PB_nle50=load_experiments(nSamples=1,Nle=177, ROI='PB' )
PB_nle10_exploding=load_experiments(nSamples=1,Nle=38, ROI='PB', target='2.00' )
PB_nle10_vanishing=load_experiments(nSamples=1,Nle=38, ROI='PB', target='-2.00' )
PB_nle10_custom=load_experiments(nSamples=1,Nle=38, ROI='PB', target='1.25_-1.25' )
PB=[PB_nle5,PB_nle10,PB_nle20,PB_nle50]

ATL_nle5=load_experiments(nSamples=1,Nle=27, ROI='ATL' )
ATL_nle10=load_experiments(nSamples=1,Nle=53, ROI='ATL' )
ATL_nle20=load_experiments(nSamples=1,Nle=106, ROI='ATL' )
ATL_nle50=load_experiments(nSamples=1,Nle=264, ROI='ATL' )
ATL_nle10_exploding=load_experiments(nSamples=1,Nle=53, ROI='ATL', target='2.00' )
ATL_nle10_vanishing=load_experiments(nSamples=1,Nle=53, ROI='ATL', target='-2.00' )
ATL_nle10_custom=load_experiments(nSamples=1,Nle=53, ROI='ATL', target='1.25_-1.25' )
ATL=[ATL_nle5,ATL_nle10,ATL_nle20, ATL_nle50]

NO_nle5=load_experiments(nSamples=1,Nle=25, ROI='NO' )
NO_nle10=load_experiments(nSamples=1,Nle=50, ROI='NO' )
NO_nle20=load_experiments(nSamples=1,Nle=100, ROI='NO' )
NO_nle50=load_experiments(nSamples=1,Nle=259, ROI='NO' )
NO_nle10_exploding=load_experiments(nSamples=1,Nle=50, ROI='NO', target='2.00' )
NO_nle10_vanishing=load_experiments(nSamples=1,Nle=50, ROI='NO', target='-2.00' )
NO_nle10_custom=load_experiments(nSamples=1,Nle=50, ROI='NO', target='1.25_-1.25' )
NO=[NO_nle5,NO_nle10,NO_nle20,NO_nle50]

## 2) Plot the spectrum logs 

In [None]:
import flybrain.model as model
import flybrain.functional as functional
import flybrain.lyapunov as lyapu
import flybrain.connectome as connectome

def get_raw_spectrum_noramlized(roi='roi'):
    ROI_data = connectome.load_flybrain(ROI=roi, types="Neurotransmitter")
    W, C = connectome.normalize_connectome(
                W=ROI_data["weights"], C=ROI_data["connectivity"]
            )
    c0 = np.random.normal(0.0, 1.0, W.shape[0])
    rnn = model.RNN(
                connectivity_matrix=C,
                weights_matrix=W,
                initial_condition=c0,
                activation_function=functional.tanh_positive(),
            )
    return [ {'spectrum':lyapu.Lyapunov().compute_spectrum(model=rnn, dt=0.1, tSim=200, nLE=W.shape[0]-1, tONS=0.2)} ]

def plot_logs_spectrums(list_of_logs, 
                        entry_used, 
                        axs,
                        serie, 
                        color,
                        alpha=0.8):
    data=np.zeros((len(list_of_logs),len(list_of_logs[0][entry_used])))
    for i, sample in enumerate(list_of_logs):
        run = np.array(sample[entry_used])
    # Assign the processed run to the data matrix
        data[i, :] = run
    mean=np.nanmean(data[:,:-1], axis=0 )
    var=np.nanstd(data[:,:-1], axis=0)
    epochs=np.linspace(0,100, len(mean))
    if color is None:
        axs.plot(epochs, mean,label=f'{serie}' ,alpha=alpha, lw=2)
    else:
        axs.plot(epochs, mean,label=f'{serie}' ,alpha=alpha, lw=2,color=color)
    #axs.fill_between(epochs,mean-var,mean+var,alpha=0.1
                    #marker="D",
                    #markersize=1,
                    #linewidth=0.5,label=fr'$T_{"on"}:{tON}$'
    #                )
    #axs.legend()
    return

def plot_logs_trainings(list_of_logs, 
                        entry_used, 
                        axs,
                        serie, 
                        color):
    data=np.zeros((len(list_of_logs),len(list_of_logs[0][entry_used])))
    for i, sample in enumerate(list_of_logs):
        run = np.array(sample[entry_used])
        axs.plot(np.arange(len(run)),run, alpha=0.2, color=color, lw=1)
    
    # Assign the processed run to the data matrix
        data[i, :] = run
    mean=np.nanmean(data[:,:-1], axis=0 )
    var=np.nanstd(data[:,:-1], axis=0)
    epochs=np.arange(len(mean))
    
    axs.plot(epochs, mean,label=f'{serie}' ,alpha=0.9, lw=2,color=color)
    #axs.fill_between(epochs,mean-var,mean+var,alpha=0.1
                    #marker="D",
                    #markersize=1,
                    #linewidth=0.5,label=fr'$T_{"on"}:{tON}$'
    #                )
    axs.legend()
    return


### 2.1) Plot the spectrum of each individual ROI

In [None]:
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")

plot_logs_spectrums(get_raw_spectrum_noramlized('EB'),'spectrum', axs, 'Raw', color=colors[0])
plot_logs_spectrums(EB_nle5,'spectrum', axs, 'NLE:5%', color=colors[1])
plot_logs_spectrums(EB_nle10,'spectrum', axs, 'NLE:10%', color=colors[2])
plot_logs_spectrums(EB_nle20,'spectrum', axs, 'NLE:20%', color=colors[3])
plot_logs_spectrums(EB_nle50,'spectrum', axs, 'NLE:50%', color=colors[4])

axs.set_xlim([0,100])
axs.hlines(y=0,xmin=0, xmax=100, ls='--', color='black' ,lw=2, alpha=0.5)
axs.set_ylabel(r"$\lambda_{i}$",fontsize=14)
axs.set_xlabel(r"$i$",fontsize=14)
axs.tick_params(axis='both', which='major', labelsize=12)
axs.legend(loc='lower left',fontsize=10, frameon=False)
axs.set_ylim(-4,1)
plt.tight_layout()
plt.savefig('../data/fig/FINAL/4_Flybrain_EB_spectrum.svg')
plt.show()


In [None]:
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")

plot_logs_spectrums(get_raw_spectrum_noramlized('AB'),'spectrum', axs, 'Raw', color=colors[0])
plot_logs_spectrums(AB_nle5,'spectrum', axs, 'NLE:5%', color=colors[1])
plot_logs_spectrums(AB_nle10,'spectrum', axs, 'NLE:10%', color=colors[2])
plot_logs_spectrums(AB_nle20,'spectrum', axs, 'NLE:20%', color=colors[3])
plot_logs_spectrums(AB_nle50,'spectrum', axs, 'NLE:50%', color=colors[4])

axs.set_xlim([0,100])
axs.hlines(y=0,xmin=0, xmax=100, ls='--', color='black' ,lw=2, alpha=0.5)
axs.set_ylabel(r"$\lambda_{i}$",fontsize=14)
axs.set_xlabel(r"$i$",fontsize=14)
axs.tick_params(axis='both', which='major', labelsize=12)
axs.legend(loc='lower left',fontsize=10, frameon=False)
axs.set_ylim(-4,1)
plt.tight_layout()
plt.savefig('../data/fig/FINAL/4_Flybrain_AB_spectrum.svg')
plt.show()

In [None]:
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")

plot_logs_spectrums(get_raw_spectrum_noramlized('PB'),'spectrum', axs, 'Raw', color=colors[0])
plot_logs_spectrums(PB_nle5,'spectrum', axs, 'NLE:5%', color=colors[1])
plot_logs_spectrums(PB_nle10,'spectrum', axs, 'NLE:10%', color=colors[2])
plot_logs_spectrums(PB_nle20,'spectrum', axs, 'NLE:20%', color=colors[3])
plot_logs_spectrums(PB_nle50,'spectrum', axs, 'NLE:50%', color=colors[4])

axs.set_xlim([0,100])
axs.hlines(y=0,xmin=0, xmax=100, ls='--', color='black' ,lw=2, alpha=0.5)
axs.set_ylabel(r"$\lambda_{i}$",fontsize=14)
axs.set_xlabel(r"$i$",fontsize=14)
axs.tick_params(axis='both', which='major', labelsize=12)
axs.legend(loc='lower left',fontsize=10, frameon=False)
axs.set_ylim(-4,1)
plt.tight_layout()
plt.savefig('../data/fig/FINAL/4_Flybrain_PB_spectrum.svg')
plt.show()

In [None]:
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")


plot_logs_spectrums(get_raw_spectrum_noramlized('NO'),'spectrum', axs, 'Raw', color=colors[0])
plot_logs_spectrums(NO_nle5,'spectrum', axs, 'NLE:5%', color=colors[1])
plot_logs_spectrums(NO_nle10,'spectrum', axs, 'NLE:10%', color=colors[2])
plot_logs_spectrums(NO_nle20,'spectrum', axs, 'NLE:20%', color=colors[3])
plot_logs_spectrums(NO_nle50,'spectrum', axs, 'NLE:50%', color=colors[4])

axs.set_xlim([0,100])
axs.hlines(y=0,xmin=0, xmax=100, ls='--', color='black' ,lw=2, alpha=0.5)
axs.set_ylabel(r"$\lambda_{i}$",fontsize=14)
axs.set_xlabel(r"$i$",fontsize=14)
axs.tick_params(axis='both', which='major', labelsize=12)
axs.legend(loc='lower left',fontsize=10, frameon=False)
axs.set_ylim(-4,1)
plt.tight_layout()
plt.savefig('../data/fig/FINAL/4_Flybrain_NO_spectrum.svg')
plt.show()

In [None]:
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")

plot_logs_spectrums(get_raw_spectrum_noramlized('ATL'),'spectrum', axs, 'Raw', color=colors[0])
plot_logs_spectrums(ATL_nle5,'spectrum', axs, 'NLE:5%', color=colors[1])
plot_logs_spectrums(ATL_nle10,'spectrum', axs, 'NLE:10%', color=colors[2])
plot_logs_spectrums(ATL_nle20,'spectrum', axs, 'NLE:20%', color=colors[3])
plot_logs_spectrums(ATL_nle50,'spectrum', axs, 'NLE:50%', color=colors[4])

axs.set_xlim([0,100])
axs.hlines(y=0,xmin=0, xmax=100, ls='--', color='black' ,lw=2, alpha=0.5)
axs.set_ylabel(r"$\lambda_{i}$",fontsize=14)
axs.set_xlabel(r"$i$",fontsize=14)
axs.tick_params(axis='both', which='major', labelsize=12)
axs.legend(loc='lower left',fontsize=10, frameon=False)
axs.set_ylim(-4,1)
plt.tight_layout()
plt.savefig('../data/fig/FINAL/4_Flybrain_ATL_spectrum.svg')
plt.show()

### 2.2) Compare the raw spectrum

In [None]:
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")

plot_logs_spectrums(get_raw_spectrum_noramlized('EB'),'spectrum', axs, 'EB', color=None)
plot_logs_spectrums(get_raw_spectrum_noramlized('AB'),'spectrum', axs, 'AB', color=None)
plot_logs_spectrums(get_raw_spectrum_noramlized('PB'),'spectrum', axs, 'PB', color=None)
plot_logs_spectrums(get_raw_spectrum_noramlized('NO'),'spectrum', axs, 'NO', color=None)
plot_logs_spectrums(get_raw_spectrum_noramlized('ATL'),'spectrum', axs, 'ATL', color=None)

axs.set_xlim([0,100])

axs.hlines(y=0,xmin=0, xmax=100, ls='--', color='black' ,lw=2, alpha=0.5)
axs.set_ylabel(r"$\lambda_{i}$",fontsize=20)
axs.set_xlabel(r"$i$",fontsize=20)
axs.tick_params(axis='both', which='major', labelsize=18)
axs.legend(loc='upper right',fontsize=20, frameon=False)
axs.set_ylim(-1.5,-.4)
plt.tight_layout()
plt.savefig('../data/fig/FINAL/4_Flybrain_Raw_spectrum.svg')
plt.show()




In [None]:
def plot_kde_spect(list_of_logs, entry_used, axs, color, serie='', alpha=0.3):
    data = np.zeros((len(list_of_logs), len(list_of_logs[0][entry_used])))
    
    # Extract the data from the logs
    for i, sample in enumerate(list_of_logs):
        run = np.array(sample[entry_used])
        data[i, :] = run
    # Center the data around -1
    centered_data = data +1  # Centering data to -1
    
    # Flatten the centered data to use for KDE
    flattened_data = centered_data[:, :-1].flatten()  # Exclude last column if needed
    
    # Plot KDE for each entry in the data
    epochs = np.linspace(0, 100, data.shape[1])
    
    # If color is None, default to the seaborn color palette
    if color is None:
        sns.kdeplot(centered_data[:, :-1].flatten(), ax=axs, label=serie, alpha=alpha, fill=True,linewidth=0)
    else:
        sns.kdeplot(centered_data[:, :-1].flatten(), ax=axs, label=serie, alpha=alpha, color=color, fill=True,linewidth=0)
    
    axs.set_xlabel("Epochs")
    axs.set_ylabel("Density")
    axs.legend()
    # Calculate the variance of the centered distribution
    variance = np.var(flattened_data)
    
    return variance
    

In [None]:
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")

var_EB=plot_kde_spect(get_raw_spectrum_noramlized('EB'),'spectrum',  axs,color=None)
var_AB=plot_kde_spect(get_raw_spectrum_noramlized('AB'),'spectrum',  axs,color=None)
var_PB=plot_kde_spect(get_raw_spectrum_noramlized('PB'),'spectrum',  axs,color=None)
var_NO=plot_kde_spect(get_raw_spectrum_noramlized('NO'),'spectrum',  axs,color=None)
var_ATL=plot_kde_spect(get_raw_spectrum_noramlized('ATL'),'spectrum',  axs,color=None)

### 2.3) Compare the  vanishing gradient

In [None]:
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")
plot_logs_spectrums(EB_nle10_vanishing,'spectrum', axs, 'EB', color=None)
plot_logs_spectrums(PB_nle10_vanishing,'spectrum', axs, 'PB', color=None)
plot_logs_spectrums(NO_nle10_vanishing,'spectrum', axs, 'NO', color=None)
plot_logs_spectrums(ATL_nle10_vanishing,'spectrum', axs, 'ATL', color=None)
#axs.hlines(y=0,xmin=0, xmax=100, ls='--', color='black' ,lw=2, alpha=0.5)
axs.set_ylim([-3,1])

### 2.4) Compare the  exploding gradient

In [None]:
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")
plot_logs_spectrums(EB_nle10_exploding,'spectrum', axs, 'EB', color=None)
plot_logs_spectrums(PB_nle10_exploding,'spectrum', axs, 'PB', color=None)
plot_logs_spectrums(NO_nle10_exploding,'spectrum', axs, 'NO', color=None)
plot_logs_spectrums(ATL_nle10_exploding,'spectrum', axs, 'ATL', color=None)
axs.hlines(y=0,xmin=0, xmax=100, ls='--', color='black' ,lw=2, alpha=0.5)

### 2.5) Compare the custom target

In [None]:
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")
plot_logs_spectrums(EB_nle10_custom,'spectrum', axs, 'EB', color=None)
plot_logs_spectrums(PB_nle10_custom,'spectrum', axs, 'PB', color=None)
plot_logs_spectrums(NO_nle10_custom,'spectrum', axs, 'NO', color=None)
plot_logs_spectrums(ATL_nle10_custom,'spectrum', axs, 'ATL', color=None)
axs.hlines(y=0,xmin=0, xmax=100, ls='--', color='black' ,lw=2, alpha=0.5)

## 3) Comparaision ROI for specfici value

In [None]:
fig, axs=plt.subplots(1,1, figsize=(7,5),sharex=True)
axs.spines["right"].set_color("none")
axs.spines["top"].set_color("none")

#plot_logs_spectrums(get_raw_spectrum_noramlized('EB'),'spectrum', axs, 'EB_RAW', color=colors_2[0], alpha=.9)
plot_logs_spectrums(EB_nle20,'spectrum', axs, 'EB:20%',color=None, alpha=0.9)
plot_logs_spectrums(AB_nle20,'spectrum', axs, 'AB:20%',color=None, alpha=0.9)
plot_logs_spectrums(PB_nle20,'spectrum', axs, 'PB:20%',color=None, alpha=0.9)
plot_logs_spectrums(NO_nle20,'spectrum', axs, 'NO:20%',color=None, alpha=0.9)
plot_logs_spectrums(ATL_nle20,'spectrum', axs, 'ATL:20%',color=None, alpha=0.9)

axs.set_xlim([0,50])
axs.set_ylim(-1,0.2)
axs.hlines(y=0,xmin=0, xmax=100, ls='--', color='black' ,lw=2, alpha=0.5)
axs.set_ylabel(r"$\lambda_{i}$",fontsize=20)
axs.set_xlabel(r"$i$",fontsize=20)
axs.tick_params(axis='both', which='major', labelsize=16)
plt.tight_layout()
#axs.legend(loc='upper right',fontsize=20, frameon=False)
plt.savefig('../data/fig/FINAL/4_Flybrain_Trained_spectrum.svg')


In [None]:
def compute_conv_loss_based(list_of_logs, window_size=50):
    loss=np.zeros((len(list_of_logs),len(list_of_logs[0]['training_loss'])))

    for i, sample in enumerate(list_of_logs):
        loss[i, :]=np.array(sample['training_loss'])
        #plt.plot(np.arange(100),grad_shifts[i, -100:])

    mean_loss=np.mean(loss[:, -window_size:], axis=1)
    return {'loss':mean_loss}
    
def compute_conv_grad_based(list_of_logs, window_size=50):
    grad_weigths=np.zeros((len(list_of_logs),len(list_of_logs[0]['grad_weights'])))
    grad_shifts=np.zeros((len(list_of_logs),len(list_of_logs[0]['grad_shifts'])))
    grad_gains=np.zeros((len(list_of_logs),len(list_of_logs[0]['grad_gains'])))

    for i, sample in enumerate(list_of_logs):
        grad_weigths[i, :]=np.array(sample['grad_weights'])
        grad_shifts[i, :]=np.array(sample['grad_shifts'])
        grad_gains[i, :]=np.array(sample['grad_gains'])
        #plt.plot(np.arange(100),grad_shifts[i, -100:])

    mean_weight=np.mean(grad_weigths[:, -window_size:], axis=1)
    mean_shifts=np.mean(grad_shifts[:, -window_size:], axis=1)
    mean_gains=np.mean(grad_gains[:, -window_size:], axis=1)
    
    return {'weights':mean_weight,
            'gains':mean_gains,
            'shifts':mean_shifts}

In [None]:
def plot_conv_grad_based(expriment_lists, experiment_names,ax, key_list=['weights'], lab=''):
    # Prepare data for plotting
    data = []
    for exp, label in zip(expriment_lists, experiment_names):
        dict_conv =compute_conv_grad_based(exp)
        combined_values=0
        for key in key_list:
            combined_values += dict_conv[key]   # Combine gains and shifts
        for value in combined_values:
            data.append({'Experiment': label, 'Value': value})
    # Convert to DataFrame
    df = pd.DataFrame(data)
  
    # Calculate means and error bars (standard deviation or standard error)
    summary = df.groupby('Experiment')['Value'].agg(['mean', 'std']).reset_index()
    summary['x'] = range(1, len(summary) + 1)
    # Plot using matplotlib
    
    #ax.scatter(df['Experiment'], df['Value'], color='black', alpha=0.2)
    #ax.scatter(summary['Experiment'], summary['mean'], color='black', marker='d',alpha=0.8)
    #ax.errorbar(summary['Experiment'], summary['mean'], yerr=summary['std'], fmt='o', markersize=10, lw=1,capsize=4, alpha=0.4, label=lab)
    ax.plot(summary['Experiment'], summary['mean'],marker='o', lw=2, alpha=0.8, label=lab)
    # Customize the plot
    ax.spines["right"].set_color("none")
    ax.spines["top"].set_color("none")
    ax.set_xticks(summary['Experiment'])
    ax.set_xticklabels(summary['Experiment'], fontsize=12)
    
def plot_conv_loss_based(expriment_lists, experiment_names,ax, lab=''):
    # Prepare data for plotting
    data = []
    for exp, label in zip(expriment_lists, experiment_names):
        dict_conv =compute_conv_loss_based(exp)
        combined_values=dict_conv['loss']
        for value in combined_values:
            data.append({'Experiment': label, 'Value': value})
    # Convert to DataFrame
    df = pd.DataFrame(data)
  
    # Calculate means and error bars (standard deviation or standard error)
    summary = df.groupby('Experiment')['Value'].agg(['mean', 'std']).reset_index()
    summary['x'] = range(1, len(summary) + 1)
    # Plot using matplotlib
    
    #ax.scatter(df['Experiment'], df['Value'], color='black', alpha=0.2)
    #ax.scatter(summary['Experiment'], summary['mean'], color='black', marker='d',alpha=0.8)
    #ax.errorbar(summary['Experiment'], summary['mean'], yerr=summary['std'], fmt='o', markersize=10, lw=1,capsize=4, alpha=0.4, label=lab)
    ax.plot(summary['Experiment'], summary['mean'],marker='o', lw=2, alpha=0.8, label=lab)
    # Customize the plot
    ax.spines["right"].set_color("none")
    ax.spines["top"].set_color("none")
    ax.set_xticks(summary['Experiment'])
    ax.set_xticklabels(summary['Experiment'], fontsize=12)


In [None]:
fig, ax=plt.subplots(1,1, figsize=(7,5),sharex=True)
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")

experiment_labels = [1, 10, 20, 50]
plot_conv_loss_based(EB,experiment_labels,ax,lab= "EB")
plot_conv_loss_based(AB,experiment_labels,ax,lab=r'AB')
plot_conv_loss_based(PB,experiment_labels,ax,lab=r'PB')
plot_conv_loss_based(NO,experiment_labels,ax,lab=r'NO')
plot_conv_loss_based(ATL,experiment_labels,ax,lab=r'ATL')

#ax.set_ylabel(r"$|\nabla L_{\theta}|$", fontsize=14)
ax.set_ylabel(r"$\frac{1}{N} \sum_{T-50}^{T}L_{\theta}^t$", fontsize=16)
ax.set_xlabel("Fraction of Lyapunov exponent", fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=16)
ax.set_yscale('log')
ax.set_ylim([1E-8, 1E1])
ax.legend(fontsize=20, frameon=False)
plt.tight_layout()
plt.savefig('../data/fig/FINAL/4_Flybrain_convergence_loss.svg')
plt.show()

In [None]:
fig, ax=plt.subplots(1,1, figsize=(7,5),sharex=True)
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")

experiment_labels = [1, 10, 20, 50]
plot_conv_grad_based(EB,experiment_labels,ax,key_list=['gains','shifts'],lab= "AB")
plot_conv_grad_based(AB,experiment_labels,ax,key_list=['gains','shifts'] ,lab=r'PB')
plot_conv_grad_based(PB,experiment_labels,ax,key_list=['gains','shifts'],lab=r'EB')
plot_conv_grad_based(NO,experiment_labels,ax,key_list=['gains','shifts'],lab=r'NO')
plot_conv_grad_based(ATL,experiment_labels,ax,key_list=['gains','shifts'],lab=r'ATL')

ax.set_ylabel(r"$|\nabla L_{\theta}|$", fontsize=14)
#ax.set_ylabel(r"$\frac{1}{N} \sum_{T-50}^{T}L_{\theta}$", fontsize=14)
ax.set_xlabel("Fraction of lyapunov exponent", fontsize=14)
ax.tick_params(axis='both', which='major', labelsize=12)
ax.set_yscale('log')
ax.set_ylim([1E-9, 1E0])
ax.legend(fontsize=10, frameon=False)
plt.tight_layout()
plt.savefig('../data/fig/FINAL/4_Flybrain_convergence_grad.svg')
plt.show()

In [None]:
def plot_conv_vs_var(expriment_lists, variancece_lists,label_lists, ax, lab=''):
    # Prepare data for plotting
    data = []
    for exp, var, lab in zip(expriment_lists, variancece_lists,label_lists):
        #dict_conv =compute_conv_loss_based(exp)
        #combined_values=dict_conv['loss'][0]
        
        dict_conv =compute_conv_grad_based(exp)
        combined_values=dict_conv['gains'][0]+dict_conv['shifts'][0]
        print(combined_values)
    
        data.append({'Conv': combined_values, 'Var': var, 'Label':lab})
    # Convert to DataFrame
    df = pd.DataFrame(data)
    
    #ax.scatter(df['Experiment'], df['Value'], color='black', alpha=0.2)
    #ax.scatter(summary['Experiment'], summary['mean'], color='black', marker='d',alpha=0.8)
    #ax.errorbar(summary['Experiment'], summary['mean'], yerr=summary['std'], fmt='o', markersize=10, lw=1,capsize=4, alpha=0.4, label=lab)
    df_sorted = df.sort_values(by='Var')
    sns.lineplot(data=df_sorted, x='Var', y='Conv', color='black', linewidth=1, linestyle='-', ax=ax, )
    sns.scatterplot(data=df, x='Var', y='Conv', hue='Label', ax=ax, s=200, alpha=1)
   
   
    # Customize the plot
  

In [None]:
fig, ax=plt.subplots(1,1, figsize=(7,5),sharex=True)
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")


plot_conv_vs_var([EB_nle20,
                  AB_nle20,
                  PB_nle20,
                  NO_nle20,
                  ATL_nle20],
                 [var_EB,
                  var_AB,
                  var_PB,
                  var_NO,
                  var_ATL],
                 ['EB',
                  'AB',
                  'PB',
                  'NO',
                  'ATL']
                 , ax)

ax.set_ylabel(r"$\frac{1}{N} \sum_{T-50}^{T}L_{\theta}^t$", fontsize=16)
#ax.set_ylabel(r"$\frac{1}{N} \sum_{T-50}^{T}L_{\theta}$", fontsize=14)
ax.set_xlabel("Raw spectrum variance", fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=16)
#ax.set_yscale('log')
#ax.set_xscale('log')
#ax.set_ylim([1E-9, 1E0])
ax.set_xlim([0, 0.055])
ax.legend(fontsize=20, frameon=False)
plt.tight_layout()
plt.savefig('../data/fig/FINAL/4_Flybrain_convergence_vs_var.svg')
plt.show()

### 4) Plot the histogram of the synpatic weight distribution

In [None]:
def plot_syn_hist(roi, axs):
    # Load data for the specified ROI
    ROI_data = connectome.load_flybrain(ROI=roi, types="Neurotransmitter")
    #Normalize the connectome
    W, C = connectome.normalize_connectome(
        W=ROI_data["weights"], C=ROI_data["connectivity"]
    )
    WC = np.dot(W, C)
    print('mean',np.mean(WC), 'var',np.var(WC))
    # Plot the histogram of the values in W
    
    axs.hist(WC.flatten(),  bins=200, alpha=0.3, edgecolor='black', label=roi, density=True)
    #axs.set_title('Histogram of Normalized Synaptic Weights', fontsize=14)
    axs.set_xlabel('Weight Value', fontsize=12)
    axs.set_ylabel('Frequency', fontsize=12)
    axs.grid(True, alpha=0.3)
    plt.tight_layout()
    
def plot_syn_hist_1(roi, axs):
    # Load data for the specified ROI
    ROI_data = connectome.load_flybrain(ROI=roi, types="Neurotransmitter")
    
    # Normalize the connectome
    W, C = connectome.normalize_connectome(
        W=ROI_data["weights"], C=ROI_data["connectivity"]
    )
    WC = np.dot(W, C)
    
    # Remove zero values
    WC_nonzero = WC[WC != 0]
    
    # Print mean and variance of the non-zero values
    print('mean:', np.mean(WC_nonzero), 'var:', np.var(WC_nonzero))
    
    # Plot the histogram of the non-zero values in WC
    axs.hist(WC_nonzero, bins=100, alpha=0.3, edgecolor='black', label=roi, density=True)
    print(WC_nonzero)
    axs.set_xlabel('Weight Value', fontsize=12)
    axs.set_ylabel('Frequency', fontsize=12)
    axs.grid(True, alpha=0.3)
    plt.tight_layout()

In [None]:
fig, axs=plt.subplots(1,1,figsize=(7, 5))
plot_syn_hist_1('AB', axs)
axs.set_yscale('log')
axs.set_xlim(-2,2)


### 5) Shift and Gain parametewr distribution

In [None]:
def get_model_path( subfolder:str,
                    activation:str, 
                     ROI:str,
                     weights=False, 
                     shifts=True, 
                     gains=True,
                     nSamples=1,
                     Nle=1,
                     Epochs=400,
                     lr=0.1 ,
                     loss='MSE',
                     target='0.00',
                     tons=0.2,
                     tsim=200,
                     dt=0.1):
    data=[]
    gen_path=os.path.join(utils.get_root(), 'data', 'models',subfolder) 
    model_path=f"{activation}_ROI_{ROI}_Weights{weights}_Shifts{shifts}_Gains{gains}_lr{lr}_NLE{Nle}_Epochs{Epochs}_{loss}_{target}_Tons{tons}_Tsim{tsim}_dt{dt}_Sample0"
    return os.path.join(gen_path, model_path)

In [None]:
EB_nle10=model.RNN(get_model_path(subfolder='flybrain_RNN',activation='tanh_positive', Nle=44,ROI='EB'),'pos' )
PB_nle10=model.RNN(get_model_path(subfolder='flybrain_RNN',activation='tanh_positive', Nle=38,ROI='PB'),'pos' )
ATL_nle10=model.RNN(get_model_path(subfolder='flybrain_RNN',activation='tanh_positive', Nle=53,ROI='ATL'),'pos' )
AB_nle10=model.RNN(get_model_path(subfolder='flybrain_RNN',activation='tanh_positive', Nle=6,ROI='AB'),'pos' )
NO_nle10=model.RNN(get_model_path(subfolder='flybrain_RNN',activation='tanh_positive', Nle=50,ROI='NO'),'pos' )

In [None]:
def plot_gains_shifts_dist(model, axs, serie_name=''):
    bin=int(model.N*0.7)
    array_shifts=model.shifts
    array_gains=model.gains
    axs[0].set_title('Shifts distribution')
    sns.kdeplot(array_shifts, ax=axs[0],label=serie_name)
    axs[0].legend()
    axs[1].set_title('Gains distribution')
    sns.kdeplot(array_gains, ax=axs[1],label=serie_name)
    axs[1].legend()
    plt.tight_layout()
    return

In [None]:
fig, axs=plt.subplots(2,1, figsize=(7,5),)
plot_gains_shifts_dist(EB_nle10, axs,'EB')
plot_gains_shifts_dist(AB_nle10, axs,'AB')
plot_gains_shifts_dist(PB_nle10, axs,'PB')
plot_gains_shifts_dist(NO_nle10, axs,'NO')
plot_gains_shifts_dist(ATL_nle10, axs,'ATL')


axs[0].set_xlim(-1,3)
axs[1].set_xlim(-10,15)