Skip to content

Commit

Permalink
Merge pull request janelia-flyem#78 from jni/minor-fixes
Browse files Browse the repository at this point in the history
Some bug fixes and speed improvements
  • Loading branch information
jni committed Mar 11, 2016
2 parents add5ff2 + ad4322e commit ebc25b7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
5 changes: 3 additions & 2 deletions gala/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,11 @@ def get_classifier(name='random forest', *args, **kwargs):

class DefaultRandomForest(RandomForestClassifier):
def __init__(self, n_estimators=100, criterion='entropy', max_depth=20,
bootstrap=False, random_state=None):
bootstrap=False, random_state=None, n_jobs=-1):
super(DefaultRandomForest, self).__init__(
n_estimators=n_estimators, criterion=criterion,
max_depth=max_depth, bootstrap=bootstrap, random_state=random_state)
max_depth=max_depth, bootstrap=bootstrap,
random_state=random_state, n_jobs=n_jobs)


class VigraRandomForest(object):
Expand Down
15 changes: 10 additions & 5 deletions gala/morpho.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from skimage import measure, util
import skimage.morphology

from sklearn.externals import joblib

zero3d = array([0,0,0])

def complement(a):
Expand Down Expand Up @@ -228,8 +230,7 @@ def watershed(a, seeds=None, connectivity=1, mask=None, smooth_thresh=0.0,
if not seeded:
seeds = regional_minima(a, connectivity)
if minimum_seed_size > 0:
seeds = remove_small_connected_components(seeds, minimum_seed_size,
in_place=True)
seeds = remove_small_connected_components(seeds, minimum_seed_size)
seeds = relabel_from_one(seeds)[0]
if smooth_seeds:
seeds = binary_opening(seeds, sel)
Expand Down Expand Up @@ -269,7 +270,7 @@ def watershed(a, seeds=None, connectivity=1, mask=None, smooth_thresh=0.0,
(br[nidxs] == level)).astype(bool) ])
return juicy_center(ws)

def watershed_sequence(a, seeds=None, mask=None, axis=0, **kwargs):
def watershed_sequence(a, seeds=None, mask=None, axis=0, n_jobs=1, **kwargs):
"""Perform a watershed on a plane-by-plane basis.
See documentation for `watershed` for available kwargs.
Expand All @@ -294,6 +295,9 @@ def watershed_sequence(a, seeds=None, mask=None, axis=0, **kwargs):
Which axis defines the plane sequence. For example, if the input image
is 3D and axis=1, then the output will be the watershed on a[:, 0, :],
a[:, 1, :], a[:, 2, :], ... and so on.
n_jobs : int, optional
Use joblib to distribute each plane over given number of processing
cores. If -1, `multiprocessing.cpu_count` is used.
Returns
-------
Expand All @@ -314,8 +318,9 @@ def watershed_sequence(a, seeds=None, mask=None, axis=0, **kwargs):
seeds = it.repeat(None)
if mask is None:
mask = it.repeat(None)
ws = [watershed(i, seeds=s, mask=m, **kwargs)
for i, s, m in zip(a, seeds, mask)]
ws = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(watershed)(i, seeds=s, mask=m, **kwargs)
for i, s, m in zip(a, seeds, mask))
counts = list(map(np.max, ws[:-1]))
counts = np.concatenate((np.array([0]), counts))
counts = np.cumsum(counts)
Expand Down

0 comments on commit ebc25b7

Please sign in to comment.