Skip to content

Commit

Permalink
Fixed feature extractor and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
goord committed Jan 25, 2018
1 parent 51e7b3f commit 5e31150
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 75 deletions.
33 changes: 20 additions & 13 deletions laserchicken/feature_extractor/__init__.py
@@ -1,10 +1,12 @@
"""Feature extractor module."""
import importlib
import re
from laserchicken import keys
import numpy as np
from laserchicken import keys,utils
#from .gijs_elena_feature import MyFeatureExtractor


def _feature_map(module_name=__name__):
def _feature_map(module_name = __name__):
"""Construct a mapping from feature names to feature extractor classes."""
module = importlib.import_module(module_name)
return {
Expand All @@ -17,23 +19,29 @@ def _feature_map(module_name=__name__):
FEATURES = _feature_map()


def compute_features(env_point_cloud, neighborhoods, target_point_cloud, feature_names):
def compute_features(env_point_cloud, neighborhoods, target_point_cloud, feature_names, overwrite = False):
ordered_features = _make_feature_list(feature_names)
targetsize = len(target_point_cloud[keys.point]["x"]["data"])
for feature in ordered_features:
if(feature in target_point_cloud): continue
if((not overwrite) and (feature in target_point_cloud[keys.point])):
continue # Skip feature calc if it is already there and we do not overwrite
extractor = FEATURES[feature]()
providedfeatures = extractor.provides()
numfeatures = len(providedfeatures)
featurevalues = [np.empty([target_indices],dtype = np.float64) for i in range(numfeatures)]
for target_index in range(target_point_cloud[keys.point]["x"]["data"]):
featurevalues = [np.empty(targetsize,dtype = np.float64) for i in range(numfeatures)]
for target_index in range(targetsize):
pointvalues = extractor.extract(env_point_cloud, neighborhoods[target_index], target_point_cloud, target_index)
for f in numfeatures:
featurevalues[f][target_index] = pointvalues[f]
if(numfeatures > 1):
for i in range(numfeatures):
featurevalues[i][target_index] = pointvalues[i]
else:
featurevalues[0][target_index] = pointvalues
for i in range(numfeatures):
if "features" not in target_point_cloud:
target_point_cloud["features"] = {}
fname = providedfeatures[i]
target_point_cloud["features"][fname] = {"type" : np.float64, "data" : featurevalues[i]}
if(overwrite or (not fname in target_point_cloud[keys.point])):
# Set feature values if it is not there (or we want to overwrite)
target_point_cloud[keys.point][fname] = {"type" : np.float64, "data" : featurevalues[i]}
utils.add_metadata(target_point_cloud,type(extractor).__module__,extractor.get_params())


def _make_feature_list(feature_names):
Expand All @@ -42,10 +50,9 @@ def _make_feature_list(feature_names):
return [f for f in feature_list if not (f in seen or seen.add(f))]



def _make_feature_list_helper(feature_names):
feature_list = feature_names
for f in feature_names:
for feature_name in feature_names:
extractor = FEATURES[feature_name]()
dependencies = extractor.requires()
feature_list.extend(dependencies)
Expand Down
12 changes: 10 additions & 2 deletions laserchicken/feature_extractor/abc.py
Expand Up @@ -29,13 +29,21 @@ def provides(cls):
"""
raise NotImplementedError("Class %s doesn't implement get_names()" % (cls.__name__))

def extract(self, point_cloud, target_point_cloud, target_index):
def extract(self, point_cloud, neighborhood, target_point_cloud, target_index):
"""
Extract the feature value(s) of the point cloud at location of the target.
:param point_cloud: environment (search space) point cloud
:param target_point_cloud: pointcloud that contains target point
:param neighborhood: array of indices of points within the point_cloud argument
:param target_point_cloud: point cloud that contains target point
:target_index: index of the target point in the target pointcloud
:return: feature value
"""
raise NotImplementedError("Class %s doesn't implement extract_features()" % (self.__class__.__name__))

def get_params(self):
"""
Returns a tuple of parameters involved in the current feature extractor
object. Needed for provenance.
"""
return ()
9 changes: 4 additions & 5 deletions laserchicken/test_feature_extractor/feature_test1.py
@@ -1,6 +1,6 @@
"""Test1 feature extractor."""
from laserchicken.feature_extractor.abc import AbstractFeatureExtractor

from laserchicken import utils

class Test1FeatureExtractor(AbstractFeatureExtractor):
@classmethod
Expand All @@ -11,7 +11,6 @@ def requires(cls):
def provides(cls):
return ['test1_a', 'test1_b']

def extract(self, _, target):
for feature_name in self.provides():
if feature_name not in target:
target[feature_name] = len(target)
def extract(self,sourcepc,neighborhood,targetpc,targetindex):
x,y,z = utils.get_point(targetpc,targetindex)
return [0.5 * z,1.5 * z]
18 changes: 9 additions & 9 deletions laserchicken/test_feature_extractor/feature_test23.py
@@ -1,6 +1,6 @@
"""Test2 and Test3 feature extractors."""
from laserchicken.feature_extractor.abc import AbstractFeatureExtractor

from laserchicken import utils

class Test2FeatureExtractor(AbstractFeatureExtractor):
@classmethod
Expand All @@ -11,10 +11,10 @@ def requires(cls):
def provides(cls):
return ['test2_a', 'test2_b', 'test2_c']

def extract(self, _, target):
for feature_name in self.provides():
if feature_name not in target:
target[feature_name] = len(target)
def extract(self,sourcepc,neighborhood,targetpc,targetindex):
t1b = utils.get_feature(targetpc,targetindex,self.requires()[0])
x,y,z = utils.get_point(targetpc,targetindex)
return [x + t1b,y + t1b,z + t1b] # x + 3z/2, y + 3z/2, 5z/2


class Test3FeatureExtractor(AbstractFeatureExtractor):
Expand All @@ -26,7 +26,7 @@ def requires(cls):
def provides(cls):
return ['test3_a']

def extract(self, _, target):
for feature_name in self.provides():
if feature_name not in target:
target[feature_name] = len(target)
def extract(self,sourcepc,neighborhood,targetpc,targetindex):
t2a,t2c = utils.get_features(targetpc,targetindex,self.requires())
x,y,z = utils.get_point(targetpc,targetindex)
return t2c - t2a - z # z
2 changes: 1 addition & 1 deletion laserchicken/test_feature_extractor/feature_test_broken.py
Expand Up @@ -13,5 +13,5 @@ def requires(cls):
def provides(cls):
return ['test_broken']

def extract(self, _, target):
def extract(self,sourcepc,neighborhood,targetpc,targetindex):
pass
71 changes: 27 additions & 44 deletions laserchicken/test_feature_extractor/test_extract_features.py
@@ -1,59 +1,42 @@
"""Test feature extraction."""
import pytest

from laserchicken import feature_extractor
import numpy as np
from laserchicken import feature_extractor,keys,test_tools

from . import __name__ as test_module_name

# Overwrite the available feature extractors with test feature extractors
feature_extractor.FEATURES = feature_extractor._feature_map(test_module_name)


def _extract_features(feature_names):
point_cloud = None
target = {}
feature_extractor.extract_features(point_cloud, target, feature_names)
def _compute_features(target,feature_names,overwrite = False):
neighborhoods = [[] for i in range(len(target["vertex"]["x"]["data"]))]
feature_extractor.compute_features({}, neighborhoods, target, feature_names, overwrite)
return target


def test_extract_single_feature():
target = _extract_features(['test3_a'])
assert target['test3_a'] == 5

target = test_tools.ComplexTestData.get_point_cloud()
_compute_features(target,['test3_a'])
assert ('test1_b' in target[keys.point])
assert all(target[keys.point]['test3_a']['data'] == target[keys.point]['z']['data'])

def test_extract_multiple_features():
result = {
'test1_a': 0,
'test1_b': 1,
'test2_a': 2,
'test2_b': 3,
'test2_c': 4,
}
feature_names = ['test2_c']
target = _extract_features(feature_names)
assert target == result


def test_no_overwrite_existing_feature():
result = {
'x': 10,
'test1_a': 20,
'test1_b': 2,
}
point_cloud = None
target = {
'x': 10,
'test1_a': 20,
}
feature_extractor.extract_features(point_cloud, target, ['test1_b'])
assert target == result


def test_extract_broken_feature():
point_cloud = None
target = {}
feature_names = ['test_broken']
msg = "TestBrokenFeatureExtractor failed to add feature test_broken to target {}"
with pytest.raises(AssertionError) as exc:
feature_extractor.extract_features(point_cloud, target, feature_names)
assert str(exc.value) == msg
target = test_tools.ComplexTestData.get_point_cloud()
feature_names = ['test3_a','test2_b']
target = _compute_features(target,feature_names)
assert ('test3_a' in target[keys.point] and 'test2_b' in target[keys.point])

def test_extract_does_not_overwrite():
target = test_tools.ComplexTestData.get_point_cloud()
target[keys.point]['test2_b'] = {"type":np.float64,"data":[0.9,0.99,0.999,0.9999]}
feature_names = ['test3_a','test2_b']
target = _compute_features(target,feature_names)
assert (target[keys.point]['test2_b']['data'][2] == 0.999)

def test_extract_can_overwrite():
target = test_tools.ComplexTestData.get_point_cloud()
target[keys.point]['test2_b'] = {"type":np.float64,"data":[0.9,0.99,0.999,0.9999]}
feature_names = ['test3_a','test2_b']
target = _compute_features(target,feature_names,overwrite = True)
assert (target[keys.point]['test2_b']['data'][2] == 11.5)
4 changes: 3 additions & 1 deletion laserchicken/utils.py
Expand Up @@ -37,7 +37,9 @@ def add_metadata(pc,module,params):
"""
Adds module metadata to pointcloud provenance
"""
msg = {"time" : datetime.datetime.utcnow(),"module" : module.__name__, "parameters" : params}
msg = {"time" : datetime.datetime.utcnow()}
msg["module"] = module.__name__ if hasattr(module,"__name__") else str(module)
if(any(params)): msg["parameters"] = params
msg["version"] = _version.__version__
if(keys.provenance not in pc):
pc[keys.provenance] = []
Expand Down

0 comments on commit 5e31150

Please sign in to comment.