In [1]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly
from plotly.subplots import make_subplots
import json
import plotly.express as px
from collections import Counter
import community as community_louvain
import networkx as nx
import pickle

In [2]:
COLOR_PALETTE = ['#d0e3f5',
                 '#712e67',
                 '#267592',
                 '#5fb12a',
                 '#fac800',
                 '#ff7917',
                 '#e23a34']

In [4]:
def load_PH(filename):
    with open('PH/'  + filename + '.json') as json_file:
        data = json.load(json_file)
    barcode = np.array(data['barcode']).T
    representatives = data['representatives']
    return(barcode,representatives)

def load_curve(filename):
    curve = np.load('pointClouds/'  + filename + '.npy')
    return(curve.values[:,:3])

In [5]:
####### to compute/save partitions

def create_graph_from_PH(curve):
    b,r = load_PH(curve)
    n = len(load_curve(curve))
    G = nx.Graph()
    G.add_nodes_from([i for i in range(n)])
    for i in range(len(b)):
        w = abs(b[i][0]-b[i][1])
        vx = []
        for el in r[i]:
            vx.append(el[0])
            vx.append(el[1])
        vx = list(set(vx))
        for k in range(len(vx)):
            for j in range(k+1,len(vx)):
                G.add_edge(vx[k]-1, vx[j]-1, weight=w )
    return(G)

def save_partitions(cells):
    partitions = {}
    for cell in cells:
        partitions[cell] = {}
        
        partition1 = communities(cell)
        partitions[cell]= partition1
        print(cell,len(load_curve(cell)),len(partition1))
    with open('partitions.pickle', 'wb') as f:
        pickle.dump(partitions, f,protocol=pickle.HIGHEST_PROTOCOL)
    return partitions


def communities(curve):
    G = create_graph_from_PH(curve)
    partition = community_louvain.best_partition(G, resolution = 1)
    return partition  

In [6]:
curves = ['curve_{}'.format(i) for i in range(2)] 
A = save_partitions(curves)

In [9]:
COMMS = {}

with open('partitions.pickle', 'rb') as f:
    PARTITIONS = pickle.load(f)
            
    for curve in curves:
        COMMS[curve] = {}
            
       
        partition = PARTITIONS[curve]

        COMMS[curve] = {v:[] for v in set(partition.values())}
        for key in partition.keys():
            COMMS[curve][partition[key]].append(key)
                


In [10]:
def community_adjacency(curve,scope = 'plot'):
    
    cs = COMMS[curve]
    n = np.sum([len(cs[k]) for k in cs.keys()])
    matrix = np.zeros((n,n))

    color = 1
    for group in cs.keys():
        if len(cs[group]) >1:
            for i in cs[group]:
                for j in cs[group]:
                    if scope == 'plot':
                        matrix[i,j] = color
                    else:
                        matrix[i,j] = 1
            color = color+1

    return(matrix) 



def plot_chain_with_communities(filename):
    
    community = COMMS[filename]
    curve = load_curve(filename)
    
    
    
    df = pd.DataFrame.from_records(curve, columns=['X', 'Y','Z'])

    nodes = communities_color_code(community)
    df["node centrality"] = nodes    
    
        
    fig = go.Figure()

    fig.add_trace(go.Scatter3d(
    x=df['X'], y=df['Y'], z=df['Z'], 
    name = 'Curve',    
    marker=dict(
        size=8,
        color = df["node centrality"],
        colorscale=COLOR_PALETTE,
        line=dict(width=4, 
                  color='DarkSlateGrey')
        
    ),
        
    line=dict(
        width=8,
        color = df["node centrality"],
        colorscale=COLOR_PALETTE
    )),
    )
    fig.update_layout(scene=dict(xaxis = dict(
                         backgroundcolor="rgb(200, 200, 230)",
                         gridcolor='rgba(0,0,0,0)',
                         showbackground=False,
                         zerolinecolor='rgba(0,0,0,0)',showticklabels=False,),
    yaxis = dict(
                         backgroundcolor="rgb(200, 200, 230)",
                         gridcolor='rgba(0,0,0,0)',
                         showbackground=False,
                         zerolinecolor='rgba(0,0,0,0)',showticklabels=False),
               zaxis = dict(
                         backgroundcolor="rgb(200, 200, 230)",
                         gridcolor='rgba(0,0,0,0)',
                         showbackground=False,
                         zerolinecolor='rgba(0,0,0,0)',showticklabels=False),
        camera=dict(
            up=dict(
                x=-10,
                y=0,
                z=30
            ),
            eye=dict(
                x=0.9,
                y=0.9,
                z=1.3,
            )
        )))
                  
    fig.update_layout(scene = dict(
                    xaxis_title=' ',
                    yaxis_title=' ',
                    zaxis_title=' '))
    fig.update_traces(hoverinfo="text",hovertemplate=nodes)

    return(fig)


In [11]:
def communities_color_code(comms):
    n = np.sum([len(comms[k]) for k in comms.keys()])
    nodes = n*[0]
    color = 1
    for k in comms.keys():
        if len(comms[k])>1:
            for j in comms[k]:
                nodes[j] = color
            color = color+1
    return nodes

In [20]:
fig = plot_chain_with_communities('curve_1')
fig.show()