diff --git a/laserchicken/feature_extractor/__init__.py b/laserchicken/feature_extractor/__init__.py index b199c75..8b31dd7 100644 --- a/laserchicken/feature_extractor/__init__.py +++ b/laserchicken/feature_extractor/__init__.py @@ -1,10 +1,12 @@ """Feature extractor module.""" import importlib import re -from laserchicken import keys +import numpy as np +from laserchicken import keys,utils +#from .gijs_elena_feature import MyFeatureExtractor -def _feature_map(module_name=__name__): +def _feature_map(module_name = __name__): """Construct a mapping from feature names to feature extractor classes.""" module = importlib.import_module(module_name) return { @@ -17,23 +19,29 @@ def _feature_map(module_name=__name__): FEATURES = _feature_map() -def compute_features(env_point_cloud, neighborhoods, target_point_cloud, feature_names): +def compute_features(env_point_cloud, neighborhoods, target_point_cloud, feature_names, overwrite = False): ordered_features = _make_feature_list(feature_names) + targetsize = len(target_point_cloud[keys.point]["x"]["data"]) for feature in ordered_features: - if(feature in target_point_cloud): continue + if((not overwrite) and (feature in target_point_cloud[keys.point])): + continue # Skip feature calc if it is already there and we do not overwrite extractor = FEATURES[feature]() providedfeatures = extractor.provides() numfeatures = len(providedfeatures) - featurevalues = [np.empty([target_indices],dtype = np.float64) for i in range(numfeatures)] - for target_index in range(target_point_cloud[keys.point]["x"]["data"]): + featurevalues = [np.empty(targetsize,dtype = np.float64) for i in range(numfeatures)] + for target_index in range(targetsize): pointvalues = extractor.extract(env_point_cloud, neighborhoods[target_index], target_point_cloud, target_index) - for f in numfeatures: - featurevalues[f][target_index] = pointvalues[f] + if(numfeatures > 1): + for i in range(numfeatures): + featurevalues[i][target_index] = pointvalues[i] + else: + featurevalues[0][target_index] = pointvalues for i in range(numfeatures): - if "features" not in target_point_cloud: - target_point_cloud["features"] = {} fname = providedfeatures[i] - target_point_cloud["features"][fname] = {"type" : np.float64, "data" : featurevalues[i]} + if(overwrite or (not fname in target_point_cloud[keys.point])): + # Set feature values if it is not there (or we want to overwrite) + target_point_cloud[keys.point][fname] = {"type" : np.float64, "data" : featurevalues[i]} + utils.add_metadata(target_point_cloud,type(extractor).__module__,extractor.get_params()) def _make_feature_list(feature_names): @@ -42,10 +50,9 @@ def _make_feature_list(feature_names): return [f for f in feature_list if not (f in seen or seen.add(f))] - def _make_feature_list_helper(feature_names): feature_list = feature_names - for f in feature_names: + for feature_name in feature_names: extractor = FEATURES[feature_name]() dependencies = extractor.requires() feature_list.extend(dependencies) diff --git a/laserchicken/feature_extractor/abc.py b/laserchicken/feature_extractor/abc.py index 153e8e6..91afed9 100644 --- a/laserchicken/feature_extractor/abc.py +++ b/laserchicken/feature_extractor/abc.py @@ -29,13 +29,21 @@ def provides(cls): """ raise NotImplementedError("Class %s doesn't implement get_names()" % (cls.__name__)) - def extract(self, point_cloud, target_point_cloud, target_index): + def extract(self, point_cloud, neighborhood, target_point_cloud, target_index): """ Extract the feature value(s) of the point cloud at location of the target. :param point_cloud: environment (search space) point cloud - :param target_point_cloud: pointcloud that contains target point + :param neighborhood: array of indices of points within the point_cloud argument + :param target_point_cloud: point cloud that contains target point :target_index: index of the target point in the target pointcloud :return: feature value """ raise NotImplementedError("Class %s doesn't implement extract_features()" % (self.__class__.__name__)) + + def get_params(self): + """ + Returns a tuple of parameters involved in the current feature extractor + object. Needed for provenance. + """ + return () diff --git a/laserchicken/test_feature_extractor/feature_test1.py b/laserchicken/test_feature_extractor/feature_test1.py index 5341fd9..b520b4c 100644 --- a/laserchicken/test_feature_extractor/feature_test1.py +++ b/laserchicken/test_feature_extractor/feature_test1.py @@ -1,6 +1,6 @@ """Test1 feature extractor.""" from laserchicken.feature_extractor.abc import AbstractFeatureExtractor - +from laserchicken import utils class Test1FeatureExtractor(AbstractFeatureExtractor): @classmethod @@ -11,7 +11,6 @@ def requires(cls): def provides(cls): return ['test1_a', 'test1_b'] - def extract(self, _, target): - for feature_name in self.provides(): - if feature_name not in target: - target[feature_name] = len(target) + def extract(self,sourcepc,neighborhood,targetpc,targetindex): + x,y,z = utils.get_point(targetpc,targetindex) + return [0.5 * z,1.5 * z] diff --git a/laserchicken/test_feature_extractor/feature_test23.py b/laserchicken/test_feature_extractor/feature_test23.py index 9498526..d76bcae 100644 --- a/laserchicken/test_feature_extractor/feature_test23.py +++ b/laserchicken/test_feature_extractor/feature_test23.py @@ -1,6 +1,6 @@ """Test2 and Test3 feature extractors.""" from laserchicken.feature_extractor.abc import AbstractFeatureExtractor - +from laserchicken import utils class Test2FeatureExtractor(AbstractFeatureExtractor): @classmethod @@ -11,10 +11,10 @@ def requires(cls): def provides(cls): return ['test2_a', 'test2_b', 'test2_c'] - def extract(self, _, target): - for feature_name in self.provides(): - if feature_name not in target: - target[feature_name] = len(target) + def extract(self,sourcepc,neighborhood,targetpc,targetindex): + t1b = utils.get_feature(targetpc,targetindex,self.requires()[0]) + x,y,z = utils.get_point(targetpc,targetindex) + return [x + t1b,y + t1b,z + t1b] # x + 3z/2, y + 3z/2, 5z/2 class Test3FeatureExtractor(AbstractFeatureExtractor): @@ -26,7 +26,7 @@ def requires(cls): def provides(cls): return ['test3_a'] - def extract(self, _, target): - for feature_name in self.provides(): - if feature_name not in target: - target[feature_name] = len(target) + def extract(self,sourcepc,neighborhood,targetpc,targetindex): + t2a,t2c = utils.get_features(targetpc,targetindex,self.requires()) + x,y,z = utils.get_point(targetpc,targetindex) + return t2c - t2a - z # z diff --git a/laserchicken/test_feature_extractor/feature_test_broken.py b/laserchicken/test_feature_extractor/feature_test_broken.py index 20e1265..039e8cc 100644 --- a/laserchicken/test_feature_extractor/feature_test_broken.py +++ b/laserchicken/test_feature_extractor/feature_test_broken.py @@ -13,5 +13,5 @@ def requires(cls): def provides(cls): return ['test_broken'] - def extract(self, _, target): + def extract(self,sourcepc,neighborhood,targetpc,targetindex): pass diff --git a/laserchicken/test_feature_extractor/test_extract_features.py b/laserchicken/test_feature_extractor/test_extract_features.py index 4bd495e..2569ccc 100644 --- a/laserchicken/test_feature_extractor/test_extract_features.py +++ b/laserchicken/test_feature_extractor/test_extract_features.py @@ -1,7 +1,7 @@ """Test feature extraction.""" import pytest - -from laserchicken import feature_extractor +import numpy as np +from laserchicken import feature_extractor,keys,test_tools from . import __name__ as test_module_name @@ -9,51 +9,34 @@ feature_extractor.FEATURES = feature_extractor._feature_map(test_module_name) -def _extract_features(feature_names): - point_cloud = None - target = {} - feature_extractor.extract_features(point_cloud, target, feature_names) +def _compute_features(target,feature_names,overwrite = False): + neighborhoods = [[] for i in range(len(target["vertex"]["x"]["data"]))] + feature_extractor.compute_features({}, neighborhoods, target, feature_names, overwrite) return target def test_extract_single_feature(): - target = _extract_features(['test3_a']) - assert target['test3_a'] == 5 - + target = test_tools.ComplexTestData.get_point_cloud() + _compute_features(target,['test3_a']) + assert ('test1_b' in target[keys.point]) + assert all(target[keys.point]['test3_a']['data'] == target[keys.point]['z']['data']) def test_extract_multiple_features(): - result = { - 'test1_a': 0, - 'test1_b': 1, - 'test2_a': 2, - 'test2_b': 3, - 'test2_c': 4, - } - feature_names = ['test2_c'] - target = _extract_features(feature_names) - assert target == result - - -def test_no_overwrite_existing_feature(): - result = { - 'x': 10, - 'test1_a': 20, - 'test1_b': 2, - } - point_cloud = None - target = { - 'x': 10, - 'test1_a': 20, - } - feature_extractor.extract_features(point_cloud, target, ['test1_b']) - assert target == result - - -def test_extract_broken_feature(): - point_cloud = None - target = {} - feature_names = ['test_broken'] - msg = "TestBrokenFeatureExtractor failed to add feature test_broken to target {}" - with pytest.raises(AssertionError) as exc: - feature_extractor.extract_features(point_cloud, target, feature_names) - assert str(exc.value) == msg + target = test_tools.ComplexTestData.get_point_cloud() + feature_names = ['test3_a','test2_b'] + target = _compute_features(target,feature_names) + assert ('test3_a' in target[keys.point] and 'test2_b' in target[keys.point]) + +def test_extract_does_not_overwrite(): + target = test_tools.ComplexTestData.get_point_cloud() + target[keys.point]['test2_b'] = {"type":np.float64,"data":[0.9,0.99,0.999,0.9999]} + feature_names = ['test3_a','test2_b'] + target = _compute_features(target,feature_names) + assert (target[keys.point]['test2_b']['data'][2] == 0.999) + +def test_extract_can_overwrite(): + target = test_tools.ComplexTestData.get_point_cloud() + target[keys.point]['test2_b'] = {"type":np.float64,"data":[0.9,0.99,0.999,0.9999]} + feature_names = ['test3_a','test2_b'] + target = _compute_features(target,feature_names,overwrite = True) + assert (target[keys.point]['test2_b']['data'][2] == 11.5) diff --git a/laserchicken/utils.py b/laserchicken/utils.py index 0c3f059..bf8444e 100644 --- a/laserchicken/utils.py +++ b/laserchicken/utils.py @@ -37,7 +37,9 @@ def add_metadata(pc,module,params): """ Adds module metadata to pointcloud provenance """ - msg = {"time" : datetime.datetime.utcnow(),"module" : module.__name__, "parameters" : params} + msg = {"time" : datetime.datetime.utcnow()} + msg["module"] = module.__name__ if hasattr(module,"__name__") else str(module) + if(any(params)): msg["parameters"] = params msg["version"] = _version.__version__ if(keys.provenance not in pc): pc[keys.provenance] = []