Skip to content

Commit

Permalink
new functions: despike_neuron and remove_tagged_branches
Browse files Browse the repository at this point in the history
  • Loading branch information
schlegelp committed Apr 30, 2018
1 parent 50112d5 commit d9e65f1
Showing 1 changed file with 249 additions and 4 deletions.
253 changes: 249 additions & 4 deletions pymaid/morpho.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pandas as pd
import numpy as np
import scipy
import networkx as nx

from pymaid import fetch, core, graph_utils, graph, utils

Expand Down Expand Up @@ -51,7 +52,8 @@

__all__ = sorted([ 'calc_cable','strahler_index', 'prune_by_strahler','stitch_neurons','arbor_confidence',
'split_axon_dendrite', 'bending_flow', 'flow_centrality',
'segregation_index', 'to_dotproduct', 'average_neurons', 'tortuosity'])
'segregation_index', 'to_dotproduct', 'average_neurons', 'tortuosity',
'remove_tagged_branches', 'despike_neuron'])

# Default settings for progress bars
pbar_hide = False
Expand Down Expand Up @@ -751,6 +753,7 @@ def segregation_index(x, centrality_method='centrifugal'):

return H


def bending_flow(x, polypre=False):
""" Variation of the algorithm for calculating synapse flow from
Schneider-Mizell et al. (eLife, 2016).
Expand Down Expand Up @@ -1170,6 +1173,7 @@ def average_neurons( x, limit=10, base_neuron=None):
>>> # Plot
>>> da1.plot3d()
>>> da1_avg.plot3d()
"""

if not isinstance(x, core.CatmaidNeuronList):
Expand Down Expand Up @@ -1252,9 +1256,10 @@ def tortuosity( x, seg_length=10, skip_remainder=False):
----------
x : {CatmaidNeuron,CatmaidNeuronList}
seg_length : {int, float, list}, optional
Segment length(s) L in microns [um]. Please note that
this is only a guidance and actual segment length is
restricted by the neuron's resolution.
Target Segment length(s) L in microns [um]. Will try
resampling neuron to this resolution. Please note that
the final segment length is restricted by the neuron's
original resolution.
skip_remainder : bool, optional
Segments can turn out to be smaller than desired if a
branch point or end point is hit before `seg_length`
Expand Down Expand Up @@ -1329,3 +1334,243 @@ def tortuosity( x, seg_length=10, skip_remainder=False):
return T.mean()


def remove_tagged_branches(x, tag, how='segment', preserve_connectors=True, inplace=False):
""" Removes branches from neuron(s) that have been tagged with a given
treenode tag (e.g. 'not a branch').
Parameters
----------
x : {CatmaidNeuron, CatmaidNeuronList}
Neuron(s) to be processed.
tag : {str}
Treeode tag to use.
how : {'segment', 'distal', 'proximal'}, optional
Method of removal:
1. `segment` removes entire segment
2. `distal`/`proximal` removes everything
distal/proximal to tagged nodes
preserve_connectors : bool, optional
If True, connectors that got disconnected during
branch removal will be reattached to the closest
surviving node parent.
inplace : bool, optional
If False, a copy of the neuron is returned.
Returns
-------
CatmaidNeuron/CatmaidNeuronList (if `inplace=False`)
Examples
--------
1. Remove not-a-branch terminal branches
>>> x = pymaid.get_neuron(16)
>>> x_prun = pymaid.remove_tagged_branches(x,
... 'not a branch',
... how='segment',
... preserve_connectors=True )
2. Prune neuron to microtubule-containing backbone
>>> x_prun = pymaid.remove_tagged_branches(x,
... 'microtubule ends',
... how='distal',
... preserve_connectors=False )
"""

def _find_next_remaining_parent(tn):
""" Helper function that walks from a treenode to the neurons root and
returns the first parent that will not be removed.
"""
this_nodes = x.nodes.set_index('treenode_id')
while True:
this_parent = this_nodes.loc[tn,'parent_id']
if this_parent not in to_remove:
return tn
tn = this_parent

if isinstance(x, core.CatmaidNeuronList):
if not inplace:
x = x.copy()

for n in tqdm(x, desc='Removing', disable=pbar_hide, leave=pbar_leave ):
remove_tagged_branches(n, tag,
how=how,
preserve_connectors=preserve_connectors,
inplace=True)

if not inplace:
return x
elif not isinstance(x, core.CatmaidNeuron):
raise TypeError('Can only process CatmaidNeuron or CatmaidNeuronList, not "{0}"'.format(type(x)))

# Check if method is valid
VALID_METHODS = ['segment', 'distal', 'proximal']
if how not in VALID_METHODS:
raise ValueError('Invalid value for "how": {0}. Valid methods are: {1}'.format(how, ', '.join(VALID_METHODS)))

# Skip if tag not present
if tag not in x.tags:
module_logger.info('No "{0}" tag found on neuron #{1}... skipped'.format(tag, x.skeleton_id))
if not inplace:
return x
return

if not inplace:
x = x.copy()

tagged_nodes = set( x.tags[tag] )

if how == 'segment':
# Find segments that have a tagged node
tagged_segs = [ s for s in x.segments if set(s) & tagged_nodes ]

# Sanity check: are any of these segments non-terminals?
non_term = [ s for s in tagged_segs if x.graph.degree( s[0] ) > 1 ]
if non_term:
module_logger.warning('Pruning {0} non-terminal segment(s)'.format(len(non_term)))

# Get node to be removed
to_remove = [ t for s in tagged_segs for t in s[:-1] ]

# Rewire connectors before we subset
if preserve_connectors:
# Get connectors that will be disconnected
lost_cn = x.connectors[ x.connectors.treenode_id.isin(to_remove) ]

# Map to a remaining treenode
# IMPORTANT: we do currently not account for the possibility that we might be removing the root segment
new_tn = [ _find_next_remaining_parent(tn) for tn in lost_cn.treenode_id.values ]
x.connectors.loc[ x.connectors.treenode_id.isin(to_remove), 'treenode_id' ] = new_tn


# Subset to remaining nodes - skip the last node in each segment
graph_utils.subset_neuron(x,
subset=x.nodes[ ~x.nodes.treenode_id.isin(to_remove) ].treenode_id.values,
keep_connectors=preserve_connectors,
inplace=True )

if not inplace:
return x
return

elif how in ['distal', 'proximal']:
# Keep pruning until no more treenodes with our tag are left
while tag in x.tags:
# Find nodes distal to this tagged node (includes the tagged node)
dist_graph = nx.bfs_tree( x.graph, x.tags[tag][0], reverse=True )

if how == 'distal':
to_remove = list( dist_graph.nodes )
elif how == 'proximal':
# Invert dist_graph
to_remove = x.nodes[ ~x.nodes.treenode_id.isin(dist_graph.nodes) ].treenode_id.tolist()
# Make sure the tagged treenode is there too
to_remove += [ x.tags[tag][0] ]

# Rewire connectors before we subset
if preserve_connectors:
# Get connectors that will be disconnected
lost_cn = x.connectors[ x.connectors.treenode_id.isin(to_remove) ]

# Map to a remaining treenode
# IMPORTANT: we do currently not account for the possibility that we might be removing the root segment
new_tn = [ _find_next_remaining_parent(tn) for tn in lost_cn.treenode_id.values ]
x.connectors.loc[ x.connectors.treenode_id.isin(to_remove), 'treenode_id' ] = new_tn

# Subset to remaining nodes
graph_utils.subset_neuron(x,
subset=x.nodes[ ~x.nodes.treenode_id.isin(to_remove) ].treenode_id.values,
keep_connectors=preserve_connectors,
inplace=True )

if not inplace:
return x
return


def despike_neuron(x, sigma=5, inplace=False):
""" Removes spikes in neuron traces (e.g. from jumps in image data).
Notes
-----
For each treenode A, the (euclidean) distance to its immediate parent
B and to that node's parent C is computed.
If `dist(A->B)/dist(A->C) > sigma`, node B is considered a spike and
realigned between A and C.
Parameters
----------
x : {CatmaidNeuron, CatmaidNeuronList}
Neuron(s) to be processed.
sigma : {float, int}, optional
Threshold for spike detection. Smaller sigma = more
promiscuous spike detection. See notes.
inplace : bool, optional
If False, a copy of the neuron is returned.
Returns
-------
CatmaidNeuron/CatmaidNeuronList (if `inplace=False`)
"""

# TODO:
# - flattening all segments first before Spike detection should speed up
# quite a lot
# -> as intermediate step: assign all new positions at once

if isinstance(x, core.CatmaidNeuronList):
if not inplace:
x = x.copy()

for n in tqdm(x, desc='Despiking', disable=pbar_hide, leave=pbar_leave ):
despike_neuron(n, sigma=sigma, inplace=True)

if not inplace:
return x
elif not isinstance(x, core.CatmaidNeuron):
raise TypeError('Can only process CatmaidNeuron or CatmaidNeuronList, not "{0}"'.format(type(x)))

if not inplace:
x = x.copy()

# Index treenodes table by treenode ID
this_treenodes = x.nodes.set_index('treenode_id')

# Go over all segments
for seg in x.segments:
# Get nodes A, B and C of this segment
this_A = this_treenodes.loc[ seg[:-2] ]
this_B = this_treenodes.loc[ seg[1:-1] ]
this_C = this_treenodes.loc[ seg[2:] ]

# Get coordinates
A = this_A[['x','y','z']].values
B = this_B[['x','y','z']].values
C = this_C[['x','y','z']].values

# Calculate euclidian distances A->B and A->C
dist_AB = np.linalg.norm( A - B,
axis=1 )
dist_AC = np.linalg.norm( A - C,
axis=1 )

# Get the spikes
spikes_ix = np.where( ( dist_AB / dist_AC ) > sigma )[0]
spikes = this_B.iloc[ spikes_ix ]

if not spikes.empty:
# Interpolate new position(s) between A and C
new_positions = A[ spikes_ix ] + ( C[ spikes_ix ] - A[ spikes_ix ] ) / 2

this_treenodes.loc[ spikes.index, ['x','y','z'] ] = new_positions

# Reassign treenode table
x.nodes = this_treenodes.reset_index(drop=False)

if not inplace:
return x

0 comments on commit d9e65f1

Please sign in to comment.