Skip to content

Commit

Permalink
Merge pull request #64 from jni/fix-learning-mode
Browse files Browse the repository at this point in the history
Fix #63
  • Loading branch information
jni committed Nov 24, 2015
2 parents f35e7d6 + 32a80b5 commit 47cecdc
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ install:
- conda update -q conda
- conda info -a

- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION numpy scipy matplotlib networkx cython h5py pillow scikit-image scikit-learn setuptools pip
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION numpy scipy matplotlib networkx cython h5py pillow scikit-image scikit-learn=0.16 setuptools pip
- source activate test-environment

# custom package not available from conda
Expand Down
35 changes: 13 additions & 22 deletions gala/agglo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
def contingency_table(a, b):
ct = ev_contingency_table(a, b)
nx, ny = ct.shape
ctout = np.zeros((2 * nx, ny), ct.dtype)
ctout = np.zeros((2*nx + 1, ny), ct.dtype)
ct.todense(out=ctout[:nx, :])
return ctout

Expand Down Expand Up @@ -202,7 +202,7 @@ def classifier_probability(feature_extractor, classifier):
def predict(g, n1, n2):
if n1 == g.boundary_body or n2 == g.boundary_body:
return inf
features = feature_extractor(g, n1, n2)
features = np.atleast_2d(feature_extractor(g, n1, n2))
try:
prediction = classifier.predict_proba(features)
prediction_arr = np.array(prediction, copy=False)
Expand Down Expand Up @@ -616,9 +616,6 @@ def set_probabilities(self, probs=array([]), normalize=False):
self.probabilities = morpho.pad(probs, padding)
self.probabilities_r = self.probabilities.ravel()[:,newaxis]
elif p_ndim == w_ndim+1:
if sp[1:] == sw:
sp = sp[1:]+[sp[0]]
probs = probs.transpose(sp)
axes = list(range(p_ndim-1))
self.probabilities = morpho.pad(probs, padding, axes)
self.probabilities_r = self.probabilities.reshape(
Expand Down Expand Up @@ -1087,10 +1084,7 @@ def learn_agglomerate(self, gts, feature_map,
g.merge_priority_function = boundary_mean
elif num_epochs > 0 and priority_mode == 'active' or \
num_epochs % 2 == 1 and priority_mode == 'mixed':
if random_state == None:
cl = get_classifier(classifier)
else:
cl = get_classifier(classifier, random_state=random_state)
cl = get_classifier(classifier, random_state=random_state)
feat, lab = classify.sample_training_data(
data[0], data[1][:, label_type_keys[labeling_mode]],
max_num_samples)
Expand All @@ -1104,8 +1098,9 @@ def learn_agglomerate(self, gts, feature_map,
g.show_progress = False # bug in MergeQueue usage causes
# progressbar crash.
g.rebuild_merge_queue()
alldata.append(g._learn_agglomerate(ctables, feature_map,
learning_mode, labeling_mode))
alldata.append(g.learn_epoch(ctables, feature_map,
learning_mode=learning_mode,
labeling_mode=labeling_mode))
if memory:
if unique:
data = unique_learning_data_elements(alldata)
Expand Down Expand Up @@ -1210,8 +1205,8 @@ def learn_edge(self, edge, ctables, assignments, feature_map):
return features, labels, weights, (n1, n2)


def _learn_agglomerate(self, ctables, feature_map, gt_dts,
learning_mode='strict', labeling_mode='assignment'):
def learn_epoch(self, ctables, feature_map,
learning_mode='permissive', labeling_mode='assignment'):
"""Learn the agglomeration process using various strategies.
Parameters
Expand All @@ -1223,10 +1218,11 @@ def _learn_agglomerate(self, ctables, feature_map, gt_dts,
The map from node pairs to a feature vector. This must
consist either of uncached features or of the cache used
when building the graph.
learning_mode : {'strict', 'loose'}
learning_mode : {'strict', 'permissive'}, optional
If ``'strict'``, don't proceed with a merge when it goes against
the ground truth.
labeling_mode : {'assignment', 'vi-sign', 'rand-sign'}
the ground truth. For historical reasons, 'loose' is allowed as
a synonym for 'strict'.
labeling_mode : {'assignment', 'vi-sign', 'rand-sign'}, optional
Which label to use for `learning_mode`. Note that all labels
are saved in the end.
Expand Down Expand Up @@ -1682,12 +1678,7 @@ def traversing_bodies(self):
def non_traversing_bodies(self):
"""List bodies that are not orphans and do not traverse the volume."""
return [n for n in self.nodes() if self.at_volume_boundary(n) and
not self.is_traversed_by_node(n)]


def compute_non_traversing_bodies(self):
"""Same as agglo.Rag.non_traversing_bodies, but doesn't use graph."""
return morpho.non_traversing_bodies(self.get_segmentation())
not self.is_traversed_by_node(n) and n != self.boundary_body]


def raveler_body_annotations(self, traverse=False):
Expand Down
9 changes: 3 additions & 6 deletions gala/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@
# libraries
import h5py
import numpy as np
from numpy.testing import assert_raises
np.seterr(divide='ignore')

from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.externals import joblib

Expand All @@ -26,9 +23,6 @@
else:
vigra_available = True

# local imports
from . import iterprogress as ip


def default_classifier_extension(cl, use_joblib=True):
"""
Expand Down Expand Up @@ -166,6 +160,7 @@ def get_classifier(name='random forest', *args, **kwargs):
True
>>> cl.n_estimators
47
>>> from numpy.testing import assert_raises
>>> assert_raises(NotImplementedError, get_classifier, 'perfect class')
"""
name = name.lower()
Expand All @@ -177,6 +172,8 @@ def get_classifier(name='random forest', *args, **kwargs):
return DefaultRandomForest(*args, **kwargs)
elif is_naive_bayes:
from sklearn.naive_bayes import GaussianNB
if 'random_state' in kwargs:
del kwargs['random_state']
return GaussianNB(*args, **kwargs)
else:
raise NotImplementedError('Classifier "%s" is either not installed '
Expand Down
2 changes: 1 addition & 1 deletion gala/imio.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def write_png_image_stack(npy_vol, fn, axis=-1, bitdepth=None):
npy_vol = uint32(npy_vol)
for z, pl in enumerate(npy_vol):
im = Image.new(mode_base, pl.T.shape)
im.fromstring(pl.tostring(), 'raw', mode)
im.frombytes(pl.tostring(), 'raw', mode)
im.save(fn % z)

### VTK structured points array format
Expand Down
9 changes: 9 additions & 0 deletions tests/test_agglo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ def test_mask():
assert (2, 4) in g.edges()


def test_traverse():
labels = [[0, 1, 2],
[0, 1, 2],
[0, 1, 2]]
g = agglo.Rag(np.array(labels))
assert g.traversing_bodies() == [1]
assert g.non_traversing_bodies() == [0, 2]


if __name__ == '__main__':
from numpy import testing
testing.run_module_suite()
Expand Down
23 changes: 21 additions & 2 deletions tests/test_gala.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import sys
import glob
from contextlib import contextmanager

Expand Down Expand Up @@ -67,7 +66,7 @@ def train_and_save_classifier(training_data_file, filename,

### tests

def test_generate_examples_1_channel():
def test_generate_lash_examples_1_channel():
"""Run a flat epoch and an active epoch of learning, compare learned sets.
The *order* of the edges learned by learn_flat is not guaranteed, so we
Expand All @@ -78,6 +77,7 @@ def test_generate_examples_1_channel():
"""
g_train = agglo.Rag(ws_train, pr_train, feature_manager=fc)
_, alldata = g_train.learn_agglomerate(gt_train, fc,
learning_mode='permissive',
classifier='naive bayes')
testfn = 'example-data/train-naive-bayes-merges1-py3.pck'
exp0, exp1 = load_pickle(os.path.join(rundir, testfn))
Expand All @@ -94,6 +94,24 @@ def test_generate_examples_1_channel():
assert_allclose(nb.class_prior_, nbexp.class_prior_, atol=1e-7)


def test_generate_gala_examples_1_channel():
"""As `test_generate_lash_examples_1_channel`, but using strict learning.
"""
g_train = agglo.Rag(ws_train, pr_train, feature_manager=fc)
_, alldata = g_train.learn_agglomerate(gt_train, fc,
learning_mode='strict',
classifier='naive bayes')
testfn = 'example-data/train-naive-bayes-merges1-py3.pck'
exp0, exp1 = load_pickle(os.path.join(rundir, testfn))
expected_edges = set(map(tuple, exp0))
edges = set(map(tuple, alldata[0][3]))
merges = alldata[1][3]
# expect same edges in flat learning
assert edges == expected_edges # flat learning epoch
# expect more edges in strict training than permissive
assert np.shape(merges)[0] > np.shape(exp1)[0]


def test_segment_with_classifer_1_channel():
fn = os.path.join(rundir, 'example-data/rf1-py3.joblib')
with tar_extract(fn) as fn:
Expand All @@ -119,6 +137,7 @@ def test_generate_examples_4_channel():
"""
g_train = agglo.Rag(ws_train, p4_train, feature_manager=fc)
_, alldata = g_train.learn_agglomerate(gt_train, fc,
learning_mode='permissive',
classifier='naive bayes')
testfn = 'example-data/train-naive-bayes-merges4-py3.pck'
exp0, exp1 = load_pickle(os.path.join(rundir, testfn))
Expand Down

0 comments on commit 47cecdc

Please sign in to comment.