Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure multiple callbacks do not bleed wrong plot state #1034

Merged
merged 18 commits into from
Jan 7, 2017
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 77 additions & 18 deletions holoviews/plotting/bokeh/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..comms import JupyterCommJS


def attributes_js(attributes):
def attributes_js(attributes, handles):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to see the docstring updated with an example showing how handles is involved...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring still needs to be updated I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, still need to do that.

Generates JS code to look up attributes on JS objects from
an attributes specification dictionary.
Expand All @@ -28,7 +28,16 @@ def attributes_js(attributes):
obj_name = attrs[0]
attr_getters = ''.join(["['{attr}']".format(attr=attr)
for attr in attrs[1:]])
code += ''.join([data_assign, obj_name, attr_getters, ';\n'])
if obj_name not in ['cb_obj', 'cb_data']:
assign_str = '{assign}{{id: {obj_name}["id"], value: {obj_name}{attr_getters}}};\n'.format(
assign=data_assign, obj_name=obj_name, attr_getters=attr_getters
)
code += 'if (({obj_name} != undefined) && ({obj_name}["id"] == "{id}")) {{ {assign} }}'.format(
obj_name=obj_name, id=handles[obj_name].ref['id'], assign=assign_str
)
else:
assign_str = ''.join([data_assign, obj_name, attr_getters, ';\n'])
code += assign_str
return code


Expand Down Expand Up @@ -76,14 +85,20 @@ class Callback(object):
js_callback = """
function on_msg(msg){{
msg = JSON.parse(msg.content.data);
var comm = HoloViewsWidget.comms["{comms_target}"];
var comm_state = HoloViewsWidget.comm_state["{comms_target}"];
if ("comms_target" in msg) {{
comms_target = msg["comms_target"]
}} else {{
comms_target = "{comms_target}"
}}
var comm = HoloViewsWidget.comms[comms_target];
var comm_state = HoloViewsWidget.comm_state[comms_target];
if (comm_state.event) {{
comm.send(comm_state.event);
comm_state.blocked = true;
comm_state.timeout = Date.now()+{debounce};
}} else {{
comm_state.blocked = false;
}}
comm_state.timeout = Date.now();
comm_state.event = undefined;
if ((msg.msg_type == "Ready") && msg.content) {{
console.log("Python callback returned following output:", msg.content);
Expand All @@ -92,6 +107,7 @@ class Callback(object):
}}
}}

data['comms_target'] = "{comms_target}";
var argstring = JSON.stringify(data);
if ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel !== undefined)) {{
var comm_manager = Jupyter.notebook.kernel.comm_manager;
Expand Down Expand Up @@ -151,6 +167,7 @@ def __init__(self, plot, streams, source, **params):
self.plot = plot
self.streams = streams
self.comm = self._comm_type(plot, on_msg=self.on_msg)
self.stream_handles = defaultdict(list)
self.source = source


Expand All @@ -171,12 +188,23 @@ def initialize(self):


def on_msg(self, msg):
msg = json.loads(msg)
msg = self._process_msg(msg)
if any(v is None for v in msg.values()):
return
# For each stream check whether plot state is meant for it
# by checking that the IDs match the IDs of the stream's plot
# handles, dispatch only the part of the message meant for
# a particular stream
for stream in self.streams:
stream.update(trigger=False, **msg)
ids = self.stream_handles[stream]
sanitized_msg = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this bit sounds like a 'message filter' (as opposed to sanitization) and could be its own method (with docstring). Something like message_filter(msg, ids)...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sounds good.

for k, v in msg.items():
if isinstance(v, dict) and 'id' in v:
if v['id'] in ids:
sanitized_msg[k] = v['value']
else:
sanitized_msg[k] = v
processed_msg = self._process_msg(sanitized_msg)
if not processed_msg:
continue
stream.update(trigger=False, **processed_msg)
Stream.trigger(self.streams)


Expand All @@ -195,23 +223,38 @@ def set_customjs(self, handle):
self_callback = self.js_callback.format(comms_target=self.comm.target,
timeout=self.timeout,
debounce=self.debounce)
attributes = attributes_js(self.attributes)
code = 'var data = {};\n' + attributes + self.code + self_callback

handles = {}
subplots = list(self.plot.subplots.values())[::-1] if self.plot.subplots else []
plots = [self.plot] + subplots
for plot in plots:
handles.update({k: v for k, v in plot.handles.items()
if k in self.handles})

attributes = attributes_js(self.attributes, handles)
code = 'var data = {};\n' + attributes + self.code + self_callback

# Gather the ids of the plotting handles attached to this callback
# This allows checking that a stream is not given the state
# of a plotting handle it wasn't attached to
stream_handle_ids = defaultdict(list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to turn this bit of code into a method (get_handle_ids?) and turn the comment into a docstring.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

for stream in self.streams:
for h in self.handles:
if h in handles:
handle_id = handles[h].ref['id']
stream_handle_ids[stream].append(handle_id)

# Set callback
if id(handle.callback) in self._callbacks:
cb = self._callbacks[id(handle.callback)]
if isinstance(cb, type(self)):
cb.streams += self.streams
for k, v in stream_handle_ids.items():
cb.stream_handles[k] += v
else:
handle.callback.code += code
else:
self.stream_handles.update(stream_handle_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the contents of self.stream_handles can keep growing between set_customjs calls with no way to reset/clear it? Is this more state that hangs around when a visualization is removed (e.g a notebook cell is deleted?).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it'll stick around but compared to all the plotting state I'm not worried about a few IDs.

js_callback = CustomJS(args=handles, code=code)
self._callbacks[id(js_callback)] = self
handle.callback = js_callback
Expand Down Expand Up @@ -249,8 +292,12 @@ class RangeXYCallback(Callback):
handles = ['x_range', 'y_range']

def _process_msg(self, msg):
return {'x_range': (msg['x0'], msg['x1']),
'y_range': (msg['y0'], msg['y1'])}
data = {}
if 'x0' in msg and 'x1' in msg:
data['x_range'] = (msg['x0'], msg['x1'])
if 'y0' in msg and 'y1' in msg:
data['y_range'] = (msg['y0'], msg['y1'])
return data


class RangeXCallback(Callback):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the new bits of code below seems to have the form:

if predicate(msg):
   return dictionary_from_msg(msg)
else:
  return {}

If you agree this is a general pattern, maybe we can just have an applicable predicate method with the baseclass checking the predicate value. E.g

class RangeXCallback(Callback):
      handles = ['x_range']

     def applicable(msg):
          return 'x0' in msg and 'x1' in msg
  
      def _process_msg(self, msg):
         return {'x_range': (msg['x0'], msg['x1'])}

And of course applicable can return True by default.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if there are multiple predicates, such as in RangeXY? I'd prefer not to complicate this for now, although I do agree something like this is worth considering.

Copy link
Contributor

@jlstevens jlstevens Jan 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RangeXY is really the union of RangeX and RangeY. We might want to consider generalizing this idea of a union so we can build things like RangeXY out of the component pieces.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you agree with this suggestion, maybe it should be made into a new issue (feature request)? I don't think it would be hard to implement later.

Expand All @@ -261,7 +308,10 @@ class RangeXCallback(Callback):
handles = ['x_range']

def _process_msg(self, msg):
return {'x_range': (msg['x0'], msg['x1'])}
if 'x0' in msg and 'x1' in msg:
return {'x_range': (msg['x0'], msg['x1'])}
else:
return {}


class RangeYCallback(Callback):
Expand All @@ -272,7 +322,10 @@ class RangeYCallback(Callback):
handles = ['y_range']

def _process_msg(self, msg):
return {'y_range': (msg['y0'], msg['y1'])}
if 'y0' in msg and 'y1' in msg:
return {'y_range': (msg['y0'], msg['y1'])}
else:
return {}


class BoundsCallback(Callback):
Expand All @@ -285,7 +338,10 @@ class BoundsCallback(Callback):
handles = ['box_select']

def _process_msg(self, msg):
return {'bounds': (msg['x0'], msg['y0'], msg['x1'], msg['y1'])}
if all(c in msg for c in ['x0', 'y0', 'x1', 'y1']):
return {'bounds': (msg['x0'], msg['y0'], msg['x1'], msg['y1'])}
else:
return {}


class Selection1DCallback(Callback):
Expand All @@ -295,7 +351,10 @@ class Selection1DCallback(Callback):
handles = ['source']

def _process_msg(self, msg):
return {'index': [int(v) for v in msg['index']]}
if 'index' in msg:
return {'index': [int(v) for v in msg['index']]}
else:
return {}


callbacks = Stream._callbacks['bokeh']
Expand Down
22 changes: 18 additions & 4 deletions holoviews/plotting/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,16 @@ def _handle_msg(self, msg):
if stdout:
stdout = '\n\t'+'\n\t'.join(stdout)
error = '\n'.join([stdout, error])
msg = {'msg_type': "Error", 'traceback': error}
reply = {'msg_type': "Error", 'traceback': error}
else:
stdout = '\n\t'+'\n\t'.join(stdout) if stdout else ''
msg = {'msg_type': "Ready", 'content': stdout}
self.comm.send(json.dumps(msg))
reply = {'msg_type': "Ready", 'content': stdout}

# Returning the comms_target in an ACK message ensures that
# the correct comms handle is unblocked
if 'comms_target' in msg:
reply['comms_target'] = msg.pop('comms_target', None)
self.send(json.dumps(reply))


class JupyterComm(Comm):
Expand Down Expand Up @@ -154,7 +159,16 @@ def init(self):

@classmethod
def decode(cls, msg):
return msg['content']['data']
"""
Decodes messages following Jupyter messaging protocol.
If JSON decoding fails data is assumed to be a regular string.
"""
data = msg['content']['data']
try:
data = json.loads(data)
except ValueError:
pass
return data


def send(self, data):
Expand Down
31 changes: 26 additions & 5 deletions tests/testcomms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from nose.tools import *

from holoviews.element.comparison import ComparisonTestCase
Expand All @@ -17,13 +18,33 @@ def test_decode(self):
msg = 'Test'
self.assertEqual(Comm.decode(msg), msg)

def test_on_msg(self):
def test_handle_message_error_reply(self):
def raise_error(msg):
if msg == 'Error':
raise Exception()
raise Exception('Test')
def assert_error(msg):
decoded = json.loads(msg)
self.assertEqual(decoded['msg_type'], "Error")
self.assertTrue(decoded['traceback'].endswith('Exception: Test'))
comm = Comm(None, target='Test', on_msg=raise_error)
with self.assertRaises(Exception):
comm._handle_msg('Error')
comm.send = assert_error
comm._handle_msg({})

def test_handle_message_ready_reply(self):
def assert_ready(msg):
self.assertEqual(json.loads(msg), {'msg_type': "Ready", 'content': ''})
comm = Comm(None, target='Test')
comm.send = assert_ready
comm._handle_msg({})

def test_handle_message_ready_reply_with_comms_target(self):
def assert_ready(msg):
decoded = json.loads(msg)
self.assertEqual(decoded, {'msg_type': "Ready", 'content': '',
'comms_target': 'Testing target'})
comm = Comm(None, target='Test')
comm.send = assert_ready
comm._handle_msg({'comms_target': 'Testing target'})



class TestJupyterComm(ComparisonTestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/testplotinstantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def test_stream_callback(self):
dmap = DynamicMap(lambda x, y: Points([(x, y)]), kdims=[], streams=[PositionXY()])
plot = bokeh_renderer.get_plot(dmap)
bokeh_renderer(plot)
plot.callbacks[0].on_msg('{"x": 10, "y": -10}')
plot.callbacks[0].on_msg({"x": 10, "y": -10})
data = plot.handles['source'].data
self.assertEqual(data['x'], np.array([10]))
self.assertEqual(data['y'], np.array([-10]))
Expand Down