Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/easyscience/global_object/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def is_connected(self, vertices_encountered=None, start_vertex=None) -> bool:
return False

def _clear(self):
"""Reset the map to an empty state."""
"""Reset the map to an empty state. Only to be used for testing"""
for vertex in self.vertices():
self.prune(vertex)
gc.collect()
Expand Down
12 changes: 12 additions & 0 deletions src/easyscience/variable/descriptor_number.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import numbers
import uuid
from typing import Any
from typing import Dict
from typing import List
Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
url: Optional[str] = None,
display_name: Optional[str] = None,
parent: Optional[Any] = None,
**kwargs: Any # Additional keyword arguments (used for (de)serialization)
):
"""Constructor for the DescriptorNumber class

Expand All @@ -65,6 +67,10 @@ def __init__(
"""
self._observers: List[DescriptorNumber] = []

# Extract dependency_id if provided during deserialization
if '__dependency_id' in kwargs:
self.__dependency_id = kwargs.pop('__dependency_id')

if not isinstance(value, numbers.Number) or isinstance(value, bool):
raise TypeError(f'{value=} must be a number')
if variance is not None:
Expand Down Expand Up @@ -112,10 +118,14 @@ def from_scipp(cls, name: str, full_value: Variable, **kwargs) -> DescriptorNumb
def _attach_observer(self, observer: DescriptorNumber) -> None:
"""Attach an observer to the descriptor."""
self._observers.append(observer)
if not hasattr(self, '_DescriptorNumber__dependency_id'):
self.__dependency_id = str(uuid.uuid4())

def _detach_observer(self, observer: DescriptorNumber) -> None:
"""Detach an observer from the descriptor."""
self._observers.remove(observer)
if not self._observers:
del self.__dependency_id

def _notify_observers(self) -> None:
"""Notify all observers of a change."""
Expand Down Expand Up @@ -323,6 +333,8 @@ def as_dict(self, skip: Optional[List[str]] = None) -> Dict[str, Any]:
raw_dict['value'] = self._scalar.value
raw_dict['unit'] = str(self._scalar.unit)
raw_dict['variance'] = self._scalar.variance
if hasattr(self, '_DescriptorNumber__dependency_id'):
raw_dict['__dependency_id'] = self.__dependency_id
return raw_dict

def __add__(self, other: Union[DescriptorNumber, numbers.Number]) -> DescriptorNumber:
Expand Down
140 changes: 40 additions & 100 deletions src/easyscience/variable/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import weakref
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

Expand All @@ -36,14 +37,6 @@ class Parameter(DescriptorNumber):
# We copy the parent's _REDIRECT and modify it to avoid altering the parent's class dict
_REDIRECT = DescriptorNumber._REDIRECT.copy()
_REDIRECT['callback'] = None
# Skip these attributes during normal serialization as they are handled specially
_REDIRECT['_dependency_interpreter'] = None
_REDIRECT['_clean_dependency_string'] = None
# Skip the new serialization parameters - they'll be handled by _convert_to_dict
_REDIRECT['_dependency_string'] = None
_REDIRECT['_dependency_map_unique_names'] = None
_REDIRECT['_dependency_map_dependency_ids'] = None
_REDIRECT['__dependency_id'] = None

def __init__(
self,
Expand Down Expand Up @@ -84,11 +77,8 @@ def __init__(
""" # noqa: E501
# Extract and ignore serialization-specific fields from kwargs
kwargs.pop('_dependency_string', None)
kwargs.pop('_dependency_map_unique_names', None)
kwargs.pop('_dependency_map_dependency_ids', None)
kwargs.pop('_independent', None)
# Extract dependency_id if provided during deserialization
provided_dependency_id = kwargs.pop('__dependency_id', None)

if not isinstance(min, numbers.Number):
raise TypeError('`min` must be a number')
Expand Down Expand Up @@ -119,6 +109,7 @@ def __init__(
url=url,
display_name=display_name,
parent=parent,
**kwargs, # Additional keyword arguments (used for (de)serialization)
)

self._callback = callback # Callback is used by interface to link to model
Expand All @@ -128,13 +119,6 @@ def __init__(
# Create additional fitting elements
self._initial_scalar = copy.deepcopy(self._scalar)

# Generate unique dependency ID for serialization/deserialization
# Use provided dependency_id if available (during deserialization), otherwise generate new one
if provided_dependency_id is not None:
self.__dependency_id = provided_dependency_id
else:
import uuid
self.__dependency_id = str(uuid.uuid4())

@classmethod
def from_dependency(cls, name: str, dependency_expression: str, dependency_map: Optional[dict] = None, **kwargs) -> Parameter: # noqa: E501
Expand All @@ -147,15 +131,15 @@ def from_dependency(cls, name: str, dependency_expression: str, dependency_map:
:param kwargs: Additional keyword arguments to pass to the Parameter constructor.
:return: A new dependent Parameter object.
""" # noqa: E501
# Set default values for required parameters if not provided in kwargs
# Set default values for required parameters for the constructor, they get overwritten by the dependency anyways
default_kwargs = {
'value': 0.0,
'unit': '',
'variance': 0.0,
'min': -np.inf,
'max': np.inf
}
# Update with user-provided kwargs, giving precedence to user values
# Update with user-provided kwargs, to avoid errors.
default_kwargs.update(kwargs)
parameter = cls(name=name, **default_kwargs)
parameter.make_dependent_on(dependency_expression=dependency_expression, dependency_map=dependency_map)
Expand Down Expand Up @@ -331,15 +315,6 @@ def dependency_map(self) -> Dict[str, DescriptorNumber]:
def dependency_map(self, new_map: Dict[str, DescriptorNumber]) -> None:
raise AttributeError('Dependency map is read-only. Use `make_dependent_on` to change the dependency map.')

@property
def dependency_id(self) -> str:
"""
Get the unique dependency ID of this parameter used for serialization.

:return: The dependency ID of this parameter.
"""
return self.__dependency_id

@property
def value_no_call_back(self) -> numbers.Number:
"""
Expand Down Expand Up @@ -553,6 +528,26 @@ def free(self) -> bool:
def free(self, value: bool) -> None:
self.fixed = not value

def as_dict(self, skip: Optional[List[str]] = None) -> Dict[str, Any]:
""" Overwrite the as_dict method to handle dependency information. """
raw_dict = super().as_dict(skip=skip)


# Add dependency information for dependent parameters
if not self._independent:
# Save the dependency expression
raw_dict['_dependency_string'] = self._clean_dependency_string

# Mark that this parameter is dependent
raw_dict['_independent'] = self._independent

# Convert dependency_map to use dependency_ids
raw_dict['_dependency_map_dependency_ids'] = {}
for key, obj in self._dependency_map.items():
raw_dict['_dependency_map_dependency_ids'][key] = obj._DescriptorNumber__dependency_id

return raw_dict

def _revert_dependency(self, skip_detach=False) -> None:
"""
Revert the dependency to the old dependency. This is used when an error is raised during setting the dependency.
Expand Down Expand Up @@ -595,63 +590,26 @@ def _process_dependency_unique_names(self, dependency_expression: str):
raise ValueError(f'The object with unique_name {stripped_name} is not a Parameter or DescriptorNumber. Please check your dependency expression.') # noqa: E501
self._clean_dependency_string = clean_dependency_string

def _convert_to_dict(self, d: dict, encoder, skip=None, **kwargs) -> dict:
"""Custom serialization to handle parameter dependencies."""
if skip is None:
skip = []

# Add dependency information for dependent parameters
if not self._independent:
# Save the dependency expression
d['_dependency_string'] = self._dependency_string

# Convert dependency_map to use dependency_ids (preferred) and unique_names (fallback)
d['_dependency_map_dependency_ids'] = {}
d['_dependency_map_unique_names'] = {}

for key, dep_obj in self._dependency_map.items():
# Store dependency_id if available (more reliable)
if hasattr(dep_obj, '__dependency_id'):
d['_dependency_map_dependency_ids'][key] = dep_obj.__dependency_id
# Also store unique_name as fallback
if hasattr(dep_obj, 'unique_name'):
d['_dependency_map_unique_names'][key] = dep_obj.unique_name
else:
# This is quite impossible - throw an error
raise ValueError(f'The object with unique_name {key} does not have a unique_name attribute.')

# Always include dependency_id for this parameter
d['__dependency_id'] = self.__dependency_id

# Mark that this parameter is dependent
d['_independent'] = self._independent

return d

@classmethod
@classmethod
def from_dict(cls, obj_dict: dict) -> 'Parameter':
"""
Custom deserialization to handle parameter dependencies.
Override the parent method to handle dependency information.
"""
# Extract dependency information before creating the parameter
d = obj_dict.copy() # Don't modify the original dict
dependency_string = d.pop('_dependency_string', None)
dependency_map_unique_names = d.pop('_dependency_map_unique_names', None)
dependency_map_dependency_ids = d.pop('_dependency_map_dependency_ids', None)
is_independent = d.pop('_independent', True)
# Note: Keep __dependency_id in the dict so it gets passed to __init__
raw_dict = obj_dict.copy() # Don't modify the original dict
dependency_string = raw_dict.pop('_dependency_string', None)
dependency_map_dependency_ids = raw_dict.pop('_dependency_map_dependency_ids', None)
is_independent = raw_dict.pop('_independent', True)
# Note: Keep _dependency_id in the dict so it gets passed to __init__

# Create the parameter using the base class method (dependency_id is now handled in __init__)
param = super().from_dict(d)
param = super().from_dict(raw_dict)

# Store dependency information for later resolution
if not is_independent and dependency_string is not None:
if not is_independent:
param._pending_dependency_string = dependency_string
if dependency_map_dependency_ids:
param._pending_dependency_map_dependency_ids = dependency_map_dependency_ids
if dependency_map_unique_names:
param._pending_dependency_map_unique_names = dependency_map_unique_names
param._pending_dependency_map_dependency_ids = dependency_map_dependency_ids
# Keep parameter as independent initially - will be made dependent after all objects are loaded
param._independent = True

Expand Down Expand Up @@ -995,13 +953,12 @@ def resolve_pending_dependencies(self) -> None:
"""Resolve pending dependencies after deserialization.

This method should be called after all parameters have been deserialized
to establish dependency relationships using dependency_ids or unique_names as fallback.
to establish dependency relationships using dependency_ids.
"""
if hasattr(self, '_pending_dependency_string'):
dependency_string = self._pending_dependency_string
dependency_map = {}

# Try dependency IDs first (more reliable)
if hasattr(self, '_pending_dependency_map_dependency_ids'):
dependency_map_dependency_ids = self._pending_dependency_map_dependency_ids

Expand All @@ -1013,22 +970,6 @@ def resolve_pending_dependencies(self) -> None:
else:
raise ValueError(f"Cannot find parameter with dependency_id '{dependency_id}'")

# Fallback to unique_names if dependency IDs not available or incomplete
if hasattr(self, '_pending_dependency_map_unique_names'):
dependency_map_unique_names = self._pending_dependency_map_unique_names

for key, unique_name in dependency_map_unique_names.items():
if key not in dependency_map: # Only add if not already resolved by dependency_id
try:
# Look up the parameter by unique_name in the global map
dep_obj = self._global_object.map.get_item_by_key(unique_name)
if dep_obj is not None:
dependency_map[key] = dep_obj
else:
raise ValueError(f"Cannot find parameter with unique_name '{unique_name}'")
except Exception as e:
raise ValueError(f"Error resolving dependency '{key}' -> '{unique_name}': {e}")

# Establish the dependency relationship
try:
self.make_dependent_on(dependency_expression=dependency_string, dependency_map=dependency_map)
Expand All @@ -1037,14 +978,13 @@ def resolve_pending_dependencies(self) -> None:

# Clean up temporary attributes
delattr(self, '_pending_dependency_string')
if hasattr(self, '_pending_dependency_map_dependency_ids'):
delattr(self, '_pending_dependency_map_dependency_ids')
if hasattr(self, '_pending_dependency_map_unique_names'):
delattr(self, '_pending_dependency_map_unique_names')
delattr(self, '_pending_dependency_map_dependency_ids')

def _find_parameter_by_dependency_id(self, dependency_id: str) -> Optional['Parameter']:
def _find_parameter_by_dependency_id(self, dependency_id: str) -> Optional['DescriptorNumber']:
"""Find a parameter by its dependency_id from all parameters in the global map."""
for obj in self._global_object.map._store.values():
if isinstance(obj, Parameter) and hasattr(obj, '__dependency_id') and obj.__dependency_id == dependency_id:
return obj
if isinstance(obj, DescriptorNumber) and hasattr(obj, '_DescriptorNumber__dependency_id'):
if obj._DescriptorNumber__dependency_id == dependency_id:
return obj
return None

2 changes: 1 addition & 1 deletion src/easyscience/variable/parameter_dependency_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _collect_parameters(item: Any, parameters: List[Parameter]) -> None:
resolved_count += 1
except Exception as e:
error_count += 1
dependency_id = getattr(param, '__dependency_id', 'unknown')
dependency_id = getattr(param, '_DescriptorNumber__dependency_id', 'unknown')
errors.append(f"Failed to resolve dependencies for parameter '{param.name}'" \
f" (unique_name: '{param.unique_name}', dependency_id: '{dependency_id}'): {e}")

Expand Down
Loading
Loading