In [1]:
import pymaid
import matplotlib.pyplot as plt
import networkx as nx
import time

def plot_nx(x, plot_connectors=True, highlight_connectors=None, prog='dot'):
    """ This lets you plot neurons as dendrograms using networkx and its bindings
    to graphviz.

    Parameters
    ----------
    x :     				CatmaidNeuron
            				Neuron to plot. Strongly recommend to downsample the neuron!
    plot_connectors :       bool, optional
                            If True, connectors will be plotted
    highlight_connectors :  list of int
                            These connectors (or more precisely, the treenodes they
                            connect to) will be highlighted in green
    prog :                  {'dot','neato','fdp'}
                            Graphviz layout to use. Be aware that neato and fdp are 
                            extremely slow!

    Returns
    -------
    Nothing

    Examples
    --------
    >>> import pymaid
    >>> import matplotlib.pyplot as plt
    >>> rm = pymaid.CatmaidInstance(server,user,password,token)
    >>> # Retrieve neuron
    >>> x = pymaid.get_neuron(16)
    >>> # Downsample to just the essential treenodes (will speed up processing A LOT)
    >>> x.downsample(100000, preserve_cn_treenodes=True)
    >>> plot_nx( x, plot_connectors=True )
    >>> plt.show()
    """

    if not isinstance(x, (pymaid.CatmaidNeuron, pymaid.CatmaidNeuronList)):
        raise ValueError('Need to pass a CatmaidNeuron')
    elif isinstance(x, pymaid.CatmaidNeuronList):
        if len(x) > 1:
            raise ValueError('Need to pass a SINGLE CatmaidNeuron')
        else:
            x = x[0]

    valid_progs = ['fdp','dot','neato']
    if prog not in valid_progs:
        raise ValueError('Unknown program parameter!')

    # Save start time
    start = time.time()

    # Reroot neuron to soma if necessary
    if x.root != x.soma:
        x.reroot(x.soma)

    # This is only relevant if we use the 'neato' layout as it preserves segment lengths
    if 'parent_dist' not in x.nodes:
        x = pymaid.calc_cable(x, return_skdata=True)

    # Generate and populate networkX graph representation of the neuron
    g=nx.DiGraph()
    g.add_nodes_from( x.nodes.treenode_id.values )
    for e in x.nodes[['treenode_id','parent_id','parent_dist']].values:
        #Skip root node
        if e[1]==None:
            continue
        g.add_edge(e[0],e[1],len=e[2])

    # Calculate layout
    print('Calculating node positions...')
    pos = nx.nx_agraph.graphviz_layout(g, prog=prog)
    
    # Plot tree with above layout
    print('Plotting tree...')
    nx.draw(g, pos, node_size=0, arrows=False )

    #Add soma
    plt.scatter([pos[x.soma][0]], [pos[x.soma][1]], s=40, c=(0,0,0), zorder=1 )

    print('Plotting connectors...')
    if plot_connectors:
        plt.scatter(  
                    [ pos[tn][0] for tn in x.connectors[x.connectors.relation==0].treenode_id.values ],
                    [ pos[tn][1] for tn in x.connectors[x.connectors.relation==0].treenode_id.values ],
                    c=(.8,.2,.2),
                    zorder=2,
                    s=5)

        plt.scatter(  
                    [ pos[tn][0] for tn in x.connectors[x.connectors.relation==1].treenode_id.values ],
                    [ pos[tn][1] for tn in x.connectors[x.connectors.relation==1].treenode_id.values ],
                    c=(.2,.2,.8),
                    zorder=2,
                    s=5)

    if highlight_connectors != None:
        hl_cn_coords = np.array([ pos[tn] for tn in x.connectors[ x.connectors.connector_id.isin( highlight_connectors ) ].treenode_id ])
        plt.scatter( hl_cn_coords[:,0], hl_cn_coords[:,1], s = 5, c=(0,1,0), zorder = 3 )       


    print('Done in %is' % int( time.time()-start ))

