Skip to content

Commit

Permalink
fix: use img_cluster arg in predict_img for cluster-II
Browse files Browse the repository at this point in the history
  • Loading branch information
martibosch committed Mar 29, 2024
1 parent e0f5b47 commit 449e033
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 34 deletions.
31 changes: 16 additions & 15 deletions detectree/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,21 +448,22 @@ def predict_img(self, img_filepath, *, img_cluster=None, output_filepath=None):
y_pred : numpy ndarray
Array with the pixel responses.
"""
# clf = getattr(self, "clf", None)
# if clf is None:
# try:
# clf = self.clf_dict[img_cluster]
# except KeyError:
# raise ValueError(
# f"Classifier for cluster {img_cluster} not found in "
# "`self.clf_dict`."
# )
# return self._classify_img(
# img_filepath, clf, output_filepath=output_filepath
# )
return self._predict_img(
img_filepath, self.clf, output_filepath=output_filepath
)
clf = getattr(self, "clf", None)
if clf is None:
if img_cluster is not None:
try:
clf = self.clf_dict[img_cluster]
except KeyError:
raise ValueError(
f"Classifier for cluster {img_cluster} not found in"
" `self.clf_dict`."
)
else:
raise ValueError(
"A valid `img_cluster` must be provided for classifiers"
" instantiated with `clf_dict`."
)
return self._predict_img(img_filepath, clf, output_filepath=output_filepath)

def predict_imgs(self, split_df, output_dir):
"""
Expand Down
62 changes: 43 additions & 19 deletions tests/test_detectree.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,24 @@ def _test_imgs_exist_and_rm(self, pred_imgs):
# remove it so that the output dir is clean in the tests below
os.remove(pred_img)

def _test_predict_img(self, c, img_filepath, *, img_cluster=None):
# test that `predict_img` returns a ndarray
self.assertIsInstance(
c.predict_img(img_filepath, img_cluster=img_cluster), np.ndarray
)
# test that `predict_img` with `output_filepath` returns a ndarray and dumps it
output_filepath = path.join(self.tmp_output_dir, "foo.tif")
y_pred = c.predict_img(
img_filepath, img_cluster=img_cluster, output_filepath=output_filepath
)
self.assertIsInstance(y_pred, np.ndarray)
self.assertTrue(os.path.exists(output_filepath))
# remove it so that the output dir is clean in the tests below
os.remove(output_filepath)

def test_classifier(self):
# define this here to reuse it below
img_filepath = self.split_i_df.iloc[0]["img_filepath"]
# test init classifier
# TODO: test init arguments of `Classifier`
# test that for the pre-trained classifier (no init `clf`/`clf_dict` arg) and
Expand All @@ -577,23 +594,13 @@ def test_classifier(self):
self.assertFalse(hasattr(c, "clf"))
self.assertTrue(hasattr(c, "clf_dict"))

# test image classification separately for each method
# test image segmentation separately for each method
# "cluster-I"
for c in [
dtr.Classifier(),
dtr.Classifier(clf=self.clf),
]:
img_filepath = self.split_i_df.iloc[0]["img_filepath"]
# test that `classify_img` returns a ndarray
self.assertIsInstance(c.predict_img(img_filepath), np.ndarray)
# test that `classify_img` with `output_filepath` returns a ndarray and
# dumps it
output_filepath = path.join(self.tmp_output_dir, "foo.tif")
y_pred = c.predict_img(img_filepath, output_filepath=output_filepath)
self.assertIsInstance(y_pred, np.ndarray)
self.assertTrue(os.path.exists(output_filepath))
# remove it so that the output dir is clean in the tests below
os.remove(output_filepath)
self._test_predict_img(c, img_filepath)

# test that `classify_imgs` returns a list and that the images have been
# dumped. This works regardless of whether a "img_cluster" column is present
Expand All @@ -608,27 +615,44 @@ def test_classifier(self):
dtr.Classifier(refine=False),
dtr.Classifier(clf=self.clf, refine=False),
]:
img_filepath = self.split_i_df.iloc[0]["img_filepath"]
# test that `classify_img` returns a ndarray
self.assertIsInstance(c.predict_img(img_filepath), np.ndarray)

# "cluster-II"
c = dtr.Classifier(clf_dict=self.clf_dict)
# `predict_img` should raise a `ValueError`:
# - if the `img_cluster` argument is not provided
self.assertRaises(ValueError, c.predict_img, img_filepath)
# - if the provided `img_cluster` is not a key of `clf_dict`
self.assertRaises(ValueError, c.predict_img, img_filepath, img_cluster=-999)
# otherwise, it should work
img_cluster = list(self.clf_dict.keys())[0]
self._test_predict_img(c, img_filepath, img_cluster=img_cluster)
# `predict_imgs` should raise a `ValueError` if `split_df` doesn't have an
# "img_cluster" column
self.assertRaises(
KeyError, c.predict_imgs, self.split_i_df, self.tmp_output_dir
)
# `classify_imgs` should raise a `KeyError` if `split_df` doesn't have a
# "img_cluster" column
self.assertRaises(
KeyError, c.predict_imgs, self.split_i_df, self.tmp_output_dir
)
# otherwise it should return a list and dump the images (regardless of the
# `refine` value
# otherwise, it should work
pred_imgs = c.predict_imgs(self.split_ii_df, self.tmp_output_dir)
self.assertIsInstance(pred_imgs, dict)
for _img_cluster in pred_imgs:
self._test_imgs_exist_and_rm(pred_imgs[_img_cluster])

# thest the `refine` argument
for c in [
dtr.Classifier(clf_dict=self.clf_dict, refine=refine)
for refine in [True, False]
]:
pred_imgs = c.predict_imgs(self.split_ii_df, self.tmp_output_dir)
self.assertIsInstance(pred_imgs, dict)
for img_cluster in pred_imgs:
self._test_imgs_exist_and_rm(pred_imgs[img_cluster])
# test that `classify_img` returns a ndarray
self.assertIsInstance(
c.predict_img(img_filepath, img_cluster=img_cluster), np.ndarray
)


class TestLidarToCanopy(unittest.TestCase):
Expand Down

0 comments on commit 449e033

Please sign in to comment.