diff --git a/jupyter_ydoc/ynotebook.py b/jupyter_ydoc/ynotebook.py index 9067740..b6886a7 100644 --- a/jupyter_ydoc/ynotebook.py +++ b/jupyter_ydoc/ynotebook.py @@ -17,6 +17,8 @@ # The default minor version of the notebook format. NBFORMAT_MINOR_VERSION = 5 +_CELL_KEY_TYPE_MAP = {"metadata": Map, "source": Text, "outputs": Array} + class YNotebook(YBaseDoc): """ @@ -249,7 +251,7 @@ def set(self, value: dict) -> None: "id": str(uuid4()), } ] - old_ycells_by_id = {ycell["id"]: ycell for ycell in self._ycells} + old_ycells_by_id: dict[str, Map] = {ycell["id"]: ycell for ycell in self._ycells} with self._ydoc.transaction(): new_cell_list: list[dict] = [] @@ -260,7 +262,11 @@ def set(self, value: dict) -> None: cell_id = new_cell.get("id") if cell_id and (old_ycell := old_ycells_by_id.get(cell_id)): old_cell = self._cell_to_py(old_ycell) - if old_cell == new_cell: + updated_granularly = self._update_cell( + old_cell=old_cell, new_cell=new_cell, old_ycell=old_ycell + ) + + if updated_granularly: new_cell_list.append(old_cell) retained_cells.add(cell_id) continue @@ -324,3 +330,57 @@ def observe(self, callback: Callable[[str, Any], None]) -> None: self._subscriptions[self._ystate] = self._ystate.observe(partial(callback, "state")) self._subscriptions[self._ymeta] = self._ymeta.observe_deep(partial(callback, "meta")) self._subscriptions[self._ycells] = self._ycells.observe_deep(partial(callback, "cells")) + + def _update_cell(self, old_cell: dict, new_cell: dict, old_ycell: Map) -> bool: + if old_cell == new_cell: + return True + # attempt to update cell granularly + old_keys = set(old_cell.keys()) + new_keys = set(new_cell.keys()) + + shared_keys = old_keys & new_keys + removed_keys = old_keys - new_keys + added_keys = new_keys - old_keys + + for key in shared_keys: + if old_cell[key] != new_cell[key]: + value = new_cell[key] + if key == "output" and value: + # outputs require complex handling - some have Text type nested; + # for now skip creating them; clearing all outputs is fine + return False + + if key in _CELL_KEY_TYPE_MAP: + kind = _CELL_KEY_TYPE_MAP[key] + + if not isinstance(old_ycell[key], kind): + # if our assumptions about types do not hold, fall back to hard update + return False + + if kind == Text: + old: Text = old_ycell[key] + old.clear() + old += value + elif kind == Array: + old: Array = old_ycell[key] + old.clear() + old.extend(value) + elif kind == Map: + old: Map = old_ycell[key] + old.clear() + old.update(value) + else: + old_ycell[key] = new_cell[key] + + for key in removed_keys: + del old_ycell[key] + + for key in added_keys: + if key in _CELL_KEY_TYPE_MAP: + # we hard-reload cells when keys that require nested types get added + # to allow the frontend to connect observers; this could be changed + # in the future, once frontends learn how to observe all changes + return False + else: + old_ycell[key] = new_cell[key] + return True diff --git a/tests/test_ynotebook.py b/tests/test_ynotebook.py index f7619a1..868e9c1 100644 --- a/tests/test_ynotebook.py +++ b/tests/test_ynotebook.py @@ -1,7 +1,9 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -from pycrdt import Map + +from pycrdt import ArrayEvent, Map, MapEvent, TextEvent +from pytest import mark from jupyter_ydoc import YNotebook @@ -114,3 +116,66 @@ def record_changes(topic, event): {"delete": 1}, {"insert": [AnyInstanceOf(Map)]}, ] + + +@mark.parametrize( + "modifications, expected_events", + [ + # modifications of single attributes + ([["source", "'b'"]], {TextEvent}), + ([["outputs", []]], {ArrayEvent}), + ([["execution_count", 2]], {MapEvent}), + ([["metadata", {"tags": []}]], {MapEvent}), + ([["new_key", "test"]], {MapEvent}), + # multi-attribute modifications + ([["source", "10"], ["execution_count", 10]], {TextEvent, MapEvent}), + ], +) +def test_modify_single_cell(modifications, expected_events): + nb = YNotebook() + nb.set( + { + "cells": [ + { + "id": "8800f7d8-6cad-42ef-a339-a9c185ffdd54", + "cell_type": "code", + "source": "'a'", + "metadata": {"tags": ["test-tag"]}, + "outputs": [{"name": "stdout", "output_type": "stream", "text": ["a\n"]}], + "execution_count": 1, + }, + ] + } + ) + + # Get the model as Python object + model = nb.get() + + # Make changes + for modification in modifications: + key, new_value = modification + model["cells"][0][key] = new_value + + changes = [] + + def record_changes(topic, event): + changes.append((topic, event)) + + nb.observe(record_changes) + nb.set(model) + + for modification in modifications: + key, new_value = modification + after = nb.ycells[0][key] + after_py = after.to_py() if hasattr(after, "to_py") else after + assert after_py == new_value + + # there should be only one change + assert len(changes) == 1 + cell_events = [e for t, e in changes if t == "cells"] + # and it should be a cell change + assert len(cell_events) == 1 + # but it should be a change to cell data, not a change to the cell list + events = cell_events[0] + assert len(events) == len(expected_events) + assert {type(e) for e in events} == expected_events