Skip to content
This repository has been archived by the owner on Nov 26, 2023. It is now read-only.

Commit

Permalink
web widget features
Browse files Browse the repository at this point in the history
  • Loading branch information
kushalkolar committed Apr 9, 2021
1 parent 633503b commit e2d0e8f
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 15 deletions.
118 changes: 107 additions & 11 deletions mesmerize/plotting/web_widgets/datapoint_tracer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from bokeh.plotting import figure, Figure
from bokeh.models.glyphs import Image, MultiLine
from bokeh.models import HoverTool, ColumnDataSource, TapTool, Slider, TextInput, Select, BoxAnnotation
from bokeh.models import HoverTool, ColumnDataSource, TapTool, Slider, TextInput, Select, \
BoxAnnotation, Patches
from bokeh.models.mappers import LogColorMapper
from bokeh.layouts import gridplot, column, row
import os
Expand All @@ -13,6 +14,7 @@
from mesmerize.plotting.utils import auto_colormap, map_labels_to_colors
from mesmerize.plotting.web_widgets.core import BokehCallbackSignal, WebPlot
import logging
import pickle


logger = logging.getLogger()
Expand Down Expand Up @@ -64,6 +66,8 @@ def __init__(
# self.parent_document: Document = parent_document
self.project_path: Path = Path(project_path)

self.frame: np.ndarray = np.empty(0)

if image_figure_params is None:
image_figure_params = dict()

Expand All @@ -84,6 +88,21 @@ def __init__(
level="image"
)

self.roi_patches_glyph: Patches = self.image_figure.patches(
xs="xs",
ys="ys",
# color="colors",
color="#ffffff",
alpha=0.0,
line_width=2,
line_alpha=1.0,
source={
"xs": [[]],
"ys": [[]],
# "colors": ["#ffffff"],
},
)

self.image_figure.grid.grid_line_width = 0

self.curve_figure: Figure = None
Expand All @@ -107,6 +126,8 @@ def __init__(
if self.tooltip_columns is not None:
self.tooltips = [(col, f'@{col}') for col in tooltip_columns]

# self.datatable:

self.dataframe: pd.DataFrame = None
self.sample_id: str = None
self.img_uuid: UUID = None
Expand All @@ -123,13 +144,37 @@ def __init__(
self.curve_plot_bands_selector = Select(title="Bands based on:", value='', options=[''])
self.curve_plot_bands_selector.on_change('value', self.sig_plot_options_changed.trigger)

############################################################
# TEMPORARY
############################################################

# self.button_remove_selection = Button(label="Remove current selection")
# self.button_remove_selection.on_click(self.remove_sample)
############################################################
############################################################
############################################################

self.sig_plot_options_changed.connect(self.set_curve)

self.frame_slider = Slider(start=0, end=1000, value=1, step=10, title="Frame index:")
self.frame_slider.on_change('value', self.sig_frame_changed.trigger)
self.sig_frame_changed.connect(self._set_current_frame)

self.label_filesize: TextInput = TextInput(value='', title='Filesize (GB):')
self.label_sample_id: TextInput = TextInput(value='', title="SampleID:")

# def remove_sample(self):
# self.parent.dataframe = self.parent.dataframe[
# self.parent.dataframe['SampleID'] != self.sample_id
# ]
#
# sid = self.parent.dataframe['SampleID'].unique()[0]
#
# self.set_sample(
# self.parent.dataframe[self.parent.dataframe['SampleID'] == sid]
# )
#
# self.parent.update_glyph()

def _check_sample(self, dataframe: pd.DataFrame):
if len(dataframe['SampleID'].unique()) > 1:
Expand Down Expand Up @@ -159,25 +204,36 @@ def set_sample(self, dataframe: pd.DataFrame):

self.set_curve()

self.label_sample_id.update(value=self.sample_id)

def _set_video(self, vid_path: Union[Path, str]):
self.tif = tifffile.TiffFile(vid_path)

self.current_frame = 0
frame = self.tif.asarray(key=self.current_frame)
self.frame = self.tif.asarray(key=self.current_frame)

# this is basically used for vmin mvax
self.color_mapper = LogColorMapper(
palette=auto_colormap(256, 'gnuplot2', output='bokeh'),
low=np.nanmin(frame),
high=np.nanmax(frame)
low=np.nanmin(self.frame),
high=np.nanmax(self.frame)
)

self.image_glyph.data_source.data['image'] = [frame]
self.image_glyph.data_source.data['image'] = [self.frame]
self.image_glyph.glyph.color_mapper = self.color_mapper

# shows the file size in gigabytes
self.label_filesize.update(value=str(os.path.getsize(vid_path) / 1024 / 1024 / 1024))

def _get_roi_coors(self, r: pd.Series):
roi_type = r['roi_type']

if roi_type == 'ManualROI':
pos = r['roi_graphics_object_state']['pos']
points = r['roi_graphics_object_state']['points']

return points + np.array(pos)

# @WebPlot.signal_blocker
def _update_plot_options(self):
# categorical_columns = get_categorical_columns(self.dataframe)
Expand Down Expand Up @@ -215,11 +271,20 @@ def _update_plot_options(self):

def _set_current_frame(self, i: int):
self.current_frame = i
frame = self.tif.asarray(key=self.current_frame)
frame = self.tif.asarray(key=self.current_frame, maxworkers=20)

self.image_glyph.data_source.data['image'] = [frame]

def set_curve(self):
def _get_trimmed_dataframe(self) -> pd.DataFrame:
"""
Get dataframe for tooltips, JSON serializable.
"""
return self.dataframe.drop(
columns=[c for c in self.dataframe.columns if c not in self.tooltip_columns]
).copy(deep=True)

@WebPlot.signal_blocker
def set_curve(self, *args):
logger.debug('updating curve')
logger.debug(self.dataframe)
data_column = self.curve_data_selector.value
Expand All @@ -228,9 +293,7 @@ def set_curve(self):

self.frame_slider.update(start=0, end=ys[0].size - 1, value=0)

df = self.dataframe.drop(
columns=[c for c in self.dataframe.columns if c not in self.tooltip_columns]
).copy(deep=True)
df = self._get_trimmed_dataframe()

colors_column = self.curve_color_selector.value
ncolors = df[colors_column].unique().size
Expand Down Expand Up @@ -283,6 +346,36 @@ def set_curve(self):
source=src
)

# TODO: ROIs
# # set the ROIs
# p = pickle.load(
# open(
# os.path.join(
# self.parent.transmission.get_proj_path(),
# self.dataframe['ImgInfoPath'].iloc[0]
# ),
# 'rb'
# )
# )
#
# roi_coors = self.dataframe['ROI_State'].apply(self._get_roi_coors).values
#
# xs = [a[:, 0].tolist() for a in roi_coors]
# ys = [a[:, 1].tolist() for a in roi_coors]
#
# self.roi_patches_glyph.data_source.data['xs'] = xs
# self.roi_patches_glyph.data_source.data['ys'] = ys
# colors_list = self.curve_glyph.data_source.data['colors']
# if len(xs) != len(colors_list):
# colors_list = ["#ffffff"] * len(xs)
# else:
# self.roi_patches_glyph.data_source.data['colors'] = colors_list

self.image_glyph.glyph.dw = self.frame.shape[0]
self.image_glyph.glyph.dh = self.frame.shape[1]
self.image_glyph.glyph.x = 0
self.image_glyph.glyph.y = 0

# add the new curve plot to the doc root
self.doc.add_root(self.curve_figure)

Expand All @@ -299,7 +392,10 @@ def set_dashboard(self, figures: List[Figure]):
self.curve_color_selector,
self.curve_plot_bands_selector
),
self.label_filesize,
row(
self.label_sample_id,
self.label_filesize,
),
self.frame_slider
)
)
38 changes: 34 additions & 4 deletions mesmerize/plotting/web_widgets/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,16 @@ def __init__(
self.groupby_column = groupby_column

if source_columns is None:
source_columns = []
self.source_columns = []
else:
self.source_columns = source_columns

# ColumnDataSource is what bokeh uses for plotting
# it's similar to dataframes but doesn't accept
# some datatypes like dicts and arrays within dataframe "cells"
self.source: ColumnDataSource = ColumnDataSource(
self.dataframe.drop(
columns=[c for c in self.dataframe.columns if c not in source_columns]
columns=[c for c in self.dataframe.columns if c not in self.source_columns]
)
)

Expand Down Expand Up @@ -93,7 +95,9 @@ def __init__(
)

if glyph_opts is None:
glyph_opts = dict()
self.glyph_opts = dict()
else:
self.glyph_opts = glyph_opts

# jitter along the x axis for the swarm scatter
x_vals = jitter(self.groupby_column, width=0.6, range=self.figure.x_range)
Expand All @@ -105,13 +109,39 @@ def __init__(
source=self.source, # this is the ColumnDataSource created from the dataframe
**{
**_default_glyph_opts,
**glyph_opts
**self.glyph_opts
}
)

self.project_path = project_path
self.source_columns = source_columns

# def update_glyph(self):
# self.source: ColumnDataSource = ColumnDataSource(
# self.dataframe.drop(
# columns=[c for c in self.dataframe.columns if c not in self.source_columns]
# )
# )
#
# self.glyph.data_source = self.source

# self.glyph = self.figure.circle(
# x=jitter(self.groupby_column, width=0.6, range=self.figure.x_range),
# y=self.data_column, # the user specified data column
# source=self.source, # this is the ColumnDataSource created from the dataframe
# **{
# **_default_glyph_opts,
# **self.glyph_opts
# }
# )
#
# self.glyph.data_source.selected.on_change('indices', self.sig_point_selected.trigger)

# self.glyph.data_source = self.source
#
# self.glyph.data_source.data['x'] = jitter(self.groupby_column, width=0.6, range=self.figure.x_range)
# self.glyph.data_source.data['y'] = self.data_column,

def start_app(self, doc):
"""
Call this from ``bokeh.io.show()`` within a notebook to show the plot
Expand Down

0 comments on commit e2d0e8f

Please sign in to comment.