Skip to content

Commit

Permalink
Player: added default_point_color option.
Browse files Browse the repository at this point in the history
  • Loading branch information
felixchenier committed Mar 2, 2024
1 parent d4cf489 commit 22592ff
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 93 deletions.
209 changes: 116 additions & 93 deletions kineticstoolkit/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,16 @@
import scipy.optimize as optim

REPR_HTML_MAX_DURATION = 10 # Max duration for _repr_html
COLORS = ["r", "g", "b", "y", "c", "m", "w"]
PALETTE = {
"r": (1.0, 0.0, 0.0),
"g": (0.0, 1.0, 0.0),
"b": (0.3, 0.3, 1.0),
"y": (1.0, 1.0, 0.0),
"m": (1.0, 0.0, 1.0),
"c": (0.0, 1.0, 1.0),
"w": (1.0, 1.0, 1.0),
}


HELP_TEXT = """
ktk.Player help
Expand Down Expand Up @@ -77,7 +86,27 @@
"""


@typecheck
def _parse_color(
value: str | tuple[float, float, float] | ArrayLike
) -> tuple[float, float, float]:
"""Convert a color specification into a tuple[float, float, float]."""
if isinstance(value, str):
try:
return PALETTE[value]
except KeyError:
raise ValueError(
f"The specified color '{value}' is not recognized."
)
array_value = np.array(value)
if len(array_value) != 3:
raise ValueError("Color must be a character or an (R, G, B) tuple.")
return (
float(array_value[0]),
float(array_value[1]),
float(array_value[2]),
)


class Player:
# FIXME! Update this docstring.
"""
Expand Down Expand Up @@ -173,6 +202,9 @@ class Player:
_oriented_frames: TimeSeries
_interconnections: dict[str, dict[str, Any]]
_extended_interconnections: dict[str, dict[str, Any]]
_colors: set[tuple[float, float, float]] # A list of all point colors
_selected_points: list[str] # List of point names
_last_selected_point: str
_current_index: int
_current_time: float
_playback_speed: float
Expand All @@ -188,6 +220,7 @@ class Player:
_translation: np.ndarray
_target: np.ndarray
_track: bool
_default_point_color: tuple[float, float, float]
_point_size: float
_interconnection_width: float
_frame_size: float
Expand Down Expand Up @@ -217,6 +250,11 @@ def __init__(
target: tuple[float, float, float] | ArrayLike = (0.0, 0.0, 0.0),
perspective: bool = True,
track: bool = False,
default_point_color: str | tuple[float, float, float] | ArrayLike = (
0.8,
0.8,
0.8,
),
point_size: float = 4.0,
interconnection_width: float = 1.5,
frame_size: float = 0.1,
Expand All @@ -225,8 +263,12 @@ def __init__(
grid_width: float = 1.0,
grid_subdivision_size: float = 1.0,
grid_origin: tuple[float, float, float] | ArrayLike = (0.0, 0.0, 0.0),
grid_color: tuple[float, float, float] | ArrayLike = (0.3, 0.3, 0.3),
background_color: tuple[float, float, float] | ArrayLike = (
grid_color: str | tuple[float, float, float] | ArrayLike = (
0.3,
0.3,
0.3,
),
background_color: str | tuple[float, float, float] | ArrayLike = (
0.0,
0.0,
0.0,
Expand Down Expand Up @@ -257,6 +299,9 @@ def __init__(

self._interconnections = interconnections # Just to put stuff for now
self._extended_interconnections = interconnections # idem
self._colors = set() # idem
self._selected_points = []
self._last_selected_point = ""

# Assign standard properties
self.current_index = current_index
Expand All @@ -275,6 +320,7 @@ def __init__(
self.translation = translation
self.target = target
self.track = track
self.default_point_color = default_point_color
self.point_size = point_size
self.interconnection_width = interconnection_width
self.frame_size = frame_size
Expand All @@ -287,8 +333,6 @@ def __init__(
self.background_color = background_color
self.text_info = ""

self._select_none()
self.last_selected_point = ""
self._running = False

# Init mouse navigation state
Expand Down Expand Up @@ -347,7 +391,7 @@ def current_index(self, value: int):
if not self._being_constructed:
if self.track is True and self._oriented_points is not None:
new_target = self._oriented_points.data[
self.last_selected_point
self._last_selected_point
][self.current_index]
if not np.isnan(np.sum(new_target)):
self.target = new_target
Expand Down Expand Up @@ -520,6 +564,18 @@ def track(self, value: bool):
if not self._being_constructed:
self._fast_refresh()

@property
def default_point_color(self):
"""Read/write default_point_color."""
return self._default_point_color

@default_point_color.setter
def default_point_color(self, value):
"""Set default_point_color value."""
self._default_point_color = _parse_color(value)
if not self._being_constructed:
self._refresh()

@property
def point_size(self) -> float:
"""Read/write point_size."""
Expand Down Expand Up @@ -632,18 +688,7 @@ def grid_color(self):
@grid_color.setter
def grid_color(self, value):
"""Set grid_color value."""
self._set_grid_color(value)

def _set_grid_color(self, value: tuple[float, float, float] | ArrayLike):
"""Workaround for having runtime static type checking."""
array_value = np.array(value)
if len(array_value) != 3:
raise ValueError("grid_color must be an (R, G, B) tuple.")
self._grid_color = (
float(array_value[0]),
float(array_value[1]),
float(array_value[2]),
)
self._grid_color = _parse_color(value)
if not self._being_constructed:
self._update_grid()
self._refresh()
Expand All @@ -656,20 +701,7 @@ def background_color(self):
@background_color.setter
def background_color(self, value):
"""Set background_color value."""
self._set_background_color(value)

def _set_background_color(
self, value: tuple[float, float, float] | ArrayLike
):
"""Workaround for having runtime static type checking."""
array_value = np.array(value)
if len(array_value) != 3:
raise ValueError("background_color must be an (R, G, B) tuple.")
self._background_color = (
float(array_value[0]),
float(array_value[1]),
float(array_value[2]),
)
self._background_color = _parse_color(value)
if not self._being_constructed:
self._refresh()

Expand Down Expand Up @@ -1094,12 +1126,13 @@ def _update_points_and_interconnections(self) -> None:
points_data = dict() # Used to draw the points with different colors
interconnection_points = dict() # Used to draw the interconnections

for color in COLORS:
points_data[color] = np.empty([n_points, 4])
points_data[color][:] = np.nan

points_data[color + "s"] = np.empty([n_points, 4])
points_data[color + "s"][:] = np.nan
for color in self._colors:
# Reset unselected points
points_data[(color, False)] = np.empty([n_points, 4])
points_data[(color, False)][:] = np.nan
# Reset selected points
points_data[(color, True)] = np.empty([n_points, 4])
points_data[(color, True)][:] = np.nan

if n_points > 0:
for i_point, point in enumerate(points.data):
Expand All @@ -1110,27 +1143,35 @@ def _update_points_and_interconnections(self) -> None:
):
color = points.data_info[point]["Color"]
else:
color = "w"
color = self.default_point_color

these_coordinates = points.data[point][self.current_index]
points_data[color][i_point] = these_coordinates
interconnection_points[point] = these_coordinates

# Assign to unselected(False) or selected(True) points_data
if point in self._selected_points:
points_data[(color, True)][i_point] = these_coordinates
else:
points_data[(color, False)][i_point] = these_coordinates

# Update the points plot
for color in COLORS:
for color in self._colors:
# Unselected points
points_data[color] = self._project_to_camera(points_data[color])
self._mpl_objects["PointPlots"][color].set_data(
points_data[color][:, 0], points_data[color][:, 1]
points_data[(color, False)] = self._project_to_camera(
points_data[(color, False)]
)
self._mpl_objects["PointPlots"][(color, False)].set_data(
points_data[(color, False)][:, 0],
points_data[(color, False)][:, 1],
)

# Selected points
points_data[color + "s"] = self._project_to_camera(
points_data[color + "s"]
points_data[(color, True)] = self._project_to_camera(
points_data[(color, True)]
)
self._mpl_objects["PointPlots"][color + "s"].set_data(
points_data[color + "s"][:, 0],
points_data[color + "s"][:, 1],
self._mpl_objects["PointPlots"][(color, True)].set_data(
points_data[(color, True)][:, 0],
points_data[(color, True)][:, 1],
)

# Draw the interconnections
Expand Down Expand Up @@ -1290,39 +1331,43 @@ def _refresh(self):
linewidth=self.frame_width,
)[0]

# ----------------------
# Create the point plots
colors = {
"r": [1, 0, 0],
"g": [0, 1, 0],
"b": [0.3, 0.3, 1],
"y": [1, 1, 0],
"m": [1, 0, 1],
"c": [0, 1, 1],
"w": [0.8, 0.8, 0.8],
}

for color in COLORS:
self._mpl_objects["PointPlots"][color] = self._mpl_objects[
"Axes"
].plot(
# ----------------------
# List all colors in contents
self._colors = set()
for key in self._contents.data:
try:
color = self._contents.data_info[key]["Color"]
except KeyError: # Default color
color = self._default_point_color
self._colors.add(color)

# Create all required point plots
for color in self._colors:
# Unselected points
self._mpl_objects["PointPlots"][
(color, False)
] = self._mpl_objects["Axes"].plot(
np.nan,
np.nan,
".",
c=colors[color],
c=color,
markersize=self._point_size,
pickradius=1.1 * self._point_size,
picker=True,
)[
0
]

self._mpl_objects["PointPlots"][color + "s"] = self._mpl_objects[
# Selected points
self._mpl_objects["PointPlots"][(color, True)] = self._mpl_objects[
"Axes"
].plot(
np.nan,
np.nan,
".",
c=colors[color],
c=color,
markersize=3 * self._point_size,
)[
0
Expand Down Expand Up @@ -1383,24 +1428,6 @@ def error_function(input):
self._zoom = initial_zoom
self._target = initial_target

# ------------------------------------
# Helper functions
def _select_none(self) -> None:
"""Deselect every points."""
if self._oriented_points is not None:
for point in self._oriented_points.data:
try:
# Keep 1st character, remove the possible 's'
self._oriented_points.data_info[point]["Color"] = (
self._oriented_points.data_info[point]["Color"][0]
)
except KeyError:
self._oriented_points = (
self._oriented_points.add_data_info(
point, "Color", "w"
)
)

# ------------------------------------
# Callbacks
def _on_close(self, _) -> None: # pragma: no cover
Expand Down Expand Up @@ -1439,14 +1466,10 @@ def _on_pick(self, event): # pragma: no cover
self.text_info = selected_point

# Mark selected
self._select_none()
self._oriented_points.data_info[selected_point]["Color"] = (
self._oriented_points.data_info[selected_point]["Color"][0]
+ "s"
)
self._selected_points = [selected_point]

# Set as new target
self.last_selected_point = selected_point
self._last_selected_point = selected_point
self._set_new_target(
self._oriented_points.data[selected_point][self.current_index]
)
Expand Down Expand Up @@ -1547,9 +1570,9 @@ def _on_scroll(self, event): # pragma: no cover
self._fast_refresh()

def _on_mouse_press(self, event): # pragma: no cover
if len(self.last_selected_point) > 0:
if len(self._last_selected_point) > 0:
self._set_new_target(
self._oriented_points.data[self.last_selected_point][
self._oriented_points.data[self._last_selected_point][
self.current_index
]
)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_player.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def test_scripting():
p.perspective = True

# %% Styling points and interconnections
p.default_point_color = "r"
p.default_point_color = [1, 0, 0]
p.point_size = 8.0
p.interconnection_width = 5.0

Expand Down

0 comments on commit 22592ff

Please sign in to comment.