diff --git a/jupyter_ydoc/ynotebook.py b/jupyter_ydoc/ynotebook.py index 3a678f1..8099497 100644 --- a/jupyter_ydoc/ynotebook.py +++ b/jupyter_ydoc/ynotebook.py @@ -3,7 +3,7 @@ import copy from functools import partial -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional from uuid import uuid4 from pycrdt import Array, Awareness, Doc, Map, Text @@ -102,8 +102,11 @@ def get_cell(self, index: int) -> Dict[str, Any]: :return: A cell. :rtype: Dict[str, Any] """ + return self._cell_to_py(self._ycells[index]) + + def _cell_to_py(self, ycell: Map) -> Dict[str, Any]: meta = self._ymeta.to_py() - cell = self._ycells[index].to_py() + cell = ycell.to_py() cell.pop("execution_state", None) cast_all(cell, float, int) # cells coming from Yjs have e.g. execution_count as float if "id" in cell and meta["nbformat"] == 4 and meta["nbformat_minor"] <= 4: @@ -234,7 +237,7 @@ def set(self, value: Dict) -> None: nb_without_cells = {key: value[key] for key in value.keys() if key != "cells"} nb = copy.deepcopy(nb_without_cells) cast_all(nb, int, float) # Yjs expects numbers to be floating numbers - cells = value["cells"] or [ + new_cells = value["cells"] or [ { "cell_type": "code", "execution_count": None, @@ -245,26 +248,69 @@ def set(self, value: Dict) -> None: "id": str(uuid4()), } ] + old_ycells_by_id = {ycell["id"]: ycell for ycell in self._ycells} with self._ydoc.transaction(): - # clear document - self._ymeta.clear() - self._ycells.clear() + new_cell_list: List[dict] = [] + retained_cells = set() + + # Determine cells to be retained + for new_cell in new_cells: + cell_id = new_cell.get("id") + if cell_id and (old_ycell := old_ycells_by_id.get(cell_id)): + old_cell = self._cell_to_py(old_ycell) + if old_cell == new_cell: + new_cell_list.append(old_cell) + retained_cells.add(cell_id) + continue + # New or changed cell + new_cell_list.append(new_cell) + + # First delete all non-retained cells + if not retained_cells: + # fast path if no cells were retained + self._ycells.clear() + else: + index = 0 + for old_ycell in list(self._ycells): + if old_ycell["id"] not in retained_cells: + self._ycells.pop(index) + else: + index += 1 + + # Now add new cells + index = 0 + for new_cell in new_cell_list: + if len(self._ycells) > index: + if self._ycells[index]["id"] == new_cell.get("id"): + # retained cell + index += 1 + continue + self._ycells.insert(index, self.create_ycell(new_cell)) + index += 1 + for key in [ k for k in self._ystate.keys() if k not in ("dirty", "path", "document_id") ]: del self._ystate[key] - # initialize document - self._ycells.extend([self.create_ycell(cell) for cell in cells]) - self._ymeta["nbformat"] = nb.get("nbformat", NBFORMAT_MAJOR_VERSION) - self._ymeta["nbformat_minor"] = nb.get("nbformat_minor", NBFORMAT_MINOR_VERSION) + nbformat_major = nb.get("nbformat", NBFORMAT_MAJOR_VERSION) + nbformat_minor = nb.get("nbformat_minor", NBFORMAT_MINOR_VERSION) + + if self._ymeta.get("nbformat") != nbformat_major: + self._ymeta["nbformat"] = nbformat_major + + if self._ymeta.get("nbformat_minor") != nbformat_minor: + self._ymeta["nbformat_minor"] = nbformat_minor + old_y_metadata = self._ymeta.get("metadata") + old_metadata = old_y_metadata.to_py() if old_y_metadata else {} metadata = nb.get("metadata", {}) - metadata.setdefault("language_info", {"name": ""}) - metadata.setdefault("kernelspec", {"name": "", "display_name": ""}) - self._ymeta["metadata"] = Map(metadata) + if metadata != old_metadata: + metadata.setdefault("language_info", {"name": ""}) + metadata.setdefault("kernelspec", {"name": "", "display_name": ""}) + self._ymeta["metadata"] = Map(metadata) def observe(self, callback: Callable[[str, Any], None]) -> None: """ diff --git a/jupyter_ydoc/yunicode.py b/jupyter_ydoc/yunicode.py index 50d790c..d6f3914 100644 --- a/jupyter_ydoc/yunicode.py +++ b/jupyter_ydoc/yunicode.py @@ -63,6 +63,10 @@ def set(self, value: str) -> None: :param value: The content of the document. :type value: str """ + if self.get() == value: + # no-op if the values are already the same, + # to avoid side-effects such as cursor jumping to the top + return with self._ydoc.transaction(): # clear document self._ysource.clear() diff --git a/tests/test_ynotebook.py b/tests/test_ynotebook.py new file mode 100644 index 0000000..282133e --- /dev/null +++ b/tests/test_ynotebook.py @@ -0,0 +1,104 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +from pycrdt import Map + +from jupyter_ydoc import YNotebook + + +def make_code_cell(source: str): + return { + "cell_type": "code", + "source": source, + "metadata": {}, + "outputs": [], + "execution_count": None, + } + + +class AnyInstanceOf: + def __init__(self, cls): + self.cls = cls + + def __eq__(self, other): + return isinstance(other, self.cls) + + +def test_set_preserves_cells_when_unchanged(): + nb = YNotebook() + nb.set({"cells": [make_code_cell("print('a')\n"), make_code_cell("print('b')\n")]}) + + changes = [] + + def record_changes(topic, event): + changes.append((topic, event)) + + nb.observe(record_changes) + + model = nb.get() + + # Call set with identical structure + nb.set(model) + + # No changes should be observed at all + assert changes == [] + + +def test_set_preserves_cells_with_insert_and_remove(): + nb = YNotebook() + nb.set( + { + "cells": [ + make_code_cell("print('a')\n"), # original 0 + make_code_cell("print('b')\n"), # original 1 (will remove) + make_code_cell("print('c')\n"), # original 2 + ] + } + ) + + # Capture textual content for sanity check + cell0_source_text = str(nb.ycells[0]["source"]) + cell2_source_text = str(nb.ycells[2]["source"]) + + # Get the model as Python object + model = nb.get() + + # Remove the middle cell and insert a new one between the retained cells + cells = model["cells"] + assert len(cells) == 3 + + # The cell ids are needed for retention logic; keep first and last + first = cells[0] + last = cells[2] + + # New inserted cell + inserted = make_code_cell("print('x')\n") + model["cells"] = [first, inserted, last] + + changes = [] + + def record_changes(topic, event): + changes.append((topic, event)) + + nb.observe(record_changes) + nb.set(model) + + assert nb.cell_number == 3 + + # Content of the first and last cells should remain the same + assert str(nb.ycells[0]["source"]) == cell0_source_text + assert str(nb.ycells[2]["source"]) == cell2_source_text + + # The middle cell should have a different source now + assert str(nb.ycells[1]["source"]) == "print('x')\n" + + # We should have one cell event + cell_events = [e for t, e in changes if t == "cells"] + assert len(cell_events) == 1 + event_transactions = cell_events[0] + assert len(event_transactions) == 1 + assert event_transactions[0].delta == [ + {"retain": 1}, + {"delete": 1}, + {"insert": [AnyInstanceOf(Map)]}, + ] diff --git a/tests/test_yunicode.py b/tests/test_yunicode.py new file mode 100644 index 0000000..096436e --- /dev/null +++ b/tests/test_yunicode.py @@ -0,0 +1,24 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +from jupyter_ydoc import YUnicode + + +def test_set_no_op_if_unchaged(): + text = YUnicode() + text.set("test content") + + changes = [] + + def record_changes(topic, event): + changes.append((topic, event)) + + text.observe(record_changes) + + model = text.get() + + # Call set with identical text + text.set(model) + + # No changes should be observed at all + assert changes == []