Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions jupyter_ydoc/ynotebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# The default minor version of the notebook format.
NBFORMAT_MINOR_VERSION = 5

_CELL_KEY_TYPE_MAP = {"metadata": Map, "source": Text, "outputs": Array}


class YNotebook(YBaseDoc):
"""
Expand Down Expand Up @@ -249,7 +251,7 @@ def set(self, value: dict) -> None:
"id": str(uuid4()),
}
]
old_ycells_by_id = {ycell["id"]: ycell for ycell in self._ycells}
old_ycells_by_id: dict[str, Map] = {ycell["id"]: ycell for ycell in self._ycells}

with self._ydoc.transaction():
new_cell_list: list[dict] = []
Expand All @@ -260,7 +262,11 @@ def set(self, value: dict) -> None:
cell_id = new_cell.get("id")
if cell_id and (old_ycell := old_ycells_by_id.get(cell_id)):
old_cell = self._cell_to_py(old_ycell)
if old_cell == new_cell:
updated_granularly = self._update_cell(
old_cell=old_cell, new_cell=new_cell, old_ycell=old_ycell
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just pass old_ycell, since old_cell can be created from it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about doing that, but we are still using old_cell in line 270 below

)

if updated_granularly:
new_cell_list.append(old_cell)
retained_cells.add(cell_id)
continue
Expand Down Expand Up @@ -324,3 +330,57 @@ def observe(self, callback: Callable[[str, Any], None]) -> None:
self._subscriptions[self._ystate] = self._ystate.observe(partial(callback, "state"))
self._subscriptions[self._ymeta] = self._ymeta.observe_deep(partial(callback, "meta"))
self._subscriptions[self._ycells] = self._ycells.observe_deep(partial(callback, "cells"))

def _update_cell(self, old_cell: dict, new_cell: dict, old_ycell: Map) -> bool:
if old_cell == new_cell:
return True
# attempt to update cell granularly
old_keys = set(old_cell.keys())
new_keys = set(new_cell.keys())

shared_keys = old_keys & new_keys
removed_keys = old_keys - new_keys
added_keys = new_keys - old_keys

for key in shared_keys:
if old_cell[key] != new_cell[key]:
value = new_cell[key]
if key == "output" and value:
# outputs require complex handling - some have Text type nested;
# for now skip creating them; clearing all outputs is fine
return False

if key in _CELL_KEY_TYPE_MAP:
kind = _CELL_KEY_TYPE_MAP[key]

if not isinstance(old_ycell[key], kind):
# if our assumptions about types do not hold, fall back to hard update
return False

if kind == Text:
old: Text = old_ycell[key]
old.clear()
old += value
elif kind == Array:
old: Array = old_ycell[key]
old.clear()
old.extend(value)
elif kind == Map:
old: Map = old_ycell[key]
old.clear()
old.update(value)
else:
old_ycell[key] = new_cell[key]

for key in removed_keys:
del old_ycell[key]

for key in added_keys:
if key in _CELL_KEY_TYPE_MAP:
# we hard-reload cells when keys that require nested types get added
# to allow the frontend to connect observers; this could be changed
# in the future, once frontends learn how to observe all changes
return False
else:
old_ycell[key] = new_cell[key]
return True
67 changes: 66 additions & 1 deletion tests/test_ynotebook.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

from pycrdt import Map

from pycrdt import ArrayEvent, Map, MapEvent, TextEvent
from pytest import mark

from jupyter_ydoc import YNotebook

Expand Down Expand Up @@ -114,3 +116,66 @@ def record_changes(topic, event):
{"delete": 1},
{"insert": [AnyInstanceOf(Map)]},
]


@mark.parametrize(
"modifications, expected_events",
[
# modifications of single attributes
([["source", "'b'"]], {TextEvent}),
([["outputs", []]], {ArrayEvent}),
([["execution_count", 2]], {MapEvent}),
([["metadata", {"tags": []}]], {MapEvent}),
([["new_key", "test"]], {MapEvent}),
# multi-attribute modifications
([["source", "10"], ["execution_count", 10]], {TextEvent, MapEvent}),
],
)
def test_modify_single_cell(modifications, expected_events):
nb = YNotebook()
nb.set(
{
"cells": [
{
"id": "8800f7d8-6cad-42ef-a339-a9c185ffdd54",
"cell_type": "code",
"source": "'a'",
"metadata": {"tags": ["test-tag"]},
"outputs": [{"name": "stdout", "output_type": "stream", "text": ["a\n"]}],
"execution_count": 1,
},
]
}
)

# Get the model as Python object
model = nb.get()

# Make changes
for modification in modifications:
key, new_value = modification
model["cells"][0][key] = new_value

changes = []

def record_changes(topic, event):
changes.append((topic, event))

nb.observe(record_changes)
nb.set(model)

for modification in modifications:
key, new_value = modification
after = nb.ycells[0][key]
after_py = after.to_py() if hasattr(after, "to_py") else after
assert after_py == new_value

# there should be only one change
assert len(changes) == 1
cell_events = [e for t, e in changes if t == "cells"]
# and it should be a cell change
assert len(cell_events) == 1
# but it should be a change to cell data, not a change to the cell list
events = cell_events[0]
assert len(events) == len(expected_events)
assert {type(e) for e in events} == expected_events
Loading