In [None]:
from ipynb.fs.full.functions_header import *
from ipynb.fs.full.functions_cluster_analysis import *
from matplotlib.patches import Rectangle
from tqdm.notebook import tqdm

In [None]:
def plt_graph_clusters(G, pos, clustering, label_dict, title, save_file, plt_size, font_size, dpi, color_bool):
    '''Plt network graph with clusters. Use plt_size = 120 for LCC, and plt_size = 200 for full network.
    Example: plt_graph_clusters(H, pos, graph_clusters, label_dict = label_dict,
                  title = "Clustering", save_file = "test_clauset.png", plt_size = 200)'''
    # get colours
    my_cmap = ['tab:blue','tab:orange','tab:green','tab:red','tab:purple',
           'tab:brown','tab:pink','tab:olive','tab:cyan']
    my_markers = ['o', 'v', '*', 'h','s','p', 'P']
    col_list = list(itertools.product(*[my_markers, my_cmap]))
    
    # plot graph
    fig = plt.figure(figsize = (plt_size*1.2, plt_size))
    fraction = plt_size/270
    
    # Draw nodes
    for i in range(clustering.min(), clustering.max()+1):
        # get list of nodes to colour
        nodelist = list(clustering.index[clustering == i])
        
        # colour "too long" clusters with gray diamonds
        if(i == len(col_list)):
            label = '>' + str(len(col_list))
            my_mark = 'D'
            my_col = 'gray'
        elif(i > len(col_list)):
            label = ''
            my_mark = 'D'
            my_col = 'gray'
        else:
            label = str(i) + " (" + str(len(nodelist)) + ")"
            my_mark, my_col = col_list[i]
        
        # Draw nodes within speicfic cluster
        nx.draw_networkx_nodes(G, pos = pos, nodelist = nodelist,
                               node_color = my_col, node_shape = my_mark, node_size = 2350*fraction,
                               label = label) # edgecolors = 'tan' 
    
    # Draw node labels
    nx.draw_networkx_labels(G, pos = pos, labels = label_dict, font_size = font_size*fraction)
    
    # Draw edges
    if (color_bool):
        # add colorbar if color_bool == True
        weights_edge = list((nx.get_edge_attributes(G, 'weight').values()))
        my_cmap = plt.get_cmap('RdYlGn')
        nx.draw_networkx_edges(G, pos = pos, width = 0.4, edge_color = weights_edge, alpha = 0.3,
                             edge_cmap = my_cmap, edge_vmin = 0, edge_vmax = 1)

        # add colorbar
        sm = plt.cm.ScalarMappable(cmap = my_cmap, norm = plt.Normalize(vmin = 0, vmax = 1))
        sm._A = []
        cbar = plt.colorbar(sm, shrink = 0.3, pad = 0.06)
        cbar.ax.tick_params(labelsize = plt_size*0.45)
        cbar.ax.set_ylabel('Probability of edge', rotation = 270,
                           fontsize = plt_size*0.65, labelpad = plt_size*1.1)
    else:
        # otherwise leave gray
        nx.draw_networkx_edges(G, pos = pos, width = 0.5, edge_color = '#DCDCDC')
    
    # Add legend on the side
    lg = plt.legend(scatterpoints = 1, markerscale = 4.25*fraction, prop = {'size': plt_size*0.55},
               title = "Clusters (size)", title_fontsize = plt_size*0.55,
               bbox_to_anchor = (1.004, 1.0), loc = 'upper left')
    plt.tight_layout()
    plt.title(title, fontsize = plt_size, pad = plt_size*1.1)
    
    # save plot
    plt.savefig(save_file, dpi = dpi, bbox_extra_artists = (lg,), bbox_inches = 'tight')
#     plt.savefig(save_file, dpi = dpi, format="tiff",bbox_extra_artists = (lg,), bbox_inches = 'tight',
#                 pil_kwargs={"compression": "tiff_lzw"})
    plt.close(fig)
    
def plt_graph_clusters_all(G, pos, cluster_df, ind_start, label_dict, plt_size,
                           str_save, str_title, dpi, font_size, color_bool):
    '''Plot all clusters in a graph'''
    for i in cluster_df.columns[ind_start:]:
        graph_clusters = cluster_df[i]
        save_title = str_save + i +".pdf"

        plt_graph_clusters(G, pos, graph_clusters, label_dict = label_dict,
                           title = str_title + i, save_file = save_title,
                           plt_size = plt_size, dpi = dpi, font_size = font_size,
                           color_bool = color_bool)
    
def get_str_title(G, title):
    '''Create appropriate title for network visualisation'''
    str_title = "Clustering of " + title
    str_title += " (" + str(len(G.nodes)) + " nodes, "
    str_title += str(len(G.edges)) + " edges) using "
    return(str_title)

In [None]:
def get_pi(cluster, pairs):
    '''Returns posterior similarity matrix of MCMC output, based on vector of pair indices'''
    sim_vec = np.zeros(len(pairs)).astype(int)

    for i in tqdm(cluster.columns):
        sim_vec = sim_vec + (cluster[i][pairs[:,0]].values == cluster[i][pairs[:,1]].values).astype(int)
    return(squareform(sim_vec)/cluster.shape[1])

# alternatively use this
# for (i,j) in tqdm(node_pairs):
#     sim_mat[i,j] = (cluster1.iloc[i,:] == cluster1.iloc[j,:]).sum()

def get_scores(psm, clustering, pairs, no_of_samples):
    '''Given posterior similarity matrix (as vector), calculate Dahl loss for a mcmc clustering output.
    Input node pairs for indexing as well.'''
    # p_clusters: proposed clusters
    # psm: posterior similarity matrix
    psm_vec = squareform(psm)
    score = np.zeros(no_of_samples)
    
    for i in tqdm(range(no_of_samples)):
        score[i] = (((clustering[i][pairs[:,0]].values == clustering[i][pairs[:,1]].values) - psm_vec)**2).sum()
    return(score)

def plt_adj_heat(clustering, A_adj, title, cmap):
    '''Plot heatmap of adjacency matrix.'''
    
    # get indices to rearrange the matrix
    mat_sort = np.argsort(clustering)
    mat_sort_len = cluster_order(clustering, style = 'index')[:,1]

    fig,ax = plt.subplots(figsize=(15,15))
    sns.heatmap(A_adj[:,mat_sort][mat_sort,:], cmap = cmap, ax = ax)
    
    # Relabel x and y axis
    xticks = np.concatenate((np.zeros(1), np.cumsum(mat_sort_len)[:len(mat_sort_len) - 1])) + mat_sort_len/2
    plt.xticks(xticks, labels = range(len(xticks)))
    plt.yticks(xticks, labels = range(len(xticks)))
    ax.tick_params(labelbottom=False,labeltop=True)
    
    # Add title
    plt.title(title)
    
    # Add rectangles
    j = 0
    for i in mat_sort_len:
        ax.add_patch(Rectangle((j, j), i, i, fill = False, edgecolor='blue', lw=3))
        j += i

    plt.show()
    

def disp_results(G_adj, clustering, analytics, psm, scores):
    '''Plot results'''
    dahl_ind = np.argmin(scores)
    map_ind = np.argmax(analytics['Log_lik'])
    
    print("Dahl minimum MCMC: ", dahl_ind)
    print("MAP:", map_ind)

    # plots
    plt.figure(figsize = (16,5))
    plt.plot(scores)
    plt.title("Dahl's score of MCMC samples compared to psm")
    plt.show()
    
    plt_adj_heat(clustering[dahl_ind], psm, title = "PSM with Dahl estimate", cmap = 'gray_r')
    plt_adj_heat(clustering[map_ind], psm, title = "PSM with MAP estimate", cmap = 'gray_r')
    plt_adj_heat(clustering[dahl_ind], G_adj, title = "Adjacency matrix with Dahl estimate", cmap = 'RdYlGn')
    plt_adj_heat(clustering[map_ind], G_adj, title = "Adjacency matrix with MAP estimate", cmap = 'RdYlGn')