In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import flybrain.utils as utils
import numpy as np 
import matplotlib.pyplot as plt
import torch
import os
import pandas as pd
import seaborn as sns
import json

In [None]:
def load_experiments(
                    param:str,
                     ):
    data=[]
    gen_path=os.path.join(utils.get_root(), 'data', 'logs','spectrum_characterization') 
    
    model_path=f"tanh_positive_N1000_nSample10_tSim200_dt0.1_tOns0.2_{"_".join(param)}"
    with open(os.path.join(gen_path,model_path+"_dim.json")) as f:
        dim= json.load(f)
    with open(os.path.join(gen_path,model_path+"_entropies.json")) as f:
        entro= json.load(f)
   
    with open(os.path.join(gen_path,model_path+"_spectrum.json")) as f:
        spec= json.load(f)
   
    
   
    return {'dimension':dim, 'entropy':entro, 'spec':spec}

## Load the runs

In [None]:
weights=load_experiments(param=['weights'])
gains=load_experiments(param=['gains'])
shifts=load_experiments(param=['shifts'])
gains_shifts_weigths=load_experiments(param=['weights','gains','shifts'])
     

In [None]:
weights['entropy'].keys()

## Plot the training logs 

In [None]:
def plot_logs_trainings(dict_of_logs, 
                        axs,
                        serie, 
                        color):
    x_val=[1,2,4,8, 16, 32, 64, 128, 256, 512, 1024]
    mean=[]
    for i,key in zip (range(len(x_val)),dict_of_logs.keys()):
        axs.scatter(np.ones(len(dict_of_logs[key]))*x_val[i],np.array(dict_of_logs[key])/1000, alpha=0.2, color=color)
        mean.append( np.mean(dict_of_logs[key])/1000)
    axs.plot(x_val, mean,label=serie, color=color, alpha=0.9)
    axs.legend(fontsize=18, frameon=False)
    

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(17, 5), sharex=True, sharey=False)

# Customizing subplots
for ax in axs:
    ax.spines["right"].set_color("none")
    ax.spines["top"].set_color("none")
    ax.tick_params(axis='both', which='major', labelsize=16)  # Increased tick label size
    #ax.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.7)  # Add gridlines

# Plotting data
plot_logs_trainings(weights['dimension'], axs[0], r'$W_{ij}\sim N(0,\sigma/\sqrt{N})$', 'orange')
plot_logs_trainings(gains['dimension'], axs[1], r'$g_{ij}\sim N(0,\sigma/\sqrt{N})$', 'blue')
plot_logs_trainings(shifts['dimension'], axs[1], r'$s_{ij}\sim N(0,\sigma/\sqrt{N})$', 'red')
plot_logs_trainings(gains_shifts_weigths['dimension'], axs[2], r'$W_{ij},s_{ij},g_{ij}\sim N(0,\sigma/\sqrt{N})$', 'green')

# Setting axis labels
axs[0].set_ylabel(r"$\frac{D}{N}$", fontsize=22,weight='bold')
for ax in axs:
    ax.set_xlabel(r"$\sigma$", fontsize=22,weight='bold')

# Adding titles for each subplot
#axs[0].set_title("Weights", fontsize=18)
#axs[1].set_title("Gains and Shifts", fontsize=18)
#axs[2].set_title("Combined", fontsize=18)

# Customizing axis scales and limits
axs[0].set_xscale('log')
axs[0].set_ylim([-.01, 0.2])
axs[1].set_ylim([-.01, 1.1])
axs[2].set_ylim([-.01, 0.4])

# Tight layout for better spacing
plt.tight_layout()

# Saving the figure
plt.savefig('../data/fig/FINAL/0_Spectrum_Char_dim.svg', format='svg')
plt.show()


In [None]:
fig, axs=plt.subplots(1,3, figsize=(17,5),sharex=True, sharey=False)
for ax in axs:
    ax.spines["right"].set_color("none")
    ax.spines["top"].set_color("none")

plot_logs_trainings(weights['entropy'],axs[0], r'$W_{ij}\sim N(0,\sigma/\sqrt{N})$', 'orange' )
plot_logs_trainings(gains['entropy'],axs[1], r'$g_{ij}\sim N(0,\sigma/\sqrt{N})$', 'blue' )
plot_logs_trainings(shifts['entropy'],axs[1], r'$s_{ij}\sim N(0,\sigma/\sqrt{N})$', 'red' )
plot_logs_trainings(gains_shifts_weigths['entropy'],axs[2], r'$W_{ij},s_{ij},g_{ij}\sim N(0,\sigma/\sqrt{N})$', 'green' )

axs[0].set_ylabel(r"$\frac{E}{N}$",fontsize=14)
axs[0].set_xlabel(r"$\sigma$",fontsize=14)
axs[1].set_xlabel(r"$\sigma$",fontsize=14)
axs[2].set_xlabel(r"$\sigma$",fontsize=14)

axs[0].tick_params(axis='both', which='major', labelsize=12)
axs[0].set_xscale('log')
#axs[0].set_yscale('log')
#axs[0].set_ylim([0,200])
#axs[2].set_yscale('log')
axs[0].set_ylim([-.01,0.20])
axs[1].set_ylim([-.1,5])
axs[2].set_ylim([-.01,0.20])
plt.tight_layout()
plt.savefig('../data/fig/FINAL/0_Spectrum_Char_entropy.svg')
plt.show()

In [None]:
def plot_logs_spectrum(dict_of_logs, 
                        axs,
                        serie, 
                        colorMap):
   
    cmap = plt.get_cmap(colorMap)
    colors_2 = [cmap(i / (12 - 1)) for i in range(13)]
    
    x_val=[1,2,4,8, 16, 32, 64, 128, 256, 512, 1024]
    choosen=[1,2,4,8, 16, 32, 64, 128, 256, 512, 1024]
    
    mean=[]
    for i,key in zip (range(len(choosen)),choosen):
        g_mean=[]
        for sample in dict_of_logs[f'{key}']:
            axs.plot(np.linspace(0,100,1000),sample,alpha=0.1, color=colors_2[i+1],lw=1)
            g_mean.append(sample)
        
        mean.append( np.mean(np.array(g_mean), axis=0))
        if (key==1 )|(key ==1024):
            axs.plot(np.linspace(0,100,1000), mean[-1], color=colors_2[i+2], alpha=0.8, label=r'$\sigma$:'+f" {key}",lw=2)
        else :
            axs.plot(np.linspace(0,100,1000), mean[-1], color=colors_2[i+2], alpha=0.8)
    axs.legend(fontsize=18, frameon=False)
   
    axs.hlines(y=0,xmin=0, xmax=100, ls='--', color='black' ,lw=2, alpha=0.5)
    

In [None]:
fig, axs=plt.subplots(1,3, figsize=(17,5),sharex=True, sharey=True)
for ax in axs:
    ax.spines["right"].set_color("none")
    ax.spines["top"].set_color("none")
    ax.tick_params(axis='both', which='major', labelsize=16)  # Increased tick label size
    #ax.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.7)  # Add gridlines

plot_logs_spectrum(weights['spec'],axs[0], r'$W_{ij}\sim N(0,\sigma/\sqrt{N})$', 'Oranges' )
plot_logs_spectrum(gains['spec'],axs[1], r'$W_{ij}\sim N(0,\sigma/\sqrt{N})$', 'Blues' )
plot_logs_spectrum(gains_shifts_weigths['spec'],axs[2], r'$W_{ij}\sim N(0,\sigma/\sqrt{N})$', 'Greens' )
axs[0].set_ylabel(r"$\lambda_i$",fontsize=22, weight='bold')
axs[0].set_xlabel(r"$i$",fontsize=22,weight='bold')
axs[1].set_xlabel(r"$i$",fontsize=22,weight='bold')
axs[2].set_xlabel(r"$i$",fontsize=22,weight='bold')

axs[0].tick_params(axis='both', which='major', labelsize=16)

#axs[0].set_yscale('log')
axs[0].set_ylim([-5,4])
#axs[2].set_yscale('log')
plt.tight_layout()
plt.savefig('../data/fig/FINAL/0_Spectrum_Char_spectrum.svg')
plt.show()