Skip to content

Commit

Permalink
refactored napari layer styles into dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Feb 20, 2024
1 parent 1e74f56 commit b9d0ad9
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 50 deletions.
76 changes: 76 additions & 0 deletions movement/napari/layer_styles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Dataclasses containing layer styles for napari."""

from dataclasses import dataclass, field
from typing import Optional

import numpy as np
import pandas as pd
from napari.utils.colormaps import ensure_colormap

DEFAULT_COLORMAP = "turbo"


@dataclass
class LayerStyle:
"""Base class for napari layer styles."""

name: str
properties: pd.DataFrame
visible: bool = True
blending: str = "translucent"

def as_kwargs(self) -> dict:
"""Return the style properties as a dictionary of kwargs."""
return self.__dict__


@dataclass
class PointsStyle(LayerStyle):
"""Style properties for a napari Points layer."""

name: str
properties: pd.DataFrame
visible: bool = True
blending: str = "translucent"
symbol: str = "disc"
size: int = 10
edge_width: int = 0
face_color: Optional[str] = None
face_color_cycle: Optional[list[tuple]] = None
face_colormap: str = DEFAULT_COLORMAP
text: dict = field(default_factory=lambda: {"visible": False})

@staticmethod
def _sample_colormap(n: int, cmap_name: str) -> list[tuple]:
"""Sample n equally-spaced colors from a napari colormap,
including the endpoints."""
cmap = ensure_colormap(cmap_name)
samples = np.linspace(0, len(cmap.colors) - 1, n).astype(int)
return [tuple(cmap.colors[i]) for i in samples]

def set_color_by(self, prop: str, cmap: str) -> None:
"""Set the face_color to a column in the properties DataFrame."""
self.face_color = prop
self.text["string"] = prop
n_colors = len(self.properties[prop].unique())
self.face_color_cycle = self._sample_colormap(n_colors, cmap)


@dataclass
class TracksStyle(LayerStyle):
"""Style properties for a napari Tracks layer."""

name: str
properties: pd.DataFrame
tail_width: int = 5
tail_length: int = 60
head_length: int = 0
color_by: str = "track_id"
colormap: str = DEFAULT_COLORMAP
visible: bool = True
blending: str = "translucent"

def set_color_by(self, prop: str, cmap: str) -> None:
"""Set the color_by to a column in the properties DataFrame."""
self.color_by = prop
self.colormap = cmap
82 changes: 32 additions & 50 deletions movement/napari/loader_widgets.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import logging
from pathlib import Path

import numpy as np
from napari.utils.colormaps import ensure_colormap
import pandas as pd
from napari.viewer import Viewer
from pandas.api.types import CategoricalDtype
from qtpy.QtWidgets import (
QComboBox,
QFileDialog,
Expand All @@ -18,16 +16,21 @@

from movement.io import load_poses
from movement.napari.convert import ds_to_napari_tracks
from movement.napari.layer_styles import PointsStyle, TracksStyle

logger = logging.getLogger(__name__)


def sample_colormap(n: int, cmap_name: str) -> list[tuple]:
"""Sample n equally-spaced colors from a napari colormap,
including the endpoints."""
cmap = ensure_colormap(cmap_name)
samples = np.linspace(0, len(cmap.colors) - 1, n).astype(int)
return [tuple(cmap.colors[i]) for i in samples]
def columns_to_categorical(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame:
"""Convert columns in a DataFrame to ordered categorical data type. The
categories are the unique values in the column, ordered by appearance."""
new_df = df.copy()
for col in cols:
cat_dtype = pd.api.types.CategoricalDtype(
categories=df[col].unique().tolist(), ordered=True
)
new_df[col] = df[col].astype(cat_dtype).cat.codes
return new_df


class FileLoader(QWidget):
Expand Down Expand Up @@ -120,49 +123,28 @@ def load_file(self, file_path):

def add_layers(self):
"""Add the predicted pose tracks and keypoints to the napari viewer."""

common_kwargs = {"visible": True, "blending": "translucent"}
n_individuals = len(self.props["individual"].unique())
color_by = "individual" if n_individuals > 1 else "keypoint"
n_colors = len(self.props[color_by].unique())

# kwargs for the napari Points layer
points_kwargs = {
**common_kwargs,
"name": f"Keypoints - {self.file_name}",
"properties": self.props,
"symbol": "disc",
"size": 10,
"edge_width": 0,
"face_color": color_by,
"face_color_cycle": sample_colormap(n_colors, "turbo"),
"face_colormap": "turbo",
"text": {"string": color_by, "visible": False},
}

# Modify properties for the napari Tracks layer
tracks_props = self.props.copy()

# Style properties for the napari Points layer
points_style = PointsStyle(
name=f"Keypoints - {self.file_name}",
properties=self.props,
)
points_style.set_color_by(prop=color_by, cmap="turbo")

# Track properties must be numeric, so convert str to categorical codes
for col in ["individual", "keypoint"]:
cat_dtype = CategoricalDtype(
categories=tracks_props[col].unique(), ordered=True
)
tracks_props[col] = tracks_props[col].astype(cat_dtype).cat.codes
tracks_props = columns_to_categorical(
self.props, ["individual", "keypoint"]
)

# kwargs for the napari Tracks layer
tracks_kwargs = {
**common_kwargs,
"name": f"Tracks - {self.file_name}",
"properties": tracks_props,
"tail_width": 5,
"tail_length": 60,
"head_length": 0,
"color_by": color_by,
"colormap": "turbo",
}

# Add the napari Tracks layer to the viewer
self.viewer.add_tracks(self.data, **tracks_kwargs)

# Add the napari Points layer to the viewer
self.viewer.add_points(self.data[:, 1:], **points_kwargs)
tracks_style = TracksStyle(
name=f"Tracks - {self.file_name}",
properties=tracks_props,
)
tracks_style.set_color_by(prop=color_by, cmap="turbo")

# Add the new layers to the napari viewer
self.viewer.add_tracks(self.data, **tracks_style.as_kwargs())
self.viewer.add_points(self.data[:, 1:], **points_style.as_kwargs())

0 comments on commit b9d0ad9

Please sign in to comment.