Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions src/easyreflectometry/model/model_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
interface=None,
unique_name: Optional[str] = None,
populate_if_none: bool = True,
next_color_index: Optional[int] = None,
**kwargs,
):
if not models:
Expand All @@ -33,17 +34,25 @@ def __init__(
# Needed to ensure an empty list is created when saving and instatiating the object as_dict -> from_dict
# Else collisions might occur in global_object.map
self.populate_if_none = False
self._next_color_index = next_color_index

super().__init__(name, interface, unique_name=unique_name, *models, **kwargs)
super().__init__(name, interface, *models, unique_name=unique_name, **kwargs)

color_count = len(COLORS)
if color_count == 0:
self._next_color_index = 0
elif self._next_color_index is None:
self._next_color_index = len(self) % color_count
else:
self._next_color_index %= color_count

def add_model(self, model: Optional[Model] = None):
"""Add a model to the collection.

:param model: Model to add.
"""
if model is None:
color = COLORS[len(self) % len(COLORS)]
model = Model(name='Model', interface=self.interface, color=color)
model = Model(name='Model', interface=self.interface, color=self._current_color())
self.append(model)

def duplicate_model(self, index: int):
Expand All @@ -59,6 +68,7 @@ def duplicate_model(self, index: int):
def as_dict(self, skip: List[str] | None = None) -> dict:
this_dict = super().as_dict(skip=skip)
this_dict['populate_if_none'] = self.populate_if_none
this_dict['next_color_index'] = self._next_color_index
return this_dict

@classmethod
Expand All @@ -69,16 +79,48 @@ def from_dict(cls, this_dict: dict) -> ModelCollection:
:param data: The dictionary for the collection
"""
collection_dict = this_dict.copy()
# We neeed to call from_dict on the base class to get the models
dict_data = collection_dict['data']
del collection_dict['data']
# We need to call from_dict on the base class to get the models
dict_data = collection_dict.pop('data')
next_color_index = collection_dict.pop('next_color_index', None)

collection = super().from_dict(collection_dict) # type: ModelCollection

for model_data in dict_data:
collection.add_model(Model.from_dict(model_data))
collection._append_internal(Model.from_dict(model_data), advance=False)

if len(collection) != len(this_dict['data']):
raise ValueError(f'Expected {len(collection)} models, got {len(this_dict["data"])}')

color_count = len(COLORS)
if color_count == 0:
collection._next_color_index = 0
elif next_color_index is None:
collection._next_color_index = len(collection) % color_count
else:
collection._next_color_index = next_color_index % color_count

return collection

def append(self, model: Model) -> None: # type: ignore[override]
self._append_internal(model, advance=True)

def _append_internal(self, model: Model, advance: bool) -> None:
super().append(model)
if advance:
self._advance_color_index()

def _advance_color_index(self) -> None:
if not COLORS:
self._next_color_index = 0
return
if self._next_color_index is None:
self._next_color_index = len(self) % len(COLORS)
return
self._next_color_index = (self._next_color_index + 1) % len(COLORS)

def _current_color(self) -> str:
if not COLORS:
raise ValueError('No colors defined for models.')
if self._next_color_index is None:
self._next_color_index = 0
return COLORS[self._next_color_index]
63 changes: 63 additions & 0 deletions tests/model/test_model_collection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from easyscience import global_object

from easyreflectometry.model.model import COLORS
from easyreflectometry.model.model import Model
from easyreflectometry.model.model_collection import ModelCollection

Expand Down Expand Up @@ -52,6 +53,44 @@ def test_add_model(self):
assert collection[0].name == 'Model1'
assert collection[1].name == 'Model2'

def test_add_model_color_cycle(self):
collection = ModelCollection(populate_if_none=False)

collection.add_model()
assert collection[0].color == COLORS[0]

collection.add_model()
assert collection[1].color == COLORS[1]

collection.remove(0)
collection.add_model()

assert collection[0].color == COLORS[1]
assert collection[1].color == COLORS[2]

def test_add_model_color_wrap(self):
collection = ModelCollection(populate_if_none=False)

for _ in range(len(COLORS)):
collection.add_model()

collection.add_model()

assert collection[-1].color == COLORS[0]

def test_add_model_preserves_explicit_color(self):
collection = ModelCollection(populate_if_none=False)
collection.add_model()
expected_index = collection._next_color_index

custom_color = '#ABCDEF'
custom_model = Model(name='Custom', color=custom_color)

collection.add_model(custom_model)

assert collection[-1].color == custom_color
assert collection._next_color_index == (expected_index + 1) % len(COLORS)

def test_delete_model(self):
# When
model_1 = Model(name='Model1')
Expand Down Expand Up @@ -94,3 +133,27 @@ def test_dict_round_trip(self):
q.as_dict(skip=['resolution_function', 'interface'])
)
assert p[0]._resolution_function.smearing(5.5) == q[0]._resolution_function.smearing(5.5)

def test_next_color_index_round_trip(self):
collection = ModelCollection(populate_if_none=False)
for _ in range(3):
collection.add_model()

expected_index = collection._next_color_index
dict_repr = collection.as_dict()
global_object.map._clear()

restored = ModelCollection.from_dict(dict_repr)

assert restored._next_color_index == expected_index

def test_legacy_from_dict_sets_color_index(self):
collection = ModelCollection()
legacy_dict = collection.as_dict()
legacy_dict.pop('next_color_index', None)
global_object.map._clear()

restored = ModelCollection.from_dict(legacy_dict)
restored.add_model()

assert [model.color for model in restored] == [COLORS[0], COLORS[1]]
Loading