Skip to content

Commit

Permalink
moved duplicate cell selection code to App
Browse files Browse the repository at this point in the history
  • Loading branch information
nickdelgrosso committed Oct 6, 2021
1 parent e425944 commit 74c6b46
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 13 deletions.
5 changes: 2 additions & 3 deletions regexport/actions/save_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def register(self, model: AppState):
def submit(self, filename: Path, selected_regions_only: bool = False):
print('File saving...')

df = self.model.selected_cells if selected_regions_only else self.model.cells
types = {
'Image': 'category',
'BrainRegion': 'category',
Expand All @@ -28,10 +29,8 @@ def submit(self, filename: Path, selected_regions_only: bool = False):
}

types.update({col: 'uint16' for col in self.model.cells.columns if "Num Spots" in col})
df = self.model.cells.astype(types)
df = df.astype(types)
df: pd.DataFrame = df.drop(columns=['BGIdx'])
if selected_regions_only:
df = df.iloc[self.model.selected_cell_ids]

print(df.info())
print(filename)
Expand Down
3 changes: 3 additions & 0 deletions regexport/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class AppState(HasTraits):
column_to_plot = Unicode(default_value="BrainRegion")
colormap_options = List(Unicode(), default_value=[cmap for cmap in plt.colormaps() if not cmap.endswith('_r')])#['tab20c', 'viridis'])
selected_colormap = Unicode(default_value='tab20c')
selected_cells = Instance(pd.DataFrame, allow_none=True)

@observe('cells')
def _update_column_to_plot_options(self, change):
Expand All @@ -43,9 +44,11 @@ def _update_selected_cell_ids(self, change):
return
elif len(self.selected_region_ids) == 0:
self.selected_cell_ids = self.cells.index.values
self.selected_cells = self.cells
else:
is_parented = self.cells.groupby('BGIdx', as_index=False).BGIdx.transform(
lambda ids: is_parent(ids.values[0], selected_ids=self.selected_region_ids, tree=self.atlas.hierarchy) if ids.values[0] != 0 else False
)
only_parented = is_parented[is_parented.BGIdx].index.values
self.selected_cell_ids = only_parented
self.selected_cells = self.cells.iloc[only_parented]
8 changes: 4 additions & 4 deletions regexport/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ def __post_init__(self):


@warn_if_slow()
def plot_cells(positions: np.ndarray, colors: np.ndarray, indices: Tuple[int], cmap: str = 'tab20c') -> PointCloud:
def plot_cells(positions: np.ndarray, colors: np.ndarray, cmap: str = 'tab20c') -> PointCloud:
return PointCloud(
coords=positions[indices, :],
colors=(selected_colors := convert_values_to_colors(colors, getattr(plt.cm, cmap))[indices])[:, :3],
coords=positions,
colors=(selected_colors := convert_values_to_colors(colors, getattr(plt.cm, cmap)))[:, :3],
alphas=selected_colors[:, 3:4],
)


def convert_values_to_colors(color_codes: np.ndarray, cmap: ListedColormap):
return cmap(color_codes / color_codes.max())[:, :4]
return cmap(color_codes / color_codes.max())[:, :4]
11 changes: 5 additions & 6 deletions regexport/views/plot_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,18 @@ def register(self, model: AppState):
lambda atlas: Path(str(atlas.structures[997]['mesh_filename'])) if atlas is not None else None
)
model.observe(self.link_cells_to_points, names=[
'selected_cell_ids', 'cells', 'selected_colormap', 'column_to_plot',
'selected_cells', 'selected_colormap', 'column_to_plot',
])

def link_cells_to_points(self, change):
model = self.model
if model.cells is None:
if model.selected_cells is None:
self.points = PointCloud()
return
color_col = model.cells[model.column_to_plot]
color_col = model.selected_cells[model.column_to_plot]
points = plot_cells(
positions=model.cells[['X', 'Y', 'Z']].values * 1000,
positions=model.selected_cells[['X', 'Y', 'Z']].values * 1000,
colors=color_col.cat.codes.values if color_col.dtype.name == 'category' else color_col.values,
indices=model.selected_cell_ids if model.selected_cell_ids is not None else (),
cmap=self.model.selected_colormap
)
self.points = points
Expand Down Expand Up @@ -103,4 +102,4 @@ def render(self, change=None):

self.plotter.show(actors, at=0)
self.plotter.addInset(self._atlas_mesh, pos=(.9, .9), size=0.1, c='w', draggable=True)
# note: look at from vedo.applications import SlicerPlotter for inspiration
# note: look at from vedo.applications import SlicerPlotter for inspiration

0 comments on commit 74c6b46

Please sign in to comment.