Skip to content

Commit

Permalink
Make sure that user supplied requirement functions are always checked (
Browse files Browse the repository at this point in the history
…#725)

* Make sure that a user supplied requirement function is always checked

* Remove unused reference

Co-authored-by: Dimitri RODARIE <d.rodarie@gmail.com>

* value needs to be None for regular bool requirements to proc

* added tests

* added call test

---------

Co-authored-by: Dimitri RODARIE <d.rodarie@gmail.com>
  • Loading branch information
Helveg and drodarie committed May 31, 2023
1 parent f31c8a1 commit 2e7b833
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 6 deletions.
4 changes: 2 additions & 2 deletions bsb/config/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _set_pk(obj, parent, key):
_setattr(obj, a.attr_name, key)


def _check_required(instance, attr, kwargs):
def _missing_requirements(instance, attr, kwargs):
# We use `self.__class__`, not `cls`, to get the proper child class.
cls = instance.__class__
dynamic_root = getattr(cls, "_config_dynamic_root", None)
Expand Down Expand Up @@ -261,7 +261,7 @@ def __post_new__(self, _parent=None, _key=None, **kwargs):
name = attr.attr_name
value = values[name] = leftovers.pop(name, None)
try:
if value is None and _check_required(self, attr, kwargs):
if _missing_requirements(self, attr, kwargs) and value is None:
raise RequirementError(f"Missing required attribute '{name}'")
except RequirementError as e:
# Catch both our own and possible `attr.required` RequirementErrors
Expand Down
54 changes: 50 additions & 4 deletions tests/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def test_requirement(self):
class Test:
name = config.attr(type=str, required=True)

def regular(value):
return "timmy" in value

def special(value):
raise RequirementError("special")

Expand All @@ -156,13 +159,56 @@ class Test2:
class Test3:
name = config.attr(type=str, required=lambda x: True)

@config.node
class Test4:
name = config.attr(type=str, required=regular)

Test(name="required")
with self.assertRaises(RequirementError):
Test({}, _parent=TestRoot())
Test()
with self.assertRaisesRegex(RequirementError, r"special"):
Test2({}, _parent=TestRoot())
Test2()
with self.assertRaises(RequirementError):
Test3({}, _parent=TestRoot())
t = Test({"name": "hello"}, _parent=TestRoot())
Test3()
Test4()
Test4(timmy="x", name="required")
with self.assertRaises(RequirementError):
Test4(timmy="x")

def test_requirement_proc(self):
fcalled = False

def fspy(value):
nonlocal fcalled
fcalled = True
return False

@config.node
class Test:
name = config.attr(type=str, required=fspy)

tcalled = False

def tspy(value):
nonlocal tcalled
tcalled = True
return True

@config.node
class Test2:
name = config.attr(type=str, required=tspy)

Test()
self.assertTrue(fcalled, "Requirement functions should always be called.")
Test2(name="required")
self.assertTrue(tcalled, "Requirement functions should always be called.")

def test_precast_identity(self):
@config.node
class Test:
name = config.attr(type=str, required=True)

t = Test(name="hello")
self.assertEqual(t, Test(t), "Already cast object should not be altered")


Expand Down

0 comments on commit 2e7b833

Please sign in to comment.