Skip to content

Commit

Permalink
Add Py3 RFs for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
jni committed Jan 28, 2015
1 parent 855948f commit 5a1456c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 3 deletions.
2 changes: 1 addition & 1 deletion gala/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def load_classifier(fn):
with open(fn, 'r') as f:
cl = pck.load(f)
return cl
except pck.UnpicklingError:
except (pck.UnpicklingError, UnicodeDecodeError):
pass
if sklearn_available:
try:
Expand Down
Binary file added tests/example-data/rf1-py3.joblib.tar.gz
Binary file not shown.
Binary file added tests/example-data/rf4-py3.joblib.tar.gz
Binary file not shown.
24 changes: 22 additions & 2 deletions tests/test_gala.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
from __future__ import absolute_import
import os
import sys

PYTHON_VERSION = sys.version_info[0]

from numpy.testing import assert_allclose
import numpy as np
from sklearn.externals import joblib
import subprocess as sp
from gala import imio, classify, features, agglo, evaluate as ev
from six.moves import map


def tar_extract(fn):
sp.call(['tar', '-xzf', fn + '.tar.gz'])


rundir = os.path.dirname(__file__)

# load example data
Expand Down Expand Up @@ -52,8 +62,13 @@ def test_generate_examples_1_channel():


def test_segment_with_classifer_1_channel():
rf = classify.load_classifier(
if PYTHON_VERSION == 2:
rf = classify.load_classifier(
os.path.join(rundir, 'example-data/rf-1.joblib'))
else:
fn = os.path.join(rundir, 'example-data/rf1-py3.joblib')
tar_extract(fn)
rf = joblib.load(os.path.basename(fn))
learned_policy = agglo.classifier_probability(fc, rf)
g_test = agglo.Rag(ws_test, pr_test, learned_policy, feature_manager=fc)
g_test.agglomerate(0.5)
Expand All @@ -75,8 +90,13 @@ def test_generate_examples_4_channel():


def test_segment_with_classifier_4_channel():
rf = classify.load_classifier(
if PYTHON_VERSION == 2:
rf = classify.load_classifier(
os.path.join(rundir, 'example-data/rf-4.joblib'))
else:
fn = os.path.join(rundir, 'example-data/rf4-py3.joblib')
tar_extract(fn)
rf = joblib.load(os.path.basename(fn))
learned_policy = agglo.classifier_probability(fc, rf)
g_test = agglo.Rag(ws_test, p4_test, learned_policy, feature_manager=fc)
g_test.agglomerate(0.5)
Expand Down

0 comments on commit 5a1456c

Please sign in to comment.