diff --git a/gala/classify.py b/gala/classify.py index 71a0869..5bf600f 100644 --- a/gala/classify.py +++ b/gala/classify.py @@ -87,7 +87,7 @@ def load_classifier(fn): with open(fn, 'r') as f: cl = pck.load(f) return cl - except pck.UnpicklingError: + except (pck.UnpicklingError, UnicodeDecodeError): pass if sklearn_available: try: diff --git a/tests/example-data/rf1-py3.joblib.tar.gz b/tests/example-data/rf1-py3.joblib.tar.gz new file mode 100644 index 0000000..6e99953 Binary files /dev/null and b/tests/example-data/rf1-py3.joblib.tar.gz differ diff --git a/tests/example-data/rf4-py3.joblib.tar.gz b/tests/example-data/rf4-py3.joblib.tar.gz new file mode 100644 index 0000000..cd41706 Binary files /dev/null and b/tests/example-data/rf4-py3.joblib.tar.gz differ diff --git a/tests/test_gala.py b/tests/test_gala.py index 7593187..82acf78 100644 --- a/tests/test_gala.py +++ b/tests/test_gala.py @@ -1,11 +1,21 @@ from __future__ import absolute_import import os +import sys + +PYTHON_VERSION = sys.version_info[0] from numpy.testing import assert_allclose import numpy as np +from sklearn.externals import joblib +import subprocess as sp from gala import imio, classify, features, agglo, evaluate as ev from six.moves import map + +def tar_extract(fn): + sp.call(['tar', '-xzf', fn + '.tar.gz']) + + rundir = os.path.dirname(__file__) # load example data @@ -52,8 +62,13 @@ def test_generate_examples_1_channel(): def test_segment_with_classifer_1_channel(): - rf = classify.load_classifier( + if PYTHON_VERSION == 2: + rf = classify.load_classifier( os.path.join(rundir, 'example-data/rf-1.joblib')) + else: + fn = os.path.join(rundir, 'example-data/rf1-py3.joblib') + tar_extract(fn) + rf = joblib.load(os.path.basename(fn)) learned_policy = agglo.classifier_probability(fc, rf) g_test = agglo.Rag(ws_test, pr_test, learned_policy, feature_manager=fc) g_test.agglomerate(0.5) @@ -75,8 +90,13 @@ def test_generate_examples_4_channel(): def test_segment_with_classifier_4_channel(): - rf = classify.load_classifier( + if PYTHON_VERSION == 2: + rf = classify.load_classifier( os.path.join(rundir, 'example-data/rf-4.joblib')) + else: + fn = os.path.join(rundir, 'example-data/rf4-py3.joblib') + tar_extract(fn) + rf = joblib.load(os.path.basename(fn)) learned_policy = agglo.classifier_probability(fc, rf) g_test = agglo.Rag(ws_test, p4_test, learned_policy, feature_manager=fc) g_test.agglomerate(0.5)