From 22592ffab04f615be79e50bd3d5dcc368ddb1d77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fe=CC=81lix=20Che=CC=81nier?= Date: Sat, 2 Mar 2024 16:18:49 -0500 Subject: [PATCH] Player: added default_point_color option. --- kineticstoolkit/player.py | 209 +++++++++++++++++++++----------------- tests/test_player.py | 2 + 2 files changed, 118 insertions(+), 93 deletions(-) diff --git a/kineticstoolkit/player.py b/kineticstoolkit/player.py index 64fc748..09fc26e 100644 --- a/kineticstoolkit/player.py +++ b/kineticstoolkit/player.py @@ -48,7 +48,16 @@ import scipy.optimize as optim REPR_HTML_MAX_DURATION = 10 # Max duration for _repr_html -COLORS = ["r", "g", "b", "y", "c", "m", "w"] +PALETTE = { + "r": (1.0, 0.0, 0.0), + "g": (0.0, 1.0, 0.0), + "b": (0.3, 0.3, 1.0), + "y": (1.0, 1.0, 0.0), + "m": (1.0, 0.0, 1.0), + "c": (0.0, 1.0, 1.0), + "w": (1.0, 1.0, 1.0), +} + HELP_TEXT = """ ktk.Player help @@ -77,7 +86,27 @@ """ -@typecheck +def _parse_color( + value: str | tuple[float, float, float] | ArrayLike +) -> tuple[float, float, float]: + """Convert a color specification into a tuple[float, float, float].""" + if isinstance(value, str): + try: + return PALETTE[value] + except KeyError: + raise ValueError( + f"The specified color '{value}' is not recognized." + ) + array_value = np.array(value) + if len(array_value) != 3: + raise ValueError("Color must be a character or an (R, G, B) tuple.") + return ( + float(array_value[0]), + float(array_value[1]), + float(array_value[2]), + ) + + class Player: # FIXME! Update this docstring. """ @@ -173,6 +202,9 @@ class Player: _oriented_frames: TimeSeries _interconnections: dict[str, dict[str, Any]] _extended_interconnections: dict[str, dict[str, Any]] + _colors: set[tuple[float, float, float]] # A list of all point colors + _selected_points: list[str] # List of point names + _last_selected_point: str _current_index: int _current_time: float _playback_speed: float @@ -188,6 +220,7 @@ class Player: _translation: np.ndarray _target: np.ndarray _track: bool + _default_point_color: tuple[float, float, float] _point_size: float _interconnection_width: float _frame_size: float @@ -217,6 +250,11 @@ def __init__( target: tuple[float, float, float] | ArrayLike = (0.0, 0.0, 0.0), perspective: bool = True, track: bool = False, + default_point_color: str | tuple[float, float, float] | ArrayLike = ( + 0.8, + 0.8, + 0.8, + ), point_size: float = 4.0, interconnection_width: float = 1.5, frame_size: float = 0.1, @@ -225,8 +263,12 @@ def __init__( grid_width: float = 1.0, grid_subdivision_size: float = 1.0, grid_origin: tuple[float, float, float] | ArrayLike = (0.0, 0.0, 0.0), - grid_color: tuple[float, float, float] | ArrayLike = (0.3, 0.3, 0.3), - background_color: tuple[float, float, float] | ArrayLike = ( + grid_color: str | tuple[float, float, float] | ArrayLike = ( + 0.3, + 0.3, + 0.3, + ), + background_color: str | tuple[float, float, float] | ArrayLike = ( 0.0, 0.0, 0.0, @@ -257,6 +299,9 @@ def __init__( self._interconnections = interconnections # Just to put stuff for now self._extended_interconnections = interconnections # idem + self._colors = set() # idem + self._selected_points = [] + self._last_selected_point = "" # Assign standard properties self.current_index = current_index @@ -275,6 +320,7 @@ def __init__( self.translation = translation self.target = target self.track = track + self.default_point_color = default_point_color self.point_size = point_size self.interconnection_width = interconnection_width self.frame_size = frame_size @@ -287,8 +333,6 @@ def __init__( self.background_color = background_color self.text_info = "" - self._select_none() - self.last_selected_point = "" self._running = False # Init mouse navigation state @@ -347,7 +391,7 @@ def current_index(self, value: int): if not self._being_constructed: if self.track is True and self._oriented_points is not None: new_target = self._oriented_points.data[ - self.last_selected_point + self._last_selected_point ][self.current_index] if not np.isnan(np.sum(new_target)): self.target = new_target @@ -520,6 +564,18 @@ def track(self, value: bool): if not self._being_constructed: self._fast_refresh() + @property + def default_point_color(self): + """Read/write default_point_color.""" + return self._default_point_color + + @default_point_color.setter + def default_point_color(self, value): + """Set default_point_color value.""" + self._default_point_color = _parse_color(value) + if not self._being_constructed: + self._refresh() + @property def point_size(self) -> float: """Read/write point_size.""" @@ -632,18 +688,7 @@ def grid_color(self): @grid_color.setter def grid_color(self, value): """Set grid_color value.""" - self._set_grid_color(value) - - def _set_grid_color(self, value: tuple[float, float, float] | ArrayLike): - """Workaround for having runtime static type checking.""" - array_value = np.array(value) - if len(array_value) != 3: - raise ValueError("grid_color must be an (R, G, B) tuple.") - self._grid_color = ( - float(array_value[0]), - float(array_value[1]), - float(array_value[2]), - ) + self._grid_color = _parse_color(value) if not self._being_constructed: self._update_grid() self._refresh() @@ -656,20 +701,7 @@ def background_color(self): @background_color.setter def background_color(self, value): """Set background_color value.""" - self._set_background_color(value) - - def _set_background_color( - self, value: tuple[float, float, float] | ArrayLike - ): - """Workaround for having runtime static type checking.""" - array_value = np.array(value) - if len(array_value) != 3: - raise ValueError("background_color must be an (R, G, B) tuple.") - self._background_color = ( - float(array_value[0]), - float(array_value[1]), - float(array_value[2]), - ) + self._background_color = _parse_color(value) if not self._being_constructed: self._refresh() @@ -1094,12 +1126,13 @@ def _update_points_and_interconnections(self) -> None: points_data = dict() # Used to draw the points with different colors interconnection_points = dict() # Used to draw the interconnections - for color in COLORS: - points_data[color] = np.empty([n_points, 4]) - points_data[color][:] = np.nan - - points_data[color + "s"] = np.empty([n_points, 4]) - points_data[color + "s"][:] = np.nan + for color in self._colors: + # Reset unselected points + points_data[(color, False)] = np.empty([n_points, 4]) + points_data[(color, False)][:] = np.nan + # Reset selected points + points_data[(color, True)] = np.empty([n_points, 4]) + points_data[(color, True)][:] = np.nan if n_points > 0: for i_point, point in enumerate(points.data): @@ -1110,27 +1143,35 @@ def _update_points_and_interconnections(self) -> None: ): color = points.data_info[point]["Color"] else: - color = "w" + color = self.default_point_color these_coordinates = points.data[point][self.current_index] - points_data[color][i_point] = these_coordinates interconnection_points[point] = these_coordinates + # Assign to unselected(False) or selected(True) points_data + if point in self._selected_points: + points_data[(color, True)][i_point] = these_coordinates + else: + points_data[(color, False)][i_point] = these_coordinates + # Update the points plot - for color in COLORS: + for color in self._colors: # Unselected points - points_data[color] = self._project_to_camera(points_data[color]) - self._mpl_objects["PointPlots"][color].set_data( - points_data[color][:, 0], points_data[color][:, 1] + points_data[(color, False)] = self._project_to_camera( + points_data[(color, False)] + ) + self._mpl_objects["PointPlots"][(color, False)].set_data( + points_data[(color, False)][:, 0], + points_data[(color, False)][:, 1], ) # Selected points - points_data[color + "s"] = self._project_to_camera( - points_data[color + "s"] + points_data[(color, True)] = self._project_to_camera( + points_data[(color, True)] ) - self._mpl_objects["PointPlots"][color + "s"].set_data( - points_data[color + "s"][:, 0], - points_data[color + "s"][:, 1], + self._mpl_objects["PointPlots"][(color, True)].set_data( + points_data[(color, True)][:, 0], + points_data[(color, True)][:, 1], ) # Draw the interconnections @@ -1290,25 +1331,28 @@ def _refresh(self): linewidth=self.frame_width, )[0] + # ---------------------- # Create the point plots - colors = { - "r": [1, 0, 0], - "g": [0, 1, 0], - "b": [0.3, 0.3, 1], - "y": [1, 1, 0], - "m": [1, 0, 1], - "c": [0, 1, 1], - "w": [0.8, 0.8, 0.8], - } - - for color in COLORS: - self._mpl_objects["PointPlots"][color] = self._mpl_objects[ - "Axes" - ].plot( + # ---------------------- + # List all colors in contents + self._colors = set() + for key in self._contents.data: + try: + color = self._contents.data_info[key]["Color"] + except KeyError: # Default color + color = self._default_point_color + self._colors.add(color) + + # Create all required point plots + for color in self._colors: + # Unselected points + self._mpl_objects["PointPlots"][ + (color, False) + ] = self._mpl_objects["Axes"].plot( np.nan, np.nan, ".", - c=colors[color], + c=color, markersize=self._point_size, pickradius=1.1 * self._point_size, picker=True, @@ -1316,13 +1360,14 @@ def _refresh(self): 0 ] - self._mpl_objects["PointPlots"][color + "s"] = self._mpl_objects[ + # Selected points + self._mpl_objects["PointPlots"][(color, True)] = self._mpl_objects[ "Axes" ].plot( np.nan, np.nan, ".", - c=colors[color], + c=color, markersize=3 * self._point_size, )[ 0 @@ -1383,24 +1428,6 @@ def error_function(input): self._zoom = initial_zoom self._target = initial_target - # ------------------------------------ - # Helper functions - def _select_none(self) -> None: - """Deselect every points.""" - if self._oriented_points is not None: - for point in self._oriented_points.data: - try: - # Keep 1st character, remove the possible 's' - self._oriented_points.data_info[point]["Color"] = ( - self._oriented_points.data_info[point]["Color"][0] - ) - except KeyError: - self._oriented_points = ( - self._oriented_points.add_data_info( - point, "Color", "w" - ) - ) - # ------------------------------------ # Callbacks def _on_close(self, _) -> None: # pragma: no cover @@ -1439,14 +1466,10 @@ def _on_pick(self, event): # pragma: no cover self.text_info = selected_point # Mark selected - self._select_none() - self._oriented_points.data_info[selected_point]["Color"] = ( - self._oriented_points.data_info[selected_point]["Color"][0] - + "s" - ) + self._selected_points = [selected_point] # Set as new target - self.last_selected_point = selected_point + self._last_selected_point = selected_point self._set_new_target( self._oriented_points.data[selected_point][self.current_index] ) @@ -1547,9 +1570,9 @@ def _on_scroll(self, event): # pragma: no cover self._fast_refresh() def _on_mouse_press(self, event): # pragma: no cover - if len(self.last_selected_point) > 0: + if len(self._last_selected_point) > 0: self._set_new_target( - self._oriented_points.data[self.last_selected_point][ + self._oriented_points.data[self._last_selected_point][ self.current_index ] ) diff --git a/tests/test_player.py b/tests/test_player.py index 2c3d7af..52d8311 100755 --- a/tests/test_player.py +++ b/tests/test_player.py @@ -296,6 +296,8 @@ def test_scripting(): p.perspective = True # %% Styling points and interconnections + p.default_point_color = "r" + p.default_point_color = [1, 0, 0] p.point_size = 8.0 p.interconnection_width = 5.0