In [None]:
import nibabel as nib
import pandas as pd
import numpy as np
import community as louvain
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import networkx as nx
from itertools import combinations

labels = list(nib.load('/Users/SEAlab/Documents/PPM/Data/1196_All_Timepoints/Merged_rest_Atlas_rescaled.GORDparcel.32k_fs_LR.pconn.nii').header.get_axis(0).name)
conn_map = nib.load('/Users/SEAlab/Documents/PPM/Data/1196_All_Timepoints/Merged_rest_Atlas_rescaled.GORDparcel.32k_fs_LR.pconn.nii').get_fdata()
conn_df = pd.DataFrame(conn_map, columns=labels, index=labels)

cifti_parcels = '/Users/SEAlab/Documents/PPM/Data/Gordon333_FreesurferSubcortical.32k_fs_LR.dlabel.nii'

iterations = 10
density_range = [0.001, 0.1, 0.2, 0.3, 0.4, 0.05] # top percentage to keep

# Green‐Armytage, Paul. (2010). A Colour Alphabet and the Limits of Colour Coding. Color: Design & Creativity. 5. 1-23. 
kelly_colors = dict(vivid_yellow=(255, 179, 0),
                    strong_purple=(128, 62, 117),
                    vivid_orange=(255, 104, 0),
                    very_light_blue=(166, 189, 215),
                    vivid_red=(193, 0, 32),
                    grayish_yellow=(206, 162, 98),
                    medium_gray=(129, 112, 102),
                    vivid_green=(0, 125, 52),
                    strong_purplish_pink=(246, 118, 142),
                    strong_blue=(0, 83, 138),
                    strong_yellowish_pink=(255, 122, 92),
                    strong_violet=(83, 55, 122),
                    vivid_orange_yellow=(255, 142, 0),
                    strong_purplish_red=(179, 40, 81),
                    vivid_greenish_yellow=(244, 200, 0),
                    strong_reddish_brown=(127, 24, 13),
                    vivid_yellowish_green=(147, 170, 0),
                    deep_yellowish_brown=(89, 51, 21),
                    vivid_reddish_orange=(241, 58, 19),
                    dark_olive_green=(35, 44, 22))

## Community Detection

In [None]:
# functions

def threshold_map(conn_df, thresh):
    thresh = 1-thresh
    mask = np.triu(np.ones(conn_df.shape)).astype(bool)
    conn_df.mask(mask, np.nan, inplace=True)
    quant = np.quantile(conn_df.to_numpy()[np.isfinite(conn_df.to_numpy())], q=thresh)
    thresholded_conn_df = conn_df[conn_df>=quant]    
    thresholded_conn_df.fillna(0, inplace=True)
    return(thresholded_conn_df)

def adj_to_edgelist(map_df):
    from itertools import combinations
    labels = map_df.columns
    combs = list(combinations(labels, 2))
    edge_list = []
    for c in combs:
        if map_df.loc[c[0],c[1]] > 0:
            edge_list.append(c)
        elif map_df.loc[c[1],c[0]] > 0:
            edge_list.append(c)
    return(edge_list)

def create_graph(node_list, edge_list):
    graph = nx.Graph()
    graph.add_nodes_from(node_list)
    graph.add_edges_from(edge_list)
    return(graph)

def louvain_comm_detection(graph, iterations):
    communities = pd.DataFrame(columns=graph.nodes)
    i = 0
    while i < iterations:
        communities.loc[i,:] = list(community_louvain.best_partition(graph, random_state=42+i).values())
        i += 1
    final_graph = nx.Graph()
    final_graph.add_nodes_from(graph.nodes)
    combs = list(combinations(graph.nodes, 2))
    for c in combs:
        same_net = np.mean(communities[c[0]] == communities[c[1]])
        if same_net >= 0.8:
            final_graph.add_edge(c[0],c[1])

    final_communities = community_louvain.best_partition(final_graph, random_state=42)
     
    return(final_communities)

def label_adj_with_nets(communities, network_labels, conn_df):
    
    return(network_df, net_parc_df)

def label_cifti_with_nets(cifti_parcels, communities, out_file, colors = kelly_colors, network_names=None):
    klist = list(colors.values())
    networks = list(np.unique(list(communities.values())))
    if network_names == None:
        network_names = ['Network_{0}'.format(a) for a in networks]
    
    data = np.array(list(communities.values())).astype(int)
    atlas = nib.load(cifti_parcels)
    bm = atlas.header.get_axis(1)
    label_ax = atlas.header.get_axis(0)
    label_ax.label[0]={networks[i]:(network_names[i], klist[i]) for i in networks}

    label_img = nib.cifti2.cifti2.Cifti2Image(data, (bm, label_ax))
    nib.save(label_img, out_file)

In [None]:
# pipeline
iterations = 5
density = 0.05
thresholded_conn_df = threshold_map(conn_df, density)
edge_list = adj_to_edgelist(thresholded_conn_df)
graph = create_graph(labels, edge_list)
communities = louvain_comm_detection(graph, iterations)
label_cifti_with_nets(cifti_parcels, communities, 'test_nets.dlabel.nii', colors = kelly_colors, network_names=None)

## Infomap Community detection

In [None]:
# 

In [None]:
# 