From 25c54c33d579f8dc93b589a4c0c443a1c51519e7 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Tue, 28 Oct 2025 14:42:17 +0100 Subject: [PATCH] more sophisticated model colour algorithm --- .../model/model_collection.py | 56 ++++++++++++++--- tests/model/test_model_collection.py | 63 +++++++++++++++++++ 2 files changed, 112 insertions(+), 7 deletions(-) diff --git a/src/easyreflectometry/model/model_collection.py b/src/easyreflectometry/model/model_collection.py index 84292f3a..b3c0bd2d 100644 --- a/src/easyreflectometry/model/model_collection.py +++ b/src/easyreflectometry/model/model_collection.py @@ -23,6 +23,7 @@ def __init__( interface=None, unique_name: Optional[str] = None, populate_if_none: bool = True, + next_color_index: Optional[int] = None, **kwargs, ): if not models: @@ -33,8 +34,17 @@ def __init__( # Needed to ensure an empty list is created when saving and instatiating the object as_dict -> from_dict # Else collisions might occur in global_object.map self.populate_if_none = False + self._next_color_index = next_color_index - super().__init__(name, interface, unique_name=unique_name, *models, **kwargs) + super().__init__(name, interface, *models, unique_name=unique_name, **kwargs) + + color_count = len(COLORS) + if color_count == 0: + self._next_color_index = 0 + elif self._next_color_index is None: + self._next_color_index = len(self) % color_count + else: + self._next_color_index %= color_count def add_model(self, model: Optional[Model] = None): """Add a model to the collection. @@ -42,8 +52,7 @@ def add_model(self, model: Optional[Model] = None): :param model: Model to add. """ if model is None: - color = COLORS[len(self) % len(COLORS)] - model = Model(name='Model', interface=self.interface, color=color) + model = Model(name='Model', interface=self.interface, color=self._current_color()) self.append(model) def duplicate_model(self, index: int): @@ -59,6 +68,7 @@ def duplicate_model(self, index: int): def as_dict(self, skip: List[str] | None = None) -> dict: this_dict = super().as_dict(skip=skip) this_dict['populate_if_none'] = self.populate_if_none + this_dict['next_color_index'] = self._next_color_index return this_dict @classmethod @@ -69,16 +79,48 @@ def from_dict(cls, this_dict: dict) -> ModelCollection: :param data: The dictionary for the collection """ collection_dict = this_dict.copy() - # We neeed to call from_dict on the base class to get the models - dict_data = collection_dict['data'] - del collection_dict['data'] + # We need to call from_dict on the base class to get the models + dict_data = collection_dict.pop('data') + next_color_index = collection_dict.pop('next_color_index', None) collection = super().from_dict(collection_dict) # type: ModelCollection for model_data in dict_data: - collection.add_model(Model.from_dict(model_data)) + collection._append_internal(Model.from_dict(model_data), advance=False) if len(collection) != len(this_dict['data']): raise ValueError(f'Expected {len(collection)} models, got {len(this_dict["data"])}') + color_count = len(COLORS) + if color_count == 0: + collection._next_color_index = 0 + elif next_color_index is None: + collection._next_color_index = len(collection) % color_count + else: + collection._next_color_index = next_color_index % color_count + return collection + + def append(self, model: Model) -> None: # type: ignore[override] + self._append_internal(model, advance=True) + + def _append_internal(self, model: Model, advance: bool) -> None: + super().append(model) + if advance: + self._advance_color_index() + + def _advance_color_index(self) -> None: + if not COLORS: + self._next_color_index = 0 + return + if self._next_color_index is None: + self._next_color_index = len(self) % len(COLORS) + return + self._next_color_index = (self._next_color_index + 1) % len(COLORS) + + def _current_color(self) -> str: + if not COLORS: + raise ValueError('No colors defined for models.') + if self._next_color_index is None: + self._next_color_index = 0 + return COLORS[self._next_color_index] diff --git a/tests/model/test_model_collection.py b/tests/model/test_model_collection.py index c8e60d92..cc98534d 100644 --- a/tests/model/test_model_collection.py +++ b/tests/model/test_model_collection.py @@ -1,5 +1,6 @@ from easyscience import global_object +from easyreflectometry.model.model import COLORS from easyreflectometry.model.model import Model from easyreflectometry.model.model_collection import ModelCollection @@ -52,6 +53,44 @@ def test_add_model(self): assert collection[0].name == 'Model1' assert collection[1].name == 'Model2' + def test_add_model_color_cycle(self): + collection = ModelCollection(populate_if_none=False) + + collection.add_model() + assert collection[0].color == COLORS[0] + + collection.add_model() + assert collection[1].color == COLORS[1] + + collection.remove(0) + collection.add_model() + + assert collection[0].color == COLORS[1] + assert collection[1].color == COLORS[2] + + def test_add_model_color_wrap(self): + collection = ModelCollection(populate_if_none=False) + + for _ in range(len(COLORS)): + collection.add_model() + + collection.add_model() + + assert collection[-1].color == COLORS[0] + + def test_add_model_preserves_explicit_color(self): + collection = ModelCollection(populate_if_none=False) + collection.add_model() + expected_index = collection._next_color_index + + custom_color = '#ABCDEF' + custom_model = Model(name='Custom', color=custom_color) + + collection.add_model(custom_model) + + assert collection[-1].color == custom_color + assert collection._next_color_index == (expected_index + 1) % len(COLORS) + def test_delete_model(self): # When model_1 = Model(name='Model1') @@ -94,3 +133,27 @@ def test_dict_round_trip(self): q.as_dict(skip=['resolution_function', 'interface']) ) assert p[0]._resolution_function.smearing(5.5) == q[0]._resolution_function.smearing(5.5) + + def test_next_color_index_round_trip(self): + collection = ModelCollection(populate_if_none=False) + for _ in range(3): + collection.add_model() + + expected_index = collection._next_color_index + dict_repr = collection.as_dict() + global_object.map._clear() + + restored = ModelCollection.from_dict(dict_repr) + + assert restored._next_color_index == expected_index + + def test_legacy_from_dict_sets_color_index(self): + collection = ModelCollection() + legacy_dict = collection.as_dict() + legacy_dict.pop('next_color_index', None) + global_object.map._clear() + + restored = ModelCollection.from_dict(legacy_dict) + restored.add_model() + + assert [model.color for model in restored] == [COLORS[0], COLORS[1]]