In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Imports
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from msapy import msa, plottings as pl
from netneurotools import cluster, plotting as netplot # comment out the import line 15 in plotting if you got an error here
import bct

from itertools import product
import seaborn as sns
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.model_selection import ParameterGrid

from echoes import ESNGenerator
from echoes.datasets import load_mackeyglasst17
from scipy.stats import ks_2samp, linregress
from copy import deepcopy
import networkx as nx
from tqdm import tqdm

SEED = 2022
RNG = np.random.default_rng(SEED)
pl.set_style()
my_colors = ['#006685', '#3FA5C4', '#FFFFFF', '#E84653', '#BF003F']
my_complementary_colors = ['#E84653', '#BF003F','#A6587C','#591154','#260126']

colormap = sns.blend_palette(my_colors, as_cmap=True)


In [None]:
def generate_wave_data(amp_freq_pairs, timestamps, sampling_rate):
    """A function to generate wave data from an array of amplitude and frequency pairs.
    We use this function to generate some toy examples for illustrating what causal modes are. 

    Args:
        amp_freq_pairs (ndarray): 
            numpy array of shape (n, 2) with n means the number of different amplitude and frequency pairs.
            
        timestamps (ndarray): 
            numpy array of shape (t,) with t means the number of timestamps.
            
        sampling_rate (float): 
            a float number representing the sampling rate.

    Returns:
        _type_: numpy array of shape (n, t) with n means the number of different amplitude and frequency pairs and t means the number of timestamps.
    """
    frequencies = amp_freq_pairs[:, 1, None]
    amplitudes = amp_freq_pairs[:, 0, None]
    timestamps = np.broadcast_to(
        timestamps, (amplitudes.shape[0], sampling_rate))
    data = np.sin(2 * np.pi * timestamps * frequencies) * amplitudes
    return data

def linear_case(complements):
    """the most basic case. Summing the incoming timeseries. Except for the lesioned ones. So the lesion here is defined as not contributing to this sum, and thus, not existing!

    Args:
        complements (tuples): which nodes to lesion.

    Returns:
        (ndarray): the summed up timeseries.
    """
    return data.sum(0) - data[complements, :].sum(0)

doubled_case = lambda x: linear_case(x)*2 # this is the case where we multiply the summed up signal by 2.

non_linear_case = lambda x: np.tanh(linear_case(x)) # this is the case where we pass the summed up signal through a nonlinearity, here it's tanh.



In [None]:
sampling_rate = 500
sampling_interval = 1/sampling_rate
timestamps = np.arange(0, 1, sampling_interval)

frequencies = np.arange(1, 10, 1.5)
amplitudes = np.arange(0.2, 2, 0.4)

amp_freq_pairs = np.array(list(map(list, product(amplitudes, frequencies))))

In [None]:
data = generate_wave_data(amp_freq_pairs, timestamps, sampling_rate)
elements = list(range(len(data)))

In [None]:
plt.figure(dpi=150)
plt.plot(data[5], label="node 5",c="#EAAC8B",lw=2,alpha=0.9)
plt.plot(data[10], label="node 10",c="#E56B6F",lw=2,alpha=0.9)
plt.plot(data[25], label="node 25",c="#B56576",lw=2,alpha=0.9)
plt.title("Example Activities")
plt.xlabel("Time steps")
plt.ylabel("Amplitude");

In [None]:
plt.figure(dpi=150)
plt.plot(linear_case([]),c=my_colors[0],lw=2,alpha=0.9)
plt.title("Combined Activity")
plt.xlabel("Time steps")
plt.ylabel("Amplitude");

In [None]:
shapley_table_linear, _, _ = msa.interface(
    elements=elements,
    n_permutations=10_000,
    objective_function=linear_case,
    n_parallel_games=-1,
    rng=RNG)
linear_modes = shapley_table_linear.groupby(level=1).mean()

In [None]:
plt.figure(dpi=150)
plt.plot(data[5], label="node 5",c="#EAAC8B",lw=2,alpha=0.9)
plt.plot(data[10], label="node 10",c="#E56B6F",lw=2,alpha=0.9)
plt.plot(data[25], label="node 25",c="#B56576",lw=2,alpha=0.9)

plt.plot(linear_modes[5], label="Contribution #5",c="#EAAC8B",lw=4.5,alpha=0.4)
plt.plot(linear_modes[10], label="Contribution #10",c="#E56B6F",lw=4.5,alpha=0.4)
plt.plot(linear_modes[25], label="Contribution #25",c="#B56576",lw=4.5,alpha=0.4)
plt.title("Activity vs Contribution")
plt.xlabel("Time steps")
plt.ylabel("Amplitude")

In [None]:
plt.figure(dpi=150)
plt.plot(linear_case([]),c=my_colors[0],lw=2,alpha=0.9)
plt.plot(linear_modes.sum(1),c=my_colors[0],lw=4.5,alpha=0.4)
plt.title("Reconstructing the Combined Activities")
plt.xlabel("Time steps")
plt.ylabel("Amplitude")

In [None]:
plt.figure(dpi=150)
plt.plot(doubled_case([]),c=my_colors[0],lw=2,alpha=0.9)
plt.title("(Combined Activity)x2")
plt.xlabel("Time steps")
plt.ylabel("Amplitude")

In [None]:
shapley_table_squared, _, _ = msa.interface(
    elements=elements,
    n_permutations=10_000,
    objective_function=squared_case,
    n_parallel_games=-1,
    rng=RNG)
squared_modes = shapley_table_squared.groupby(level=1).mean()

In [None]:
plt.figure(dpi=150)

plt.plot(data[25], label="Activity",c="#B56576",lw=2,alpha=0.9)

plt.plot(squared_modes[25], label="Contribution",c="k",lw=2,alpha=0.9)
plt.title("Activity vs Contribution")
plt.xlabel("Time steps")
plt.ylabel("Amplitude")
plt.legend(loc="lower left")

In [None]:
sns.regplot(data[25],squared_modes[25])

In [None]:
plt.figure(dpi=150)
plt.plot(non_linear_case([]),c=my_colors[0],lw=2,alpha=0.9)
plt.title("tanh(Combined Activity)")
plt.xlabel("Time steps")
plt.ylabel("Amplitude")

In [None]:
shapley_table_nonlinear, _, _ = msa.interface(
    elements=elements,
    n_permutations=10_000,
    objective_function=non_linear_case,
    n_parallel_games=-1,
    rng=RNG)
nonlinear_modes = shapley_table_nonlinear.groupby(level=1).mean()

In [None]:
plt.figure(dpi=150)

plt.plot(data[25], label="Activity",c="#B56576",lw=2,alpha=0.9)

plt.plot(nonlinear_modes[25], label="Contribution",c="k",lw=2,alpha=0.9)
plt.title("Activity vs Contribution")
plt.xlabel("Time steps")
plt.ylabel("Amplitude")
plt.legend(loc="lower left")

In [None]:
sns.regplot(data[25],nonlinear_modes[25])
spearmanr(data[25],nonlinear_modes[25])

In [None]:
mackey_ts = load_mackeyglasst17()
y_train, y_test = train_test_split(mackey_ts[:2500],
                                   train_size=2000,
                                   test_size=500,
                                   shuffle=False)

y_train_validate, y_test_validate = train_test_split(mackey_ts[:3500],
                                                     train_size=3000,
                                                     test_size=500,
                                                     shuffle=False)
# Constructing connectomes and defining the parameters
rng = np.random.default_rng(seed=SEED)
smallworld = nx.generators.watts_strogatz_graph(36, 6, 0.4, seed=SEED)
connectivity_matrix = nx.to_numpy_array(smallworld)

ser_connectome = connectivity_matrix * rng.uniform(low=0.1,
                                          high=1,
                                          size=(len(connectivity_matrix), len(connectivity_matrix)))

# randomly assigning weights 50 times to have 50 ESN connectomes. This'll make the network more robust.
esn_weights = [connectivity_matrix * rng.uniform(low=-0.5,
                                        high=0.5,
                                        size=(len(connectivity_matrix), len(connectivity_matrix))) for _ in range(50)]

parameter_space = list(ParameterGrid({'radius': np.linspace(0.5, 1, 20),
                                      'W': esn_weights}))
esn = ESNGenerator(n_steps=500,
                   spectral_radius=None,  # to be filled during the hyperparameter tuning
                   leak_rate=0.1,
                   random_state=SEED,
                   W=None,  # to be filled during the hyperparameter tuning
                   W_fb=rng.integers(-1, 1, len(connectivity_matrix)).reshape(-1, 1),
                   W_in=rng.uniform(-1, 1, len(connectivity_matrix)).reshape(-1, 1),
                   bias=0.001)
# Finding the best weights and spectral radius (hyperparameter tuning):
for parameters in tqdm(parameter_space,
                       total=len(parameter_space),
                       desc='Optimizing hyperparams: '):
    esn.spectral_radius = parameters['radius']
    esn.W = parameters['W']
    esn.fit(X=None, y=y_train)
    y_pred = esn.predict()
    parameters.update({'error': mean_squared_error(y_test, y_pred)})

optimum_params = min(parameter_space, key=lambda x: x['error'])  # the combination with the minimum MSE

In [None]:
# testing the network.
esn.spectral_radius = optimum_params['radius']
esn.W = optimum_params['W']
esn.store_states_pred=True
esn.fit(X=None, y=y_train_validate)

y_pred = esn.predict()
mse = mean_squared_error(y_test_validate, y_pred)

print(f'MSE: {mse:.2}, Optimal Spectral Radius: {optimum_params["radius"]:.2}')

In [None]:
lesion_esn_params = {'network': esn, 'training_data': y_train_validate}
shapley_table_esn, _, _ = msa.interface(multiprocessing_method='joblib',
                                    elements=list(range(esn.n_reservoir_)),
                                    n_permutations=1_000,
                                    objective_function=lesion_esn,
                                    objective_function_params=lesion_esn_params,
                                    n_parallel_games=-1,
                                    random_seed=SEED)
esn_modes = shapley_table_esn.groupby(level=1).mean()

In [None]:
from_to=[10,16]
data_range=esn_modes.T[from_to[0]:from_to[1]]
states_range = esn.states_pred_.T[from_to[0]:from_to[1]]
complementary_colormap = sns.blend_palette(my_complementary_colors, n_colors=len(data_range),as_cmap=False)
fig,axes = plt.subplot_mosaic(
    [['A','B']],figsize=(10,3),dpi=150)
sns.lineplot(data=data_range.T,palette=complementary_colormap,dashes=False,alpha=0.8,legend='brief',ax=axes['B'])
sns.lineplot(data=states_range.T,palette=complementary_colormap,dashes=False,alpha=0.8,legend=False,ax=axes['A'])


fig.tight_layout(pad=0.4)

In [None]:
median_index = np.argsort(np.abs(esn.W_out_).squeeze())[len(esn.W_out_.squeeze())//2]
sns.scatterplot(esn.states_pred_.T[np.argmax(np.abs(esn.W_out_))],esn_modes[np.argmax(np.abs(esn.W_out_))].squeeze(),color=my_complementary_colors[0])
sns.scatterplot(esn.states_pred_.T[median_index],esn_modes[median_index].squeeze(),color=my_complementary_colors[-2])
sns.scatterplot(esn.states_pred_.T[np.argmin(np.abs(esn.W_out_[0,:-1]))],esn_modes[np.argmin(np.abs(esn.W_out_[0,:-1]))].squeeze(),color=my_complementary_colors[-1])



In [None]:
plt.figure(dpi=150)
plt.plot(y_pred.squeeze())
plt.plot(esn_modes.sum(axis=1))


In [None]:
mean_squared_error(y_pred.squeeze(),esn_modes.sum(axis=1))

In [None]:
from scipy.spatial.distance import cosine
from scipy.stats import spearmanr, pearsonr

In [None]:
modes_similarities = np.array(esn_modes.corr())
fc = np.corrcoef(esn.states_pred_.T)

In [None]:
fig,axes = plt.subplot_mosaic(
    [['A','B']],figsize=(10,5),dpi=150)
sns.heatmap(modes_similarities,square=True,center=0,
            cmap=colormap,cbar_kws={"shrink": .5},ax=axes['A'],linewidths=0., linecolor='k')
sns.heatmap(fc,square=True,
            cmap=colormap,center=0,cbar_kws={"shrink": .5},ax=axes['B'],linewidths=0., linecolor='k')
axes['A'].title.set_text('Fluctuation of the Causal Modes')
axes['B'].title.set_text('Functional Connectivity')

In [None]:
gamma = 1.
mode_communities = [bct.clustering.modularity_louvain_und_sign(modes_similarities, gamma=gamma)[0] for _ in range(1000)]
mode_consensus = cluster.find_consensus(np.column_stack(mode_communities), seed=SEED)
fc_communities = [bct.clustering.modularity_louvain_und_sign(fc, gamma=gamma)[0] for _ in range(1000)]
fc_consensus = cluster.find_consensus(np.column_stack(fc_communities), seed=SEED)


In [None]:
fig,axes = plt.subplot_mosaic([['A','B']],figsize=(10,5),dpi=150)
netplot.plot_mod_heatmap(modes_similarities, mode_consensus, cmap=colormap,ax=axes['A'],cbar=False)
netplot.plot_mod_heatmap(fc, fc_consensus, cmap=colormap,ax=axes['B'],cbar=False,)

In [None]:
nonnegative_modes = modes_similarities.copy()
nonnegative_modes[modes_similarities<0] = 0

nonnegative_fc = fc.copy()
nonnegative_fc[fc<0] = 0

In [None]:
ks_index=[]
for i in range(len(esn_modes.T)):
    temp=[]
    temp,_=ks_2samp(np.array(esn.states_pred_.T[i]),np.array(esn_modes[i]))
    ks_index.append(temp)


In [None]:
pc = bct.participation_coef(nonnegative_modes, mode_consensus)
data=pd.DataFrame({'ks':ks_index,'pc':pc})
sns.regplot(x='ks',y='pc',data=data,truncate=False,color='k')
linregress(ks_index,pc)

In [None]:
structural_communities = [bct.community_louvain(connectivity_matrix, gamma=gamma)[0] for n in range(1000)]

structural_consensus = cluster.find_consensus(np.column_stack(structural_communities), seed=SEED)
netplot.plot_mod_heatmap(esn.W, structural_consensus, cmap=colormap)

In [None]:
from network_control.utils import matrix_normalization
from network_control.metrics import ave_control, modal_control
A = matrix_normalization(np.abs(esn.W), c=1, version='discrete')
mc = modal_control(A)
ac = ave_control(A)

In [None]:
test = mc
sns.regplot(ks_index,test)
linregress(ks_index,test)