Skip to content

Commit

Permalink
Fix invalid specification of default_value without regard to default_…
Browse files Browse the repository at this point in the history
…value_type (#1631)

This is a two-part fix for #1629:

- Part 1: fix `ctraits.c` to validate the input to `set_default_value` in the case where the given default value type is `DefaultValue.callable_and_args`. This fixes the segfault reported in #1629, replacing it with a more scrutable Python-level exception.
- Part 2: in the `TraitType.clone` base class implementation, when we set a default value, also set the default value type to be consistent with that default value. This still leaves open the possibility for `TraitType` subclasses to do their own thing in their `clone` methods (and in particular, the `List`, `Dict` and `Set` trait types will probably want to take advantage of that - see #1630).
  • Loading branch information
mdickinson committed Apr 20, 2022
1 parent 1de5ac2 commit dc59b5c
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 4 deletions.
16 changes: 16 additions & 0 deletions traits/ctraits.c
Original file line number Diff line number Diff line change
Expand Up @@ -3064,6 +3064,22 @@ _trait_set_default_value(trait_object *trait, PyObject *args)
return NULL;
}

/* Validate the value */
switch (value_type) {
/* We only do sufficient validation to avoid segfaults when
unwrapping the value in `default_value_for`. */
case CALLABLE_AND_ARGS_DEFAULT_VALUE:
if (!PyTuple_Check(value) || PyTuple_GET_SIZE(value) != 3) {
PyErr_SetString(
PyExc_ValueError,
"default value for type DefaultValue.callable_and_args "
"must be a tuple of the form (callable, args, kwds)"
);
return NULL;
}
break;
}

trait->default_value_type = value_type;

/* The DECREF on the old value can call arbitrary code, so take care not to
Expand Down
16 changes: 16 additions & 0 deletions traits/tests/test_ctraits.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@ def test_set_and_get_default_value(self):
trait.default_value(), (DefaultValue.list_copy, [1, 2, 3])
)

def test_validate_default_value_for_callable_and_args(self):

bad_values = [
None,
123,
(int, (2,)),
(int, 2, 3, 4),
]

trait = CTrait(TraitKind.trait)
for value in bad_values:
with self.subTest(value=value):
with self.assertRaises(ValueError):
trait.set_default_value(
DefaultValue.callable_and_args, value)

def test_default_value_for_set_is_deprecated(self):
trait = CTrait(TraitKind.trait)
with warnings.catch_warnings(record=True) as warn_msgs:
Expand Down
14 changes: 14 additions & 0 deletions traits/tests/test_trait_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,20 @@ class MyTraitType(TraitType):
with self.assertRaises(ValueError):
ctrait.default_value_for(None, "<dummy>")

def test_call_sets_default_value_type(self):
class FooTrait(TraitType):
default_value_type = DefaultValue.callable_and_args

def __init__(self, default_value=NoDefaultSpecified, **metadata):
default_value = (pow, (3, 4), {})
super().__init__(default_value, **metadata)

trait = FooTrait()
ctrait = trait.as_ctrait()
self.assertEqual(ctrait.default_value_for(None, "dummy"), 81)
cloned_ctrait = trait(30)
self.assertEqual(cloned_ctrait.default_value_for(None, "dummy"), 30)


class TestDeprecatedTraitTypes(unittest.TestCase):
def test_function_deprecated(self):
Expand Down
11 changes: 7 additions & 4 deletions traits/trait_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,18 @@ def clone(self, default_value=NoDefaultSpecified, **metadata):
new._metadata.update(metadata)

if default_value is not NoDefaultSpecified:
new.default_value = default_value
if self.validate is not None:
try:
new.default_value = self.validate(
None, None, default_value
)
default_value = self.validate(None, None, default_value)
except Exception:
pass

# Known issue: this doesn't do the right thing for
# List, Dict and Set, where we really want to make a copy.
# xref: enthought/traits#1630
new.default_value_type = DefaultValue.constant
new.default_value = default_value

return new

def get_value(self, object, name, trait=None):
Expand Down

0 comments on commit dc59b5c

Please sign in to comment.