Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Adding non-default loggables explicitly. #1331

Merged
merged 11 commits into from
May 31, 2022
19 changes: 11 additions & 8 deletions hoomd/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,11 +624,12 @@ def only_default(self):
quantities."""
return self._only_default

def _filter_quantities(self, quantities):
def _filter_quantities(self, quantities, force_quantities=False):
for quantity in quantities:
if self._only_default and not quantity.default:
if quantity.category not in self._categories:
continue
elif quantity.category in self._categories:
# Must be before default check to overwrite _only_default
if not self._only_default or quantity.default or force_quantities:
yield quantity

def _get_loggables_by_name(self, obj, quantities):
Expand All @@ -643,7 +644,7 @@ def _get_loggables_by_name(self, obj, quantities):
"object {} has not loggable quantities {}.".format(
obj, bad_keys))
yield from self._filter_quantities(
map(lambda q: obj._export_dict[q], quantities))
map(lambda q: obj._export_dict[q], quantities), True)

def add(self, obj, quantities=None, user_name=None):
"""Add loggables from obj to logger.
Expand Down Expand Up @@ -727,10 +728,12 @@ def __setitem__(self, namespace, value):
name and category. If using a method it should not take
arguments or have defaults for all arguments.
"""
if isinstance(value, _LoggerEntry):
super().__setitem__(namespace, value)
else:
super().__setitem__(namespace, _LoggerEntry.from_tuple(value))
if not isinstance(value, _LoggerEntry):
value = _LoggerEntry.from_tuple(value)
if value.category not in self.categories:
raise ValueError(
"User specified loggable is not of an accepted category.")
super().__setitem__(namespace, value)

def __iadd__(self, obj):
"""Add quantities from object or list of objects to logger.
Expand Down
198 changes: 125 additions & 73 deletions hoomd/pytest/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Part of HOOMD-blue, released under the BSD 3-Clause License.

from hoomd.conftest import pickling_check
from pytest import raises, fixture
from pytest import raises, fixture, mark
from hoomd.logging import (_LoggerQuantity, _SafeNamespaceDict, Logger,
dict_map, Loggable, LoggerCategories, log)

Expand Down Expand Up @@ -56,6 +56,10 @@ def prop(self):
def proplist(self):
return [1, 2, 3]

@log(category="string", default=False)
def prop_nondefault(self):
return "foo"

def __eq__(self, other):
return isinstance(other, type(self))

Expand All @@ -81,7 +85,7 @@ def propnotinherented(self):
not_dummy_loggable_inher = NotInherentedDummy

def test_logger_functor_application(self):
loggable_list = ['prop', 'proplist']
loggable_list = ['prop', 'proplist', "prop_nondefault"]
assert set(
self.dummy_loggable._export_dict.keys()) == set(loggable_list)
expected_namespace = _LoggerQuantity._generate_namespace(
Expand Down Expand Up @@ -109,7 +113,11 @@ def test_loggable_inherentence(self):

def test_loggables(self):
dummy_obj = self.dummy_loggable()
assert dummy_obj.loggables == {'prop': 'scalar', 'proplist': 'sequence'}
assert dummy_obj.loggables == {
'prop': 'scalar',
'proplist': 'sequence',
'prop_nondefault': 'string'
}


# ------- Test dict_map function
Expand Down Expand Up @@ -187,9 +195,21 @@ def test_len(self, namespace_dict, blank_namespace_dict):


# ------ Test Logger
@fixture
def blank_logger():
return Logger()
@fixture(params=(
{},
{
"only_default": False
},
{
"categories": ("scalar", "string")
},
{
"only_default": False,
"categories": ("scalar",)
},
))
def blank_logger(request):
return Logger(**request.param)


@fixture
Expand All @@ -207,20 +227,54 @@ def base_namespace():
return ('pytest', 'test_logging', 'DummyLoggable')


def nested_getitem(obj, namespace):
for k in namespace:
obj = obj[k]
return obj


class TestLogger:

def get_filter(self, logger, overwrite_default=False):

def filter(loggable):
with_default = not logger.only_default or loggable.default
return (loggable.category in logger.categories
and (with_default or overwrite_default))

return filter

def test_setitem(self, blank_logger):
logger = blank_logger
logger['a'] = (5, '__eq__', 'scalar')
logger[('b', 'c')] = (5, '__eq__', 'scalar')
logger['c'] = (lambda: [1, 2, 3], 'sequence')

def check(logger, namespace, loggable):
if LoggerCategories[loggable[-1]] not in logger.categories:
with raises(ValueError):
logger[namespace] = loggable
return
logger[namespace] = loggable
assert namespace in logger
log_quantity = nested_getitem(logger, namespace)
assert log_quantity.obj == loggable[0]
if len(loggable) == 3:
assert log_quantity.attr == loggable[1]
assert log_quantity.category == LoggerCategories[loggable[-1]]

# Valid values with potentially incompatible categories
check(blank_logger, 'a', (5, '__eq__', 'scalar'))
check(blank_logger, ('b', 'c'), (5, '__eq__', 'scalar'))
check(blank_logger, 'c', (lambda: [1, 2, 3], 'sequence'))
# Invalid values
for value in [dict(), list(), None, 5, (5, 2), (5, 2, 1)]:
with raises(ValueError):
logger[('c', 'd')] = value
blank_logger[('c', 'd')] = value
# Existent key
extant_key = next(iter(blank_logger.keys()))
# Requires that scalar category accepted or raises a ValueError
with raises(KeyError):
logger['a'] = (lambda: [1, 2, 3], 'sequence')
blank_logger[extant_key] = (lambda: 1, 'scalar')

def test_add_single_quantity(self, blank_logger, log_quantity):
# Assumes "scalar" is always accepted
blank_logger._add_single_quantity(None, log_quantity, None)
namespace = log_quantity.namespace + (log_quantity.name,)
assert namespace in blank_logger
Expand All @@ -235,19 +289,33 @@ def test_add_single_quantity(self, blank_logger, log_quantity):

def test_get_loggables_by_names(self, blank_logger, logged_obj):
# Check when quantities is None
log_quanities = blank_logger._get_loggables_by_name(logged_obj, None)
logged_names = ['prop', 'proplist']
log_quantities = list(
blank_logger._get_loggables_by_name(logged_obj, None))
log_filter = self.get_filter(blank_logger)
logged_names = [
loggable.name
for loggable in logged_obj._export_dict.values()
if log_filter(loggable)
]
assert len(log_quantities) == len(logged_names)
assert all([
log_quantity.name in logged_names for log_quantity in log_quanities
log_quantity.name in logged_names for log_quantity in log_quantities
])

# Check when quantities is given
accepted_quantities = ['prop', 'proplist']
log_quanities = blank_logger._get_loggables_by_name(
logged_obj, accepted_quantities)
accepted_quantities = ['proplist', "prop_nondefault"]
log_filter = self.get_filter(blank_logger, overwrite_default=True)
log_quantities = list(
blank_logger._get_loggables_by_name(logged_obj,
accepted_quantities))
logged_names = [
loggable.name
for loggable in logged_obj._export_dict.values()
if loggable.name in accepted_quantities and log_filter(loggable)
]
assert len(log_quantities) == len(logged_names)
assert all([
log_quantity.name in accepted_quantities
for log_quantity in log_quanities
log_quantity.name in logged_names for log_quantity in log_quantities
])

# Check when quantities has a bad value
Expand All @@ -256,60 +324,38 @@ def test_get_loggables_by_names(self, blank_logger, logged_obj):
a = blank_logger._get_loggables_by_name(logged_obj, bad_quantities)
list(a)

def test_add(self, blank_logger, logged_obj, base_namespace):

# Test adding everything
blank_logger.add(logged_obj)
expected_namespaces = [
base_namespace + ('prop',), base_namespace + ('proplist',)
]
assert all(ns in blank_logger for ns in expected_namespaces)
assert len(blank_logger) == 2
@mark.parametrize("quantities", ([], [
"prop",
], ['prop', 'proplist', "prop_nondefault"]))
def test_add(self, blank_logger, logged_obj, base_namespace, quantities):

# Test adding specific quantity
blank_logger._dict = dict()
blank_logger.add(logged_obj, 'prop')
expected_namespace = base_namespace + ('prop',)
assert expected_namespace in blank_logger
assert len(blank_logger) == 1
if len(quantities) != 0:
blank_logger.add(logged_obj, quantities)
log_filter = self.get_filter(blank_logger, overwrite_default=True)
else:
blank_logger.add(logged_obj)
log_filter = self.get_filter(blank_logger)

# Test multiple quantities
blank_logger._dict = dict()
blank_logger.add(logged_obj, ['prop', 'proplist'])
expected_namespaces = [
base_namespace + ('prop',), base_namespace + ('proplist',)
base_namespace + (loggable.name,)
for loggable in logged_obj._export_dict.values()
if log_filter(loggable)
]
assert all([ns in blank_logger for ns in expected_namespaces])
assert len(blank_logger) == 2

# Test with category
blank_logger._dict = dict()
blank_logger._categories = LoggerCategories['scalar']
blank_logger.add(logged_obj)
expected_namespace = base_namespace + ('prop',)
assert expected_namespace in blank_logger
assert len(blank_logger) == 1
if len(quantities) != 0:
expected_namespaces = [
name for name in expected_namespaces
if any(name[-1] in q for q in quantities)
]
assert all(ns in blank_logger for ns in expected_namespaces)
assert len(blank_logger) == len(expected_namespaces)

def test_add_with_user_names(self, blank_logger, logged_obj,
base_namespace):
def test_add_with_user_names(self, logged_obj, base_namespace):
logger = Logger()
# Test adding a user specified identifier into the namespace
user_name = 'UserName'
blank_logger.add(logged_obj, user_name=user_name)
assert base_namespace[:-1] + (user_name, 'prop') in blank_logger
assert base_namespace[:-1] + (user_name, 'proplist') in blank_logger

def test_add_with_categories(self, blank_logger, logged_obj,
base_namespace):
blank_logger._categories = LoggerCategories['scalar']
# Test adding everything should filter non-scalar
blank_logger.add(logged_obj)
expected_namespace = base_namespace + ('prop',)
assert expected_namespace in blank_logger
blank_logger._categories = LoggerCategories['sequence']
expected_namespace = base_namespace + ('proplist',)
blank_logger.add(logged_obj)
assert expected_namespace in blank_logger
assert len(blank_logger) == 2
logger.add(logged_obj, user_name=user_name)
assert base_namespace[:-1] + (user_name, 'prop') in logger
assert base_namespace[:-1] + (user_name, 'proplist') in logger

def test_remove(self, logged_obj, base_namespace):

Expand Down Expand Up @@ -359,21 +405,27 @@ def test_remove(self, logged_obj, base_namespace):
assert prop_namespace[:-2] + (prop_namespace[-2] + '_1',
prop_namespace[-1]) not in log

def test_remove_with_user_name(self, blank_logger, logged_obj,
base_namespace):
def test_remove_with_user_name(self, logged_obj, base_namespace):
# Test remove using a user specified namespace identifier
logger = Logger()
user_name = 'UserName'
blank_logger.add(logged_obj, user_name=user_name)
assert base_namespace[:-1] + (user_name, 'prop') in blank_logger
assert base_namespace[:-1] + (user_name, 'proplist') in blank_logger
logger.add(logged_obj, user_name=user_name)
assert base_namespace[:-1] + (user_name, 'prop') in logger
assert base_namespace[:-1] + (user_name, 'proplist') in logger

def test_iadd(self, blank_logger, logged_obj):
blank_logger.add(logged_obj)
add_log = blank_logger._dict
blank_logger._dict = dict()
blank_logger += logged_obj
assert add_log == blank_logger._dict
assert len(blank_logger) == 2
log_filter = self.get_filter(blank_logger)
expected_loggables = [
loggable.name
for loggable in logged_obj._export_dict.values()
if log_filter(loggable)
]
assert len(blank_logger) == len(expected_loggables)

def test_isub(self, logged_obj, base_namespace):

Expand Down
2 changes: 1 addition & 1 deletion hoomd/pytest/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __eq__(self, other):

@pytest.fixture
def logger():
logger = hoomd.logging.Logger(categories=['scalar'])
logger = hoomd.logging.Logger(categories=['scalar', "string"])
tommy-waltmann marked this conversation as resolved.
Show resolved Hide resolved
logger[('dummy', 'loggable', 'int')] = (Identity(42000000), 'scalar')
logger[('dummy', 'loggable', 'float')] = (Identity(3.1415), 'scalar')
logger[('dummy', 'loggable', 'string')] = (Identity("foobarbaz"), 'string')
Expand Down