Skip to content

Commit

Permalink
Merge pull request #3262 from CSSFrancis/fix_serialization
Browse files Browse the repository at this point in the history
Fix #3261
  • Loading branch information
ericpre committed Nov 8, 2023
2 parents 4cafc47 + 41856f1 commit 6163e53
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 5 deletions.
11 changes: 7 additions & 4 deletions hyperspy/component.py
Expand Up @@ -721,11 +721,12 @@ def default_traits_view(self):
@add_gui_method(toolkey="hyperspy.Component")
class Component(t.HasTraits):
__axes_manager = None
# setting dtype for t.Property(t.Bool) causes serialization error with cloudpickle
active = t.Property()
name = t.Property()

active = t.Property(t.CBool(True))
name = t.Property(t.Str(''))

def __init__(self, parameter_name_list, linear_parameter_list=None):
def __init__(self, parameter_name_list, linear_parameter_list=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.events = Events()
self.events.active_changed = Event("""
Event that triggers when the `Component.active` changes.
Expand Down Expand Up @@ -806,6 +807,8 @@ def _get_name(self):
return self._name

def _set_name(self, value):
if not isinstance(value, str):
raise ValueError('Only string values are permitted')
old_value = self._name
if old_value == value:
return
Expand Down
9 changes: 8 additions & 1 deletion hyperspy/model.py
Expand Up @@ -125,7 +125,14 @@ def reconstruct_component(comp_dictionary, **init_args):
elif "_class_dump" in comp_dictionary:
# When a component is not registered using the extension mechanism,
# it is serialized using cloudpickle.
_class = cloudpickle.loads(comp_dictionary['_class_dump'])
try:
_class = cloudpickle.loads(comp_dictionary['_class_dump'])
except TypeError: # pragma: no cover
# https://github.com/cloudpipe/cloudpickle/blob/master/README.md
raise TypeError("Pickling is not (always) supported between python "
"versions. As a result the custom class cannot be "
"loaded. Consider adding a custom Component using the "
"extension mechanism.")
else:
# For component saved with hyperspy <2.0 and moved to exspy
if comp_dictionary["_id_name"] in EXSPY_HSPY_COMPONENTS:
Expand Down
Binary file not shown.
67 changes: 67 additions & 0 deletions hyperspy/tests/component/test_component.py
Expand Up @@ -18,14 +18,18 @@

import pytest
from unittest import mock
import pathlib

import numpy as np

import hyperspy.api as hs
from hyperspy.axes import AxesManager
from hyperspy.component import Component, Parameter, _get_scaling_factor
from hyperspy._signals.signal1d import Signal1D


DIRPATH = pathlib.Path(__file__).parent / "data"

class TestMultidimensionalActive:

def setup_method(self, method):
Expand Down Expand Up @@ -298,3 +302,66 @@ def test_linear_parameter_initialisation():
assert C.one._linear
assert not C.two._linear
assert not P._linear


def test_set_name():
c = Component(['one', 'two'], ['one'])
c.name = 'test'
assert c.name == 'test'
assert c._name == 'test'


def test_set_name_error():
c = Component(['one', 'two'], ['one'])
with pytest.raises(ValueError):
c.name = 1


def test_loading_non_expression_custom_component(tmp_path):
# non-expression based custom component uses serialisation
# to save the components.

import hyperspy.api as hs
from hyperspy.component import Component

class CustomComponent(Component):

def __init__(self, p1=1, p2=2):
Component.__init__(self, ('p1', 'p2'))

self.p1.value = p1
self.p2.value = p2

self.p1.grad = self.grad_p1
self.p2.grad = self.grad_p2

def function(self, x):
p1 = self.p1.value
p2 = self.p2.value
return p1 + x * p2

def grad_p1(self, x):
return 0

def grad_p2(self, x):
return x

s = hs.signals.Signal1D(range(10))
m = s.create_model()

c = CustomComponent()
m.append(c)
m.store('a')

s.save(tmp_path / "hs2.0_custom_component.hspy")

s = hs.load(tmp_path / "hs2.0_custom_component.hspy")
_ = s.models.restore('a')


def test_load_component_previous_python():
s = hs.load(DIRPATH / "hs2.0_custom_component.hspy")
import sys
if sys.version_info[0] == 3.11:
with pytest.raises(TypeError):
_ = s.models.restore('a')
2 changes: 2 additions & 0 deletions upcoming_changes/3262.bugfix.rst
@@ -0,0 +1,2 @@
Fix serialization error due to :py:class:`traits.api.Property` not being serializable if a dtype is specified.
See #3261 for more details.

0 comments on commit 6163e53

Please sign in to comment.