In [5]:
import numpy as np
import pandas as pd

def run_cascade_single_population(adj_matrix, thr, seed_node_index):

    infected_nodes=np.zeros((adj_matrix.shape[0]))
    input_to_node=np.sum(adj_matrix,axis=0)
            
    infected_nodes[seed_node_index]=1
    
    list_of_infected_nodes_per_iter=[]
    list_of_infected_nodes_per_iter.append(np.where(infected_nodes==1)[0].tolist()) # list of lists
    counter=0
    
    while int(np.sum(infected_nodes))<adj_matrix.shape[0]:
        
        if counter>30: # Realistically, we should converge in max 10 steps (see Misic et al Neuron)
            
            #print('I got stuck, threshold of '+str(thr)+ ' was too high')
            break
        
        indices_of_infected_nodes=np.where(infected_nodes==1)[0]
        
        mask_array=np.zeros((adj_matrix.shape))
        mask_array[indices_of_infected_nodes,:]=1
        
        infected_connections=adj_matrix.copy()
        infected_connections=infected_connections*mask_array
        infected_inputs=np.sum(infected_connections,axis=0)
        infected_nodes_indices=np.where(infected_inputs/input_to_node>thr)[0]
        list_of_infected_nodes_per_iter.append(infected_nodes_indices.tolist())
        infected_nodes[infected_nodes_indices]=1
        counter=counter+1
        
    return list_of_infected_nodes_per_iter
        
def find_thr(adj_matrix,starting_thr):

    # initializes list with None values, whose length is nodes of mtrx
    visited_thresholds_per_node=[None]*adj_matrix.shape[0]

    # loop iterates over each node in the network. 
    # seed_node_index represents the index of the current node being considered.         
    for seed_node_index in range(0,adj_matrix.shape[0]):
        #start_time=time.time()     

        # ensure that position in the list is empty and ready to store the thresholds
        visited_thresholds_per_node[seed_node_index]=[]
        thr=starting_thr
        
        for dummy_thr in range(0,1000): # In doing so we guarantee to converge in less than 1000 steps, but do we?
            # function has to return list of infected notes per iteration
        ## so modify simulation code
            # simulation starts from current node aka seed_node_index
            list_of_infected_nodes_per_iter=run_cascade_single_population(adj_matrix, thr, seed_node_index)
            #print('list:', list_of_infected_nodes_per_iter)

            #if length of last element of list is equal to size of matrix aka all nodes are infected
            if len(list_of_infected_nodes_per_iter[-1])==adj_matrix.shape[0]:
                #double the thr
                thr=thr*2
                #append thr for that seed node
                visited_thresholds_per_node[seed_node_index].append(thr)
            elif (dummy_thr==0) and (len(list_of_infected_nodes_per_iter[-1])!=adj_matrix.shape[0]): 
                # if the first threshold is already too high, divide the initial step by 10. 
                # If the script crashes, increasing the number
                thr=thr/100.
            else:                
                break
        #print(time.time()-start_time)
        
    max_thresholds_per_node=np.asarray([visited_thresholds_per_node[ii][-1] for ii in range(0,len(visited_thresholds_per_node))])
    
    bottleneck_node=np.where(max_thresholds_per_node==np.min(max_thresholds_per_node))[0]
    
    thrs=np.linspace(visited_thresholds_per_node[bottleneck_node[0]][-2],visited_thresholds_per_node[bottleneck_node[0]][-1],100,endpoint=True)
    
    visited_thresholds_of_bottleneck_node=[]
    visited_thresholds_of_bottleneck_node.append(thrs[0])
    
    for final_thr in thrs:
        
            list_of_infected_nodes_per_iter=run_cascade_single_population(adj_matrix, final_thr,  bottleneck_node[0])
            
            if len(list_of_infected_nodes_per_iter[-1])==adj_matrix.shape[0]:                
                visited_thresholds_of_bottleneck_node.append(final_thr)
            else:                
                break
    return visited_thresholds_of_bottleneck_node[-1]
            
def main():
    
    adj_matrix = pd.read_csv('/home/gabridele/Desktop/connectome_sub-100206.csv', header=None).to_numpy().astype(float)
    zero_rows = np.where(np.sum(adj_matrix, 0) == 0)[0].tolist()
    adj_matrix_clean = np.delete(adj_matrix, zero_rows, axis=0)
    adj_matrix_clean = np.delete(adj_matrix_clean, zero_rows, axis=1)
    
    starting_thr = 0.0015
    thr = find_thr(adj_matrix_clean, starting_thr)
    
    # Choose two random seed nodes
    seed_node_index_1 = np.random.randint(0, adj_matrix_clean.shape[0])
    seed_node_index_2 = np.random.randint(0, adj_matrix_clean.shape[0])
    
    # Simulate cascade for seed node 1
    list_of_infected_nodes_per_iter_1 = run_cascade_single_population(adj_matrix_clean, thr, seed_node_index_1)
    
    # Simulate cascade for seed node 2
    list_of_infected_nodes_per_iter_2 = run_cascade_single_population(adj_matrix_clean, thr, seed_node_index_2)
    
    # Compare the spread of infection between the two seeds
    if len(list_of_infected_nodes_per_iter_1[-1]) > len(list_of_infected_nodes_per_iter_2[-1]):
        print(f"Seed node {seed_node_index_1} wins the competition")
    elif len(list_of_infected_nodes_per_iter_1[-1]) < len(list_of_infected_nodes_per_iter_2[-1]):
        print(f"Seed node {seed_node_index_2} wins the competition")
    else:
        print("tie")
    print('list seed node A:', list_of_infected_nodes_per_iter_1)
    print('list seed node B:', list_of_infected_nodes_per_iter_2)
if __name__ == "__main__":
    main()

tie
list seed node A: [[398], [4, 5, 7, 10, 25, 28, 32, 33, 34, 35, 36, 38, 40, 41, 44, 45, 46, 47, 48, 51, 52, 55, 58, 64, 66, 67, 68, 69, 70, 71, 72, 81, 84, 86, 107, 114, 116, 117, 118, 119, 120, 124, 125, 136, 137, 142, 144, 145, 153, 154, 160, 164, 165, 166, 171, 178, 179, 181, 221, 222, 223, 224, 225, 226, 227, 228, 230, 231, 233, 243, 244, 250, 252, 253, 254, 257, 260, 262, 263, 264, 265, 266, 268, 269, 270, 285, 286, 291, 294, 295, 296, 298, 300, 302, 303, 306, 307, 308, 309, 310, 311, 312, 313, 314, 328, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 344, 346, 347, 349, 350, 351, 352, 354, 357, 358, 361, 364, 365, 367, 371, 372, 383, 386, 389, 392, 394, 397, 399, 426, 447, 501, 502, 503, 507, 508, 511, 513, 515, 518, 521, 524, 525, 531, 559, 560, 561, 568, 569, 570, 574, 580, 581, 582, 583, 602, 635, 642, 643, 645, 646, 647, 651, 652, 653, 654, 664, 669, 671, 674, 678, 679, 681, 683, 684, 685, 690, 693, 697, 706, 746, 747, 748, 753, 755, 756, 757, 758, 759, 7