From 74c6b4652503165ea2ab4c368468835745bf7906 Mon Sep 17 00:00:00 2001 From: "Nicholas A. Del Grosso" Date: Wed, 6 Oct 2021 17:38:47 +0200 Subject: [PATCH] moved duplicate cell selection code to App --- regexport/actions/save_cells.py | 5 ++--- regexport/model.py | 3 +++ regexport/utils/plotting.py | 8 ++++---- regexport/views/plot_3d.py | 11 +++++------ 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/regexport/actions/save_cells.py b/regexport/actions/save_cells.py index 586c208..d19ce87 100644 --- a/regexport/actions/save_cells.py +++ b/regexport/actions/save_cells.py @@ -18,6 +18,7 @@ def register(self, model: AppState): def submit(self, filename: Path, selected_regions_only: bool = False): print('File saving...') + df = self.model.selected_cells if selected_regions_only else self.model.cells types = { 'Image': 'category', 'BrainRegion': 'category', @@ -28,10 +29,8 @@ def submit(self, filename: Path, selected_regions_only: bool = False): } types.update({col: 'uint16' for col in self.model.cells.columns if "Num Spots" in col}) - df = self.model.cells.astype(types) + df = df.astype(types) df: pd.DataFrame = df.drop(columns=['BGIdx']) - if selected_regions_only: - df = df.iloc[self.model.selected_cell_ids] print(df.info()) print(filename) diff --git a/regexport/model.py b/regexport/model.py index 301e12f..e075bb2 100644 --- a/regexport/model.py +++ b/regexport/model.py @@ -23,6 +23,7 @@ class AppState(HasTraits): column_to_plot = Unicode(default_value="BrainRegion") colormap_options = List(Unicode(), default_value=[cmap for cmap in plt.colormaps() if not cmap.endswith('_r')])#['tab20c', 'viridis']) selected_colormap = Unicode(default_value='tab20c') + selected_cells = Instance(pd.DataFrame, allow_none=True) @observe('cells') def _update_column_to_plot_options(self, change): @@ -43,9 +44,11 @@ def _update_selected_cell_ids(self, change): return elif len(self.selected_region_ids) == 0: self.selected_cell_ids = self.cells.index.values + self.selected_cells = self.cells else: is_parented = self.cells.groupby('BGIdx', as_index=False).BGIdx.transform( lambda ids: is_parent(ids.values[0], selected_ids=self.selected_region_ids, tree=self.atlas.hierarchy) if ids.values[0] != 0 else False ) only_parented = is_parented[is_parented.BGIdx].index.values self.selected_cell_ids = only_parented + self.selected_cells = self.cells.iloc[only_parented] diff --git a/regexport/utils/plotting.py b/regexport/utils/plotting.py index 767a865..21b5c25 100644 --- a/regexport/utils/plotting.py +++ b/regexport/utils/plotting.py @@ -21,13 +21,13 @@ def __post_init__(self): @warn_if_slow() -def plot_cells(positions: np.ndarray, colors: np.ndarray, indices: Tuple[int], cmap: str = 'tab20c') -> PointCloud: +def plot_cells(positions: np.ndarray, colors: np.ndarray, cmap: str = 'tab20c') -> PointCloud: return PointCloud( - coords=positions[indices, :], - colors=(selected_colors := convert_values_to_colors(colors, getattr(plt.cm, cmap))[indices])[:, :3], + coords=positions, + colors=(selected_colors := convert_values_to_colors(colors, getattr(plt.cm, cmap)))[:, :3], alphas=selected_colors[:, 3:4], ) def convert_values_to_colors(color_codes: np.ndarray, cmap: ListedColormap): - return cmap(color_codes / color_codes.max())[:, :4] \ No newline at end of file + return cmap(color_codes / color_codes.max())[:, :4] diff --git a/regexport/views/plot_3d.py b/regexport/views/plot_3d.py index e7f18b0..d2b87ac 100644 --- a/regexport/views/plot_3d.py +++ b/regexport/views/plot_3d.py @@ -27,19 +27,18 @@ def register(self, model: AppState): lambda atlas: Path(str(atlas.structures[997]['mesh_filename'])) if atlas is not None else None ) model.observe(self.link_cells_to_points, names=[ - 'selected_cell_ids', 'cells', 'selected_colormap', 'column_to_plot', + 'selected_cells', 'selected_colormap', 'column_to_plot', ]) def link_cells_to_points(self, change): model = self.model - if model.cells is None: + if model.selected_cells is None: self.points = PointCloud() return - color_col = model.cells[model.column_to_plot] + color_col = model.selected_cells[model.column_to_plot] points = plot_cells( - positions=model.cells[['X', 'Y', 'Z']].values * 1000, + positions=model.selected_cells[['X', 'Y', 'Z']].values * 1000, colors=color_col.cat.codes.values if color_col.dtype.name == 'category' else color_col.values, - indices=model.selected_cell_ids if model.selected_cell_ids is not None else (), cmap=self.model.selected_colormap ) self.points = points @@ -103,4 +102,4 @@ def render(self, change=None): self.plotter.show(actors, at=0) self.plotter.addInset(self._atlas_mesh, pos=(.9, .9), size=0.1, c='w', draggable=True) - # note: look at from vedo.applications import SlicerPlotter for inspiration \ No newline at end of file + # note: look at from vedo.applications import SlicerPlotter for inspiration