# ms packages and functions

repository of all important packages and functions for network analysis

## Packages

In [1]:
#########################
# IMPORTING PACKAGES
#########################


# FOR VISUALIZING
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from astropy.convolution import Gaussian1DKernel
from astropy.convolution import convolve
import seaborn as sns
from matplotlib import animation, rc
from IPython.display import HTML
import matplotlib as mpl
mpl.rcParams['animation.embed_limit'] = 2**128
from mpl_toolkits.mplot3d import Axes3D
from IPython.display import display, clear_output
# import plotly.express as px
from matplotlib.ticker import FormatStrFormatter
from matplotlib.lines import Line2D
%matplotlib inline
import plotly.graph_objects as go
import plotly.express as px
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
mpl.rc('font',family='Roboto')
from statannotations.Annotator import Annotator
from scipy.interpolate import UnivariateSpline
from matplotlib.patches import Rectangle
import string

# FOR COMPUTATION
import numpy as np
import pandas as pd
import scipy.io as sio
from scipy.spatial import distance as dist
import math
from collections import Counter
from sklearn.metrics.cluster import normalized_mutual_info_score
import sklearn.datasets
import random
from scipy import stats
import itertools
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import umap
import umap.plot
import umap.utils as utils
import umap.aligned_umap
from scipy import interpolate
import hdbscan
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.linear_model import LinearRegression, Lasso
import neo
import elephant as el
import quantities as pq
import re
from multiprocess import Pool
from statsmodels.stats.multitest import multipletests

# FOR SAVING/DATA HANDLING
import sys
from os.path import dirname, join as pjoin
import os
import json
import pickle
import time
import copy
import datetime
import h5py
import mat73
from datetime import date

# FOR NETWORK ANALYSIS
import networkx as nx
from nodevectors import Node2Vec
from gensim.models import KeyedVectors
# from node2vec import Node2Vec #original node2vec implementation
from teneto import TemporalNetwork
import teneto
import infomap
import community
#import igraph as ig
import bct



  warn("Tensorflow not installed; ParametricUMAP will be unavailable")


## Functions

### FOR DATA WRANGLING

#### bin_spikes(raster_data, startTime, endTime, binwin, binary=True, plot = True, cmap = 'binary_r', gauss = False, std = 0.1):

|Input | Process| Output
|--- | ---|---|
|a trial from <br> `trials.json`| bins spike times | `binned_spks`<br>`bin_edges`<br>optional:plots a raster plot|

In [None]:
def bin_spikes(raster_data, startTime, endTime, binwin, binary=True, plot = True, cmap = 'binary_r', gauss = False, std = 0.1):
    # required packages: numpy, astropy (if convolving), matplotlib (if plotting)
    
    binned_spks = []

    if gauss:
        kernel = Gaussian1DKernel(std)
        
    for n, neuron in enumerate(raster_data['spiketimes']):
        bins = np.arange(startTime,endTime,binwin)
        binned, bin_edges = np.histogram(neuron,bins)

        if binary:
            binned[binned>0] = 1

        if gauss:
            binned = convolve(binned,kernel)
            
        binned_spks.append(binned)
        
    if plot:
        plt.imshow(binned_spks,cmap=cmap)
        if gauss:
            plt.colorbar()
    
    return binned_spks, bin_edges

### FOR FUNCTIONAL NETWORK CONSTRUCTION

In [None]:
def MI(prob_matrix):

    # input: prob_matrix of shape (s,r)
    # stimulus: presyn: rows
    # (number of stimulus = rows in matrix = prob_matrix.shape[0])
    # response: postsyn: columns
    # (number of stimulus = columns in matrix = prob_matrix.shape[0])

    MI = 0
    for r, response in enumerate(np.sum(prob_matrix,axis=1)):
        for s, stimulus in enumerate(np.sum(prob_matrix,axis=0)):
            
            pr = response 
            ps = stimulus
            
            if (prob_matrix[r,s] != 0) and (ps!=0) and (pr!=0):
                MI += prob_matrix[r,s]*(np.log2(prob_matrix[r,s]) - np.log2(pr*ps))
    
    return MI

In [None]:
def make_FN(FN_data,metric='fMI',plot=False,self_edge=False, norm=False):
    
    FN = np.zeros((len(FN_data),len(FN_data)))
    
    # Mutual Information
    if metric=='MI':
        for i, presyn in enumerate(FN_data):
            for j, postsyn in enumerate(FN_data):
                prob_matrix = np.zeros((2,2))

                # Both are active
                aa = np.where(np.logical_and(presyn>0,postsyn>0))
                aa = len(aa[0].tolist())/len(presyn) 
                prob_matrix[1,1] = aa

                # Both are inactive
                ii =  np.where(np.logical_and(presyn==0,postsyn==0))
                ii = len(ii[0].tolist())/len(presyn)
                prob_matrix[0,0] = ii

                # Only presyn is active
                ai =  np.where(np.logical_and(presyn>0,postsyn==0))
                ai = len(ai[0].tolist())/len(presyn)
                prob_matrix[1,0] = ai

                # only postsyn is active
                ia =  np.where(np.logical_and(presyn==0,postsyn>0))
                ia = len(ia[0].tolist())/len(presyn)
                prob_matrix[0,1] = ia

                if norm:
                    FN[i,j] = normalized_mutual_info_score(presyn,postsyn)
                else:
                    MIij = MI(prob_matrix)
                    FN[i,j] = MIij
                
    # Consecutive Mutual Information
    elif metric=='cMI':
        for i, presyn in enumerate(FN_data):
            presyn  = presyn[:-1] #don't include last (to keep everything in the correct order)
            for j, postsyn in enumerate(FN_data):
                # shift one timebin over
                postsyn = postsyn[1:] #don't include the first (because that is rolled value)
                
                prob_matrix = np.zeros((2,2))

                # Both are active
                aa = np.where(np.logical_and(presyn>0,postsyn>0))
                aa = len(aa[0].tolist())/len(presyn) 
                prob_matrix[1,1] = aa

                # Both are inactive
                ii =  np.where(np.logical_and(presyn==0,postsyn==0))
                ii = len(ii[0].tolist())/len(presyn)
                prob_matrix[0,0] = ii

                # Only presyn is active
                ai =  np.where(np.logical_and(presyn>0,postsyn==0))
                ai = len(ai[0].tolist())/len(presyn)
                prob_matrix[1,0] = ai

                # only postsyn is active
                ia =  np.where(np.logical_and(presyn==0,postsyn>0))
                ia = len(ia[0].tolist())/len(presyn)
                prob_matrix[0,1] = ia

                if norm:
                    FN[i,j] = normalized_mutual_info_score(presyn,postsyn)
                else:
                    MIij = MI(prob_matrix)
                    FN[i,j] = MIij
                
    #full Mutual Information
    elif metric=='fMI':
        for i, presyn in enumerate(FN_data):
            presyn  = presyn[:-1] #don't include last (to keep everything in the correct order)
            for j, postsyn in enumerate(FN_data):
                # shift one timebin over
                postsyn = postsyn[1:] + postsyn[:-1] # look at both consecutive and simoulaneous timebins!
                postsyn[postsyn>0] = 1
            
                prob_matrix = np.zeros((2,2))

                # Both are active
                aa = np.where(np.logical_and(presyn>0,postsyn>0))
                aa = len(aa[0].tolist())/len(presyn) 
                prob_matrix[1,1] = aa

                # Both are inactive
                ii =  np.where(np.logical_and(presyn==0,postsyn==0))
                ii = len(ii[0].tolist())/len(presyn)
                prob_matrix[0,0] = ii

                # Only presyn is active
                ai =  np.where(np.logical_and(presyn>0,postsyn==0))
                ai = len(ai[0].tolist())/len(presyn)
                prob_matrix[1,0] = ai

                # only postsyn is active
                ia =  np.where(np.logical_and(presyn==0,postsyn>0))
                ia = len(ia[0].tolist())/len(presyn)
                prob_matrix[0,1] = ia

                if norm:
                    FN[i,j] = normalized_mutual_info_score(presyn,postsyn)
                else:
                    MIij = MI(prob_matrix)
                    FN[i,j] = MIij   
    #elif metric=='corr':
         
    #elif metric=='lagcorr':
    else:
        print('That metric does not exist!')
     
    if not self_edge: #zero out diagonal
        np.fill_diagonal(FN,0)
        
    if plot:
        sns.heatmap(FN,cmap= 'magma',square=True)
        plt.show()
        
    return FN

### FOR FUNCTIONAL NETWORK ANALYSIS

In [None]:
def mat2edgelist(matrix,weighted=False,weight_threshold=0):
    # INPUT:
    # matrix = adjecency matrix
    # OUTPUT:
    # edgelist = formatted as a list [[source],[target],[weight]]
    
    edgelist = []
    for i,ival in enumerate(matrix):
        for j, jval in enumerate(ival):
            if not weighted:
                weight = 1*(jval>weight_threshold)
            else:
                weight = jval
            if weight>=weight_threshold:
                edgelist.append([i,j,jval])
    return edgelist 

In [None]:
def findInfomapCommunities(G, N=10):

    infomapWrapper = infomap.Infomap("--directed --two-level -0 -N{}".format(N))

    for e in G.edges():
        infomapWrapper.addLink(*e)

    infomapWrapper.run();
    
    communities = infomapWrapper.getModules()
    
    nx.set_node_attributes(G, communities, 'community')
    
    return infomapWrapper.numTopModules()

In [None]:
def findLouvainCommunities(G, N=10):
    """
    Partition a networkx graph G with the Louvain algorithm.
    Annotates nodes with 'community' id and returns the number of communities found.
    """
    best_partition = None
    best_modularity = float("inf")
    for i in range(N):
        partition = community.best_partition(G, randomize=True)
        if community.modularity(partition, G) < best_modularity:
            best_partition = partition
    nx.set_node_attributes(G, name='community', values=best_partition)
    return len(Counter(best_partition.values()))

In [None]:
def getNetwork(graph_data,threshold):
    edgelist = mat2edgelist(graph_data,weighted=True,weight_threshold=threshold)
    g = nx.DiGraph()
    for i in edgelist:
        g.add_edge(i[0],i[1],weight=i[2])
    return g

In [None]:
def poissSpkTrain_r(T,dt,tau,r0):
    num_steps = T/dt #number of iterations
    t_steps = np.arange(0,T,dt)

    spkTimes = []
    isi = []
    spkTrain = []

    t_old = dt
    r_old = r0
    r     = r0
    for i in t_steps:
        tmp = random.random()
        pf = r * dt

        if tmp <= pf:
            spkTimes.append(i)
            isi.append(i-t_old)
            t_old = i

            r = 0

            spkTrain.append(1)
        else:
            spkTrain.append(0)

        r_old = r
        dr = (r0-r_old)*(dt/tau)
        r  = r+dr
    
    return spkTimes,isi,spkTrain,t_steps

In [None]:
def poissSpkTrain(r_est,T,dt):
    '''
    INPUT:
    r_est = estimated firing rate (spks/s)
    dt    = time steps (s)
    T     = total time (s)
    '''
    pf = r_est * dt #probability of firing

    num_steps = T/dt #number of iterations
    t_steps = np.arange(0,T,dt)

    spkTimes = []
    isi = []
    spkTrain = []

    t_old = dt
    
    for i in t_steps:
        tmp = random.random()
        if tmp <= pf:
            spkTimes.append(i)
            isi.append(i-t_old)
            t_old = i
            
            spkTrain.append(1)
        else:
            spkTrain.append(0)

    return spkTimes,isi,spkTrain,t_steps

In [None]:
# from : https://stackoverflow.com/questions/2566412/find-nearest-value-in-numpy-array 
def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx], idx

In [1]:
def status(string):
    clear_output(wait=True)                
    return display(string) 

**Graph Alignment**

From Maayan Levy (a function for Graph Alignment in Matlab):
```matlab
function align_score = lp_alignment(adj1, adj2)
    mp = cat(3, adj1, adj2);
    min_mat = min(mp,[],3);
    align_score = (2*sum(sum(min_mat)))/(sum(sum(sum(mp))));
end
```

In [None]:
def align_score(G1,G2):
    # todo error if G1 and G2 are not the same dimension
    min_mat = np.minimum(G1,G2)
    align_score = (2*np.sum(min_mat))/np.sum([np.sum(G1),np.sum(G2)])
    return align_score

In [None]:
def temporal_align_score_mp(G):
    G1 = G[0]
    G2 = G[1]
    c  = G[2] # circ distance
    
    return [c,align_score(G1,G2)]
    # null_alignment_scores = []

    # min_trial_length = min([len(z[0]) for y in null_FNs[0] for x in y for z in x])

    # all_trials = [[dir,trial] for dir in range(0,len(null_FNs[0]))
    #                         for trial in range(0,len(null_FNs[0][dir][0]))]

    # FN_pairs = [y for y in itertools.combinations(all_trials,2)]

    # cdist = [circdistance(8,y[0][0],y[1][0]) for y in FN_pairs]

    # for t in range(0,min_trial_length):
    #     status("getting alignment scores for time: {}s".format(bin2time(t)))       
    #     null_alignment_scores.append([])
    #     [null_alignment_scores[t].append([]) for x in np.unique(cdist)]

    #     for c, pair in zip(cdist,FN_pairs):
    #         score = align_score(null_FNs[0][pair[0][0]][0][pair[0][1]][0][time],null_FNs[0][pair[1][0]][0][pair[1][1]][0][time])
    #         null_alignment_scores[t][c].append(score)

    # return null_alignment_scores

In [None]:
def meanFRoverSims(mp_data):
    data_fr = mp_data[0]
    trial_shape = mp_data[1]

    binwin = 0.01
    mov_window = 0.2 # 200 ms
    mov_10ms = int(0.01/binwin) # move the mov_window this many bins
    windowBin = int(mov_window/binwin)


    poiss_spiketrain = el.spike_train_generation.inhomogeneous_poisson_process(data_fr)

    time_fr = []

    for j in range(0,trialShape-windowBin,mov_10ms):
        try:
            fr = el.statistics.mean_firing_rate(poiss_spiketrain,j*binwin*pq.s,(j+windowBin)*binwin*pq.s).magnitude

        except:
            fr = 0 
        
        time_fr.append(fr)

    return time_fr

In [1]:
def circdistance(len_my_list, idx_1, idx_2):
    i = (idx_1 - idx_2) % len_my_list
    j = (idx_2 - idx_1) % len_my_list
    return min(i, j)

In [None]:
def binnedPos(xPos,xTime,yPos,yTime,binwin):
    
    x = xPos
    xtime = xTime-min(xTime)
    fx = interpolate.interp1d(xtime,x)
    
    
    y = yPos
    ytime = yTime-min(yTime)
    fy = interpolate.interp1d(ytime,y)
    
    newTime =  np.arange(0, max(max(xtime),max(ytime)), binwin)
    xnew = fx(newTime)
    ynew = fy(newTime)
    
    return xnew,ynew,newTime

In [9]:
def time2bin(time,binwin=0.01,buffer=0.45,lastBin=False, window=0.2):
    
    # the default is to get the first bin of integration, but sometimes we want the last bin! 
    if lastBin:
        bin_idx = int((time+buffer-window)/binwin) 
    else:
        bin_idx = int(((time)+(buffer))/binwin)
        
    return bin_idx

In [23]:
def bin2time(bin_idx,binwin=0.01,buffer=0.45,lastBin=False, window=0.2):
        # the default is to get the first bin of integration, but sometimes we want the last bin! 
    if lastBin:
        time = (bin_idx*binwin)-buffer+window
    else:
        time = (bin_idx*binwin)-buffer
    
    return time

In [None]:
def measure_percent_hits(neighbors_by_labels):
    
    percent_hits = []
    for row in neighbors_by_labels:
        hit_count = 0
        label = row[0]
        neighbors = row[1:]
        for neighbor in neighbors:
            if neighbor==label:
                hit_count += 1
        score = 1.0*hit_count / len(neighbors)
        percent_hits.append(score)
        
    return percent_hits

In [None]:
def flattenNetwork(FN):
    return FN.flatten('F')

In [None]:
# def thresholdNetwork(FN,percentile,type ='network', returnDim = True, binary=False):
#     if FN.ndim == 2 :
#         FN_flattened = flattenNetwork(FN)
#     elif FN.ndim == 1:
#         FN_flattened = FN
#     else:
#         "Dimensions of FN not accepted"
#         return None 
#     if type=='network':
#         if binary:
#             thresholded_FN = np.where(FN_flattened > np.percentile(FN_flattened,percentile), 1, 0)
#         else:
#             thresholded_FN = np.where(FN_flattened > np.percentile(FN_flattened,percentile), FN_flattened, 0)
    
    

#     if returnDim:
#         thresholded_FN = np.reshape(thresholded_FN,FN.shape)

#     return thresholded_FN

#threshold Network beta update: has type argument
def thresholdNetwork(FN,percentile,type ='network', direction = None, returnDim = True, binary=False):


    if type =='network':
        # NETWORK-wise thresholding means we are taking the top p percentile of weights of the whole network
        # check network dimensions
        if FN.ndim == 2 :
            FN_flattened = flattenNetwork(FN)
        elif FN.ndim == 1:
            FN_flattened = FN
        else:
            "Dimensions of FN not accepted"
            return None 

        if binary:
            thresholded_FN = np.where(FN_flattened > np.percentile(FN_flattened,percentile), 1, 0)
        else:
            thresholded_FN = np.where(FN_flattened > np.percentile(FN_flattened,percentile), FN_flattened, 0)

    elif type =='node':
        # NODE-wise thresholding. We look at either the in- or out- going connections and threshold per node.
        # check network dimensions; we are going row by row (source node-wise) so we actually want an adjacency matrix
        if FN.ndim == 2 :
            square_FN = FN
        elif FN.ndim == 1:
            square_FN = flat2squareNetwork(FN)
        else:
            "Dimensions of FN not accepted"
            return None

        # check that argument direction exists (possible values: 'in','out','undirected')
        if direction  == 'out' or direction == 'undirected':
            # we treat undirected as 'out', since theoretically this should be the same if we used 'in'
            if binary:
                thresholded_FN = [np.where(x > np.percentile(x,percentile), 1, 0) for x in square_FN]
            else:
                thresholded_FN = [np.where(x > np.percentile(x,percentile), x, 0) for x in square_FN]

        elif direction == 'in':
            if binary:
                thresholded_FN = [np.where([x[c] for x in square_FN] > np.percentile([x[c] for x in square_FN],percentile), 1, 0)  for c in range(len(square_FN))]
            else:
                thresholded_FN = [np.where([x[c] for x in square_FN] > np.percentile([x[c] for x in square_FN],percentile), [x[c] for x in square_FN], 0)  for c in range(len(square_FN))]
        else: 
            "The 'direction' argument is invalid. Please specify the direction you want to threshold ('in'/'out') or if undirection ('undirected')"
            return None

    if returnDim:
        thresholded_FN = np.reshape(thresholded_FN,FN.shape)

    return thresholded_FN

In [None]:
def loadPickle(filepath):
    with open(filepath, 'rb') as f:
        data = pickle.load(f)   
    return data

In [None]:
def savePickle(filepath,save_obj):
    f = open(filepath,"wb")
    pickle.dump(save_obj,f)
    f.close()

In [1]:
def getDirection(x,y):
    # x and y are arrays. we will get the direction of the reach based on the end points
    # output is in degrees
    direction = np.arctan2(y[-1]-y[0],x[-1]-x[0]) * 180 / np.pi
    direction = (direction+360) % 360
    return direction


In [None]:
def tripletCC(W,weighted=True):
    weighted = True
    motifs = ['cycle','middleman','fan-in','fan-out']

    CC_by_motif = {}
    d_o  = np.count_nonzero(W, axis=0)
    d_i  = np.count_nonzero(W, axis=1)
    d_bi = np.count_nonzero(np.where(np.dot(W,W.transpose())>0,1,0),axis=0) #np.count_nonzero(np.diagonal(np.dot(W,W)))
    if weighted:
        A = np.cbrt(W)
    else:
        A = W

    for motif in motifs:
        CC_by_motif[motif] = {}

        if motif =='cycle':
            actual_motif_count =  np.diagonal(np.linalg.multi_dot([A,A,A]))
            possible_motif_count = ((d_o*d_i)-d_bi)
        if motif =='middleman':
            actual_motif_count =  np.diagonal(np.linalg.multi_dot([A,A.transpose(),A]))
            possible_motif_count = ((d_o*d_i)-d_bi)
        if motif =='fan-in':
            actual_motif_count =  np.diagonal(np.linalg.multi_dot([A.transpose(),A,A]))
            possible_motif_count = d_i*(d_i-1)
        if motif =='fan-out':
            actual_motif_count =  np.diagonal(np.linalg.multi_dot([A,A,A.transpose()]))
            possible_motif_count = d_o*(d_o-1)

        CC = actual_motif_count/possible_motif_count
        
        CC[CC==0] = np.nan
        CC = np.nan_to_num(CC,nan=0,posinf=0,neginf=0)
        CC_by_motif[motif]['meanCC'] = np.mean(CC)
        CC_by_motif[motif]['byNode'] = CC
        
    return CC_by_motif

In [None]:
def LaplacianFlatten(FN):
    FN_reshaped = flat2squareNetwork(FN)
    G = nx.from_numpy_matrix(FN_reshaped,create_using=nx.DiGraph())
    G_L = np.asarray(nx.linalg.directed_laplacian_matrix(G))
    G_L_flattened = flattenNetwork(G_L)
    return G_L_flattened

In [None]:
def flat2squareNetwork(FN):
    return np.reshape(FN, (int(np.sqrt(FN.shape[0])),int(np.sqrt(FN.shape[0]))))

In [1]:
def jitterSpikes(x,jitter, precision=0.001):
    # add +/- range to each element of x
    # assumes spiketimes are in seconds and jitters a millisecond
    rand_array = np.arange(-jitter,jitter,precision)
    jittered_x = np.sort([y+np.random.choice(rand_array) for y in x])
    return jittered_x

In [None]:
def poissSpkTrain_data(T,dt,r_data):

    t_steps = np.arange(0,T,dt)

    spkTimes = []
    isi = []
    spkTrain = []
    t_old = 0 
    k=0
    for i in t_steps:
        tmp = random.random()
        pf = r_data[k]

        if tmp <= pf:
            spkTimes.append(i)
            isi.append(i-t_old)
            t_old = i

            r = 0

            spkTrain.append(1)
        else:
            spkTrain.append(0)

        k +=1
    return spkTimes,isi,spkTrain,t_steps

In [None]:
def lscov(A, B, w=None):
    """Least-squares solution in presence of known covariance

    :math:`A \\cdot x = B`, that is, :math:`x` minimizes
    :math:`(B - A \\cdot x)^T \\cdot \\text{diag}(w) \\cdot (B - A \\cdot x)`.
    The matrix :math:`w` typically contains either counts or inverse
    variances.

    Parameters
    ----------
    A: matrix or 2d ndarray
        input matrix
    B: vector or 1d ndarray
        input vector

    Notes
    --------
    https://de.mathworks.com/help/matlab/ref/lscov.html
    """
    # https://stackoverflow.com/questions/27128688/how-to-use-least-squares-with-weight-matrix-in-python
    # https://de.mathworks.com/help/matlab/ref/lscov.html
    if w is None:
        Aw = A.copy()
        Bw = B.T.copy()
    else:
        W = np.sqrt(np.diag(np.array(w).flatten()))
        Aw = np.dot(W, A)
        Bw = np.dot(B.T, W)

    # set rcond=1e-10 to prevent diverging odd indices in x
    # (problem specific to ggf/stress computation)
    x, residuals, rank, s = np.linalg.lstsq(Aw, Bw.T, rcond=1e-10)
    return np.array(x).flatten()

In [None]:
def weighted_reciprocity(FN):
    w_r = np.minimum(FN,FN.transpose())
    w_nr = FN-w_r
    r = np.sum(w_r)/np.sum(FN)
    return r

In [None]:
def normalizeNetworkMetric(data,null):
    return (data-null)/(1-null)

In [None]:
# Resample FN
def subsampleFN(FN,nodes = None, num_nodes=None):
    # check that FN is square
    if FN.shape[0]!=FN.shape[1]:
        print("Error: Dimensions of FN not accepted. Must be square (NxN)!")
        return None 

    if nodes is None and num_nodes is None:
        print("Error: you must either put a number of nodes (num_nodes) to subsample or an array with node IDs (nodes)!")
        return None
    else:
        # check that num_nodes<original number of nodes
        if num_nodes is not None:
            if FN.shape[0]<num_nodes:
                print("Error: num_nodes must be less than the number of the original nodes in your network!")
                return None 
            else:
                resampled_nodes = np.sort(np.random.choice(np.arange(FN.shape[0]),num_nodes,replace=False))
        elif nodes is not None:
            if FN.shape[0]<np.max(nodes) and FN.shape[0]<len(nodes) :
                print("Error: nodes must be less than the number of the original nodes in your network!")
                return None 
            else:
                resampled_nodes = nodes

    # get subsampled FN
    resampled_FN = np.array([np.array([FN[i,j] for j in resampled_nodes]) for i in resampled_nodes])
    
    return resampled_FN,resampled_nodes

In [None]:
# modified from: https://stackoverflow.com/questions/61760669/numpy-1d-array-find-indices-of-boundaries-of-subsequences-of-the-same-number
def first_and_last_seq(x, n):
    a = np.r_[n-1,x,n-1]
    a = a==n
    start = np.r_[False,~a[:-1] & a[1:]]
    end = np.r_[a[:-1] & ~a[1:], False]
    return [np.where(start)[0]-1, np.where(end)[0]-1]

In [None]:
# from https://stackoverflow.com/a/35094823
def autoscale_y(ax,margin=0.1):
    """This function rescales the y-axis based on the data that is visible given the current xlim of the axis.
    ax -- a matplotlib axes object
    margin -- the fraction of the total height of the y-data to pad the upper and lower ylims"""

    import numpy as np

    def get_bottom_top(line):
        xd = line.get_xdata()
        yd = line.get_ydata()
        lo,hi = ax.get_xlim()
        y_displayed = yd[((xd>lo) & (xd<hi))]
        h = np.max(y_displayed) - np.min(y_displayed)
        bot = np.min(y_displayed)-margin*h
        top = np.max(y_displayed)+margin*h
        return bot,top

    lines = ax.get_lines()
    bot,top = np.inf, -np.inf

    for line in lines:
        new_bot, new_top = get_bottom_top(line)
        if new_bot < bot: bot = new_bot
        if new_top > top: top = new_top

    ax.set_ylim(bot,top)