diff --git a/hyperspy/component.py b/hyperspy/component.py index d00d16ff24..676f24c079 100644 --- a/hyperspy/component.py +++ b/hyperspy/component.py @@ -721,11 +721,12 @@ def default_traits_view(self): @add_gui_method(toolkey="hyperspy.Component") class Component(t.HasTraits): __axes_manager = None + # setting dtype for t.Property(t.Bool) causes serialization error with cloudpickle + active = t.Property() + name = t.Property() - active = t.Property(t.CBool(True)) - name = t.Property(t.Str('')) - - def __init__(self, parameter_name_list, linear_parameter_list=None): + def __init__(self, parameter_name_list, linear_parameter_list=None, *args, **kwargs): + super().__init__(*args, **kwargs) self.events = Events() self.events.active_changed = Event(""" Event that triggers when the `Component.active` changes. @@ -806,6 +807,8 @@ def _get_name(self): return self._name def _set_name(self, value): + if not isinstance(value, str): + raise ValueError('Only string values are permitted') old_value = self._name if old_value == value: return diff --git a/hyperspy/model.py b/hyperspy/model.py index fc6da82b19..1015dc4ae4 100644 --- a/hyperspy/model.py +++ b/hyperspy/model.py @@ -125,7 +125,14 @@ def reconstruct_component(comp_dictionary, **init_args): elif "_class_dump" in comp_dictionary: # When a component is not registered using the extension mechanism, # it is serialized using cloudpickle. - _class = cloudpickle.loads(comp_dictionary['_class_dump']) + try: + _class = cloudpickle.loads(comp_dictionary['_class_dump']) + except TypeError: # pragma: no cover + # https://github.com/cloudpipe/cloudpickle/blob/master/README.md + raise TypeError("Pickling is not (always) supported between python " + "versions. As a result the custom class cannot be " + "loaded. Consider adding a custom Component using the " + "extension mechanism.") else: # For component saved with hyperspy <2.0 and moved to exspy if comp_dictionary["_id_name"] in EXSPY_HSPY_COMPONENTS: diff --git a/hyperspy/tests/component/data/hs2.0_custom_component.hspy b/hyperspy/tests/component/data/hs2.0_custom_component.hspy new file mode 100644 index 0000000000..076d980def Binary files /dev/null and b/hyperspy/tests/component/data/hs2.0_custom_component.hspy differ diff --git a/hyperspy/tests/component/test_component.py b/hyperspy/tests/component/test_component.py index 9a8913c4ec..347f25cc88 100644 --- a/hyperspy/tests/component/test_component.py +++ b/hyperspy/tests/component/test_component.py @@ -18,14 +18,18 @@ import pytest from unittest import mock +import pathlib import numpy as np +import hyperspy.api as hs from hyperspy.axes import AxesManager from hyperspy.component import Component, Parameter, _get_scaling_factor from hyperspy._signals.signal1d import Signal1D +DIRPATH = pathlib.Path(__file__).parent / "data" + class TestMultidimensionalActive: def setup_method(self, method): @@ -298,3 +302,66 @@ def test_linear_parameter_initialisation(): assert C.one._linear assert not C.two._linear assert not P._linear + + +def test_set_name(): + c = Component(['one', 'two'], ['one']) + c.name = 'test' + assert c.name == 'test' + assert c._name == 'test' + + +def test_set_name_error(): + c = Component(['one', 'two'], ['one']) + with pytest.raises(ValueError): + c.name = 1 + + +def test_loading_non_expression_custom_component(tmp_path): + # non-expression based custom component uses serialisation + # to save the components. + + import hyperspy.api as hs + from hyperspy.component import Component + + class CustomComponent(Component): + + def __init__(self, p1=1, p2=2): + Component.__init__(self, ('p1', 'p2')) + + self.p1.value = p1 + self.p2.value = p2 + + self.p1.grad = self.grad_p1 + self.p2.grad = self.grad_p2 + + def function(self, x): + p1 = self.p1.value + p2 = self.p2.value + return p1 + x * p2 + + def grad_p1(self, x): + return 0 + + def grad_p2(self, x): + return x + + s = hs.signals.Signal1D(range(10)) + m = s.create_model() + + c = CustomComponent() + m.append(c) + m.store('a') + + s.save(tmp_path / "hs2.0_custom_component.hspy") + + s = hs.load(tmp_path / "hs2.0_custom_component.hspy") + _ = s.models.restore('a') + + +def test_load_component_previous_python(): + s = hs.load(DIRPATH / "hs2.0_custom_component.hspy") + import sys + if sys.version_info[0] == 3.11: + with pytest.raises(TypeError): + _ = s.models.restore('a') \ No newline at end of file diff --git a/upcoming_changes/3262.bugfix.rst b/upcoming_changes/3262.bugfix.rst new file mode 100644 index 0000000000..448b0c14ad --- /dev/null +++ b/upcoming_changes/3262.bugfix.rst @@ -0,0 +1,2 @@ +Fix serialization error due to :py:class:`traits.api.Property` not being serializable if a dtype is specified. +See #3261 for more details. \ No newline at end of file