Skip to content

Commit

Permalink
Merge pull request #30 from ealcobaca/postprocessing
Browse files Browse the repository at this point in the history
Postprocessing (code coverage)
  • Loading branch information
FelSiq committed May 15, 2019
2 parents 9556210 + d0ac1f5 commit 0d05b06
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 16 deletions.
60 changes: 44 additions & 16 deletions pymfe/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def _get_all_prefixed_mtds(
groups: t.Tuple[str, ...],
update_groups_by: t.Optional[t.Union[t.FrozenSet[str],
t.Set[str]]] = None,
custom_class_: t.Any = None,
) -> t.Dict[str, t.Tuple]:
"""Get all methods prefixed with ``prefix`` in predefined feature ``groups``.
Expand All @@ -281,6 +282,10 @@ def _get_all_prefixed_mtds(
precomputation methods from feature groups not related with user
selected features.
custom_class_ (Class, optional): used for inner testing purposes. If
not None, the given class will be used as reference to extract
the prefixed methods.
Returns:
If ``filter_groups_by`` argument is :obj:`NoneType` or empty:
tuple: with all filtered methods by ``group``.
Expand All @@ -293,16 +298,23 @@ def _get_all_prefixed_mtds(
"""
groups = tuple(set(VALID_GROUPS).intersection(groups))

if not groups:
if not groups and custom_class_ is None:
return {"methods": tuple(), "groups": tuple()}

if custom_class_ is None:
verify_groups = VALID_GROUPS
verify_classes = VALID_MFECLASSES

else:
verify_groups = ("test_methods", )
verify_classes = (custom_class_, )

methods_by_group = {
ft_type_id: _get_prefixed_mtds_from_class(
class_obj=mfe_class,
prefix=prefix)

for ft_type_id, mfe_class in zip(VALID_GROUPS, VALID_MFECLASSES)
if ft_type_id in groups
for ft_type_id, mfe_class in zip(verify_groups, verify_classes)
if ft_type_id in groups or custom_class_ is not None
}

gathered_methods = [] # type: t.List[TypeMtdTuple]
Expand Down Expand Up @@ -793,6 +805,7 @@ def process_features(
groups: t.Tuple[str, ...],
wildcard: str = "all",
suppress_warnings: bool = False,
custom_class_: t.Any = None,
) -> t.Tuple[t.Tuple[str, ...],
t.Tuple[TypeExtMtdTuple, ...],
t.Tuple[str, ...]]:
Expand All @@ -818,6 +831,10 @@ def process_features(
suppress_warnings (:obj:`bool`, optional): if True, hide all warnings
raised during this method processing.
custom_class_ (Class, optional): used for inner testing purposes. If
not None, the given class will be used as reference to extract
the metafeature extraction methods.
Returns:
tuple(tuple, tuple): A pair of tuples. The first Tuple is all feature
names extracted from this method, to give the user easy access to
Expand All @@ -835,8 +852,12 @@ def process_features(
if not features:
raise ValueError('"features" can not be None nor empty.')

if groups is None:
groups = tuple()
if not groups:
if custom_class_ is None:
groups = tuple()

else:
groups = ("custom", )

processed_ft = _preprocess_iterable_arg(features) # type: t.List[str]

Expand All @@ -848,6 +869,7 @@ def process_features(
prefix=MTF_PREFIX,
update_groups_by=reference_values,
groups=groups,
custom_class_=custom_class_,
) # type: t.Dict[str, t.Tuple]

ft_mtds_filtered = mtds_metadata.get(
Expand Down Expand Up @@ -916,8 +938,8 @@ def process_precomp_groups(
groups: t.Optional[t.Tuple[str, ...]] = None,
wildcard: str = "all",
suppress_warnings: bool = False,
**kwargs
) -> t.Dict[str, t.Any]:
custom_class_: t.Any = None,
**kwargs) -> t.Dict[str, t.Any]:
"""Process ``precomp_groups`` argument while fitting into a MFE model.
This function is expected to be used after ``process_groups`` function,
Expand All @@ -940,6 +962,10 @@ def process_precomp_groups(
suppress_warnings (:obj:`bool`, optional): if True, suppress warnings
invoked while processing precomputation option.
custom_class_ (Class, optional): used for inner testing purposes. If
not None, the given class will be used as reference to extract
the preprocomputing methods.
**kwargs: used to pass extra custom arguments to precomputation metho-
ds.
Expand All @@ -952,7 +978,7 @@ def process_precomp_groups(

precomp_groups = _patch_precomp_groups(precomp_groups, groups)

if not precomp_groups:
if not precomp_groups and custom_class_ is None:
return {}

processed_precomp_groups = _preprocess_iterable_arg(
Expand All @@ -961,7 +987,7 @@ def process_precomp_groups(
if wildcard in processed_precomp_groups:
processed_precomp_groups = groups

else:
elif custom_class_ is None:
if not suppress_warnings:
unknown_groups = set(processed_precomp_groups).difference(groups)

Expand All @@ -975,7 +1001,8 @@ def process_precomp_groups(

mtds_metadata = _get_all_prefixed_mtds(
prefix=PRECOMPUTE_PREFIX,
groups=processed_precomp_groups,
groups=tuple(processed_precomp_groups),
custom_class_=custom_class_,
) # type: t.Dict[str, t.Tuple]

precomp_mtds_filtered = mtds_metadata.get(
Expand Down Expand Up @@ -1313,6 +1340,7 @@ def check_score(score: str, groups: t.Tuple[str, ...]):
def post_processing(results: np.ndarray,
groups: t.Tuple[str, ...],
suppress_warnings: bool = False,
custom_class_: t.Any = None,
**kwargs) -> None:
"""Detect and apply post-processing methods in metafeatures.
Expand All @@ -1328,15 +1356,17 @@ def post_processing(results: np.ndarray,
suppress_warnings (:obj:`bool`, optional): if True, suppress warnings
invoked while processing precomputation option.
custom_class_ (Class, optional): used for inner testing purposes. If
not None, the given class will be used as reference to extract
the postprocessing methods.
**kwargs: used to pass extra custom arguments to precomputation metho-
ds.
"""
if groups is None:
return

mtds_metadata = _get_all_prefixed_mtds(
prefix=POSTPROCESS_PREFIX,
groups=groups,
custom_class_=custom_class_,
) # type: t.Dict[str, t.Tuple]

postprocess_mtds = mtds_metadata.get(
Expand Down Expand Up @@ -1382,5 +1412,3 @@ def post_processing(results: np.ndarray,

if remove_groups:
kwargs.pop("groups")

return
162 changes: 162 additions & 0 deletions tests/test_architecture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""Module dedicated to framework testing."""
import pytest
import typing as t

import numpy as np

from pymfe import _internal

GNAME = "framework-testing"


class MFETestClass:
"""Some generic methods for testing the MFE Framework."""

@classmethod
def postprocess_return_none(cls, **kwargs) -> None:
"""Postprocess: return None."""
return None

@classmethod
def postprocess_return_new_feature(
cls,
number_of_lists: int = 3,
**kwargs
) -> t.Tuple[t.List, t.List, t.List]:
"""Postprocess: return Tuple of lists."""
return tuple(["test_value"] for _ in range(number_of_lists))

@classmethod
def postprocess_raise_exception(
cls,
raise_exception: bool = False,
**kwargs) -> None:
"""Posprocess: raise exception."""
if raise_exception:
raise ValueError("Expected exception (postprocess).")

return None

@classmethod
def precompute_return_empty(cls, **kwargs) -> t.Dict[str, t.Any]:
"""Precompute: return empty dictionary."""
precomp_vals = {}

return precomp_vals

@classmethod
def precompute_return_something(cls, **kwargs) -> t.Dict[str, t.Any]:
"""Precompute: return empty dictionary."""
precomp_vals = {
"test_param_1": 0,
"test_param_2": "euclidean",
"test_param_3": list,
"test_param_4": abs,
}

return precomp_vals

@classmethod
def precompute_raise_exception(
cls,
raise_exception: bool = False,
**kwargs) -> t.Dict[str, t.Any]:
"""Precompute: raise exception."""
precomp_vals = {}

if raise_exception:
raise ValueError("Expected exception (precompute).")

return precomp_vals

@classmethod
def ft_valid_number(
cls,
X: np.ndarray,
y: np.ndarray) -> float:
"""Metafeature: float type."""
return 0.0

@classmethod
def ft_valid_array(
cls,
X: np.ndarray,
y: np.ndarray) -> np.ndarray:
"""Metafeature: float type."""
return np.zeros(5)

@classmethod
def ft_raise_expection(
cls,
X: np.ndarray,
y: np.ndarray,
raise_exception: False) -> float:
"""Metafeature: float type."""
if raise_exception:
raise ValueError("Expected exception (feature).")

return -1.0


class TestArchitecture:
"""Tests for the framework architecture."""
def test_postprocessing_valid(self):
"""Test valid postprocessing and its automatic detection."""
results = [], [], []

_internal.post_processing(
results=results,
groups=tuple(),
custom_class_=MFETestClass)

assert all(map(lambda l: len(l) > 0, results))

def test_postprocessing_invalid_1(self):
"""Test exception handling in invalid postprocessing."""
results = [], [], []

with pytest.warns(UserWarning):
_internal.post_processing(
results=results,
groups=tuple(),
custom_class_=MFETestClass,
raise_exception=True)

def test_postprocessing_invalid_2(self):
"""Test incorrect return value in postprocessing methods."""
results = [], [], []

with pytest.warns(UserWarning):
_internal.post_processing(
results=results,
groups=tuple(),
custom_class_=MFETestClass,
number_of_lists=2)

def test_preprocessing_valid(self):
"""Test valid precomputation and its automatic detection."""
precomp_args = _internal.process_precomp_groups(
precomp_groups=tuple(),
groups=tuple(),
custom_class_=MFETestClass)

assert len(precomp_args) > 0

def test_preprocessing_invalid(self):
"""Test exception handling of precomputation."""
with pytest.warns(UserWarning):
_internal.process_precomp_groups(
precomp_groups=tuple(),
groups=tuple(),
custom_class_=MFETestClass,
raise_exception=True)

def test_feature_detection(self):
"""Test automatic dectection of metafeature extraction method."""
name, mtd, groups = _internal.process_features(
features="all",
groups=tuple(),
suppress_warnings=True,
custom_class_=MFETestClass)

assert len(name) == 3 and len(mtd) == 3 and len(groups) == 1

0 comments on commit 0d05b06

Please sign in to comment.