Skip to content

Commit

Permalink
Make Complex trait type validation duck-typed. (#1594)
Browse files Browse the repository at this point in the history
* Do duck-typed Complex validation

* Fix skip condition

* Rename for consistency with #1595
  • Loading branch information
mdickinson committed Nov 15, 2021
1 parent 9fd358b commit f6e00f2
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 10 deletions.
3 changes: 3 additions & 0 deletions traits/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ class ValidateTrait(IntEnum):
#: A callable check.
callable = 22

#: A complex number check.
complex_number = 23


class ComparisonMode(IntEnum):
""" Comparison mode.
Expand Down
101 changes: 100 additions & 1 deletion traits/ctraits.c
Original file line number Diff line number Diff line change
Expand Up @@ -3390,6 +3390,77 @@ validate_trait_float(
return result;
}

/*
Convert a complex-number-like Python object to a complex number.
Returns a new object of exact type complex, or raises TypeError
if the given object cannot be converted to a complex number.
Here complex-number-like means:
- is an instance of complex, or
- can be converted to a a complex number via its type's __complex__ method,
or
- can be converted to a float via its type's __float__ method, or
- (for Python >= 3.8) can be converted to an int via its type's __index__
method.
In other words, these should be exactly the Python objects that are
accepted by the standard functions in the cmath module.
*/

static PyObject *
validate_complex_number(PyObject *value)
{
Py_complex value_as_complex;

/* Fast path for common case. */
if (PyComplex_CheckExact(value)) {
Py_INCREF(value);
return value;
}

/* General case: defer to the machinations of PyComplex_AsCComplex. */
value_as_complex = PyComplex_AsCComplex(value);
if (value_as_complex.real == -1.0 && PyErr_Occurred()) {
return NULL;
}
return PyComplex_FromCComplex(value_as_complex);
}

static PyObject *
_ctraits_validate_complex_number(PyObject *self, PyObject *value)
{
return validate_complex_number(value);
}

/*-----------------------------------------------------------------------------
| Verifies that a Python value is convertible to a complex number.
|
| Will convert anything whose type has a __complex__, __float__ or (for
| Python >= 3.8) __index__ method to a Python complex number. Returns a Python
| object of exact type "complex". Raises TraitError with a suitable message if
| the given value isn't convertible to a complex number.
|
| Any exception other than TypeError raised by any of the special methods
| will be propagated. A TypeError will be caught and turned into TraitError.
+----------------------------------------------------------------------------*/

static PyObject *
validate_trait_complex_number(
trait_object *trait, has_traits_object *obj, PyObject *name,
PyObject *value)
{
PyObject *result = validate_complex_number(value);
/* A TypeError represents a type validation failure, and should be
re-raised as a TraitError. Other exceptions should be propagated. */
if (result == NULL && PyErr_ExceptionMatches(PyExc_TypeError)) {
PyErr_Clear();
return raise_trait_error(trait, obj, name, value);
}
return result;
}

/*
Determine whether `value` lies in the range specified by `range_info`.
Expand Down Expand Up @@ -4159,6 +4230,18 @@ validate_trait_complex(
break;
}

case 23: /* Complex number check */
/* A TypeError indicates that we don't have a match.
Clear the error and continue with the next item
in the complex sequence. */
result = validate_complex_number(value);
if (result == NULL
&& PyErr_ExceptionMatches(PyExc_TypeError)) {
PyErr_Clear();
break;
}
return result;

default: /* Should never happen...indicates an internal error: */
assert(0); /* invalid validation type */
goto error;
Expand Down Expand Up @@ -4198,6 +4281,7 @@ static trait_validate validate_handlers[] = {
validate_trait_integer, /* case 20: Integer check */
validate_trait_float, /* case 21: Float check */
validate_trait_callable, /* case 22: Callable check */
validate_trait_complex_number, /* case 23: Complex number check */
};

static PyObject *
Expand Down Expand Up @@ -4357,6 +4441,12 @@ _trait_set_validate(trait_object *trait, PyObject *args)
goto done;
}
break;

case 23: /* Complex check: */
if (n == 1) {
goto done;
}
break;
}
}
}
Expand Down Expand Up @@ -5553,7 +5643,6 @@ _ctraits_ctrait(PyObject *self, PyObject *args)
| 'CTrait' instance methods:
+----------------------------------------------------------------------------*/


PyDoc_STRVAR(
_ctraits_validate_float_doc,
"_validate_float(number)\n"
Expand All @@ -5562,6 +5651,14 @@ PyDoc_STRVAR(
"conversion is not possible.\n"
);

PyDoc_STRVAR(
_ctraits_validate_complex_number_doc,
"_validate_complex_number(number)\n"
"\n"
"Return *number* converted to a complex number. Raise TypeError if \n"
"conversion is not possible.\n"
);

static PyMethodDef ctraits_methods[] = {
{"_list_classes", (PyCFunction)_ctraits_list_classes, METH_VARARGS,
PyDoc_STR(
Expand All @@ -5572,6 +5669,8 @@ static PyMethodDef ctraits_methods[] = {
PyDoc_STR("_ctrait(CTrait_class)")},
{"_validate_float", (PyCFunction)_ctraits_validate_float, METH_O,
_ctraits_validate_float_doc},
{"_validate_complex_number", (PyCFunction)_ctraits_validate_complex_number,
METH_O, _ctraits_validate_complex_number_doc},
{NULL, NULL},
};

Expand Down
175 changes: 175 additions & 0 deletions traits/tests/test_complex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!

"""
Tests for the Complex trait type.
"""

import unittest

from traits.api import BaseComplex, Complex, Either, HasTraits, TraitError
from traits.testing.optional_dependencies import numpy, requires_numpy


class IntegerLike:
def __init__(self, value):
self._value = value

def __index__(self):
return self._value


# Python versions < 3.8 don't support conversion of something with __index__
# to complex.
try:
complex(IntegerLike(3))
except TypeError:
complex_accepts_index = False
else:
complex_accepts_index = True


class FloatLike:
def __init__(self, value):
self._value = value

def __float__(self):
return self._value


class ComplexLike:
def __init__(self, value):
self._value = value

def __complex__(self):
return self._value


class HasComplexTraits(HasTraits):
value = Complex()

# Assignment to the `Either` trait exercises a different C code path (see
# validate_trait_complex in ctraits.c). This use of "Either" should not
# be replaced with "Union", since "Union" does not exercise that same
# code path.
value_or_none = Either(None, Complex())


class HasBaseComplexTraits(HasTraits):
value = BaseComplex()

value_or_none = Either(None, BaseComplex())


class CommonComplexTests(object):
""" Common tests for Complex and BaseComplex. """

def test_default_value(self):
a = self.test_class()
self.assertIs(type(a.value), complex)
self.assertEqual(a.value, complex(0.0, 0.0))

def test_rejects_str(self):
a = self.test_class()
with self.assertRaises(TraitError):
a.value = "3j"

def test_accepts_int(self):
a = self.test_class()
a.value = 7
self.assertIs(type(a.value), complex)
self.assertEqual(a.value, complex(7.0, 0.0))

def test_accepts_float(self):
a = self.test_class()
a.value = 7.0
self.assertIs(type(a.value), complex)
self.assertEqual(a.value, complex(7.0, 0.0))

def test_accepts_complex(self):
a = self.test_class()
a.value = 7j
self.assertIs(type(a.value), complex)
self.assertEqual(a.value, complex(0.0, 7.0))

def test_accepts_complex_subclass(self):
class ComplexSubclass(complex):
pass

a = self.test_class()
a.value = ComplexSubclass(5.0, 12.0)
self.assertIs(type(a.value), complex)
self.assertEqual(a.value, complex(5.0, 12.0))

@unittest.skipUnless(
complex_accepts_index,
"complex does not support __index__ for this Python version",
)
def test_accepts_integer_like(self):
a = self.test_class()
a.value = IntegerLike(3)
self.assertIs(type(a.value), complex)
self.assertEqual(a.value, complex(3.0, 0.0))

def test_accepts_float_like(self):
a = self.test_class()
a.value = FloatLike(3.2)
self.assertIs(type(a.value), complex)
self.assertEqual(a.value, complex(3.2, 0.0))

def test_accepts_complex_like(self):
a = self.test_class()
a.value = ComplexLike(3.0 + 4j)
self.assertIs(type(a.value), complex)
self.assertEqual(a.value, complex(3.0, 4.0))

@requires_numpy
def test_accepts_numpy_values(self):
test_values = [
numpy.int32(23),
numpy.float32(3.7),
numpy.float64(2.3),
numpy.complex64(1.2 - 3.8j),
numpy.complex128(3.1 + 4.7j),
]
for value in test_values:
with self.subTest(value=value):
a = self.test_class()
a.value = value
self.assertIs(type(a.value), complex)
self.assertEqual(a.value, complex(value))

def test_validate_trait_complex_code_path(self):
a = self.test_class()
a.value_or_none = 3.0 + 4j
self.assertIs(type(a.value_or_none), complex)
self.assertEqual(a.value_or_none, complex(3.0, 4.0))

def test_exceptions_propagated(self):
class CustomException(Exception):
pass

class BadComplexLike:
def __complex__(self):
raise CustomException("something went wrong")

a = self.test_class()
with self.assertRaises(CustomException):
a.value = BadComplexLike()


class TestComplex(unittest.TestCase, CommonComplexTests):
def setUp(self):
self.test_class = HasComplexTraits


class TestBaseComplex(unittest.TestCase, CommonComplexTests):
def setUp(self):
self.test_class = HasBaseComplexTraits
15 changes: 6 additions & 9 deletions traits/trait_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import warnings

from .constants import DefaultValue, TraitKind, ValidateTrait
from .ctraits import _validate_float
from .ctraits import _validate_complex_number, _validate_float
from .trait_base import (
strx,
get_module_name,
Expand Down Expand Up @@ -374,13 +374,10 @@ def validate(self, object, name, value):
Note: The 'fast validator' version performs this check in C.
"""
if isinstance(value, complex):
return value

if isinstance(value, (float, int)):
return complex(value)

self.error(object, name, value)
try:
return _validate_complex_number(value)
except TypeError:
self.error(object, name, value)

def create_editor(self):
""" Returns the default traits UI editor for this type of trait.
Expand All @@ -396,7 +393,7 @@ class Complex(BaseComplex):
"""

#: The C-level fast validator to use:
fast_validate = complex_fast_validate
fast_validate = (ValidateTrait.complex_number,)


class BaseStr(TraitType):
Expand Down

0 comments on commit f6e00f2

Please sign in to comment.