Skip to content

Commit

Permalink
Move global origin registry to document (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jul 25, 2024
1 parent 25ae05d commit be2c7c4
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 14 deletions.
2 changes: 2 additions & 0 deletions python/pycrdt/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class BaseDoc:
_txn: Transaction | None
_Model: Any
_subscriptions: list[Subscription]
_origins: dict[int, Any]

def __init__(
self,
Expand All @@ -40,6 +41,7 @@ def __init__(
self._txn = None
self._Model = Model
self._subscriptions = []
self._origins = {}


class BaseType(ABC):
Expand Down
16 changes: 7 additions & 9 deletions python/pycrdt/_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
from ._doc import Doc


origins: dict[int, Any] = {}


class Transaction:
_doc: Doc
_txn: _Transaction | None
Expand All @@ -25,7 +22,7 @@ def __init__(self, doc: Doc, _txn: _Transaction | None = None, *, origin: Any =
self._origin = None
else:
self._origin = hash_origin(origin)
origins[self._origin] = origin
doc._origins[self._origin] = origin

def __enter__(self) -> Transaction:
self._nb += 1
Expand All @@ -47,9 +44,12 @@ def __exit__(
# only drop the transaction when exiting root context manager
# since nested transactions reuse the root transaction
if self._nb == 0:
# dropping the transaction will commit, no need to do it
# self._txn.commit()
assert self._txn is not None
if not isinstance(self, ReadTransaction):
self._txn.commit()
origin_hash = self._txn.origin()
if origin_hash is not None:
del self._doc._origins[origin_hash]
self._txn.drop()
self._txn = None
self._doc._txn = None
Expand All @@ -63,9 +63,7 @@ def origin(self) -> Any:
if origin_hash is None:
return None

origin = origins[origin_hash]
del origins[origin_hash]
return origin
return self._doc._origins[origin_hash]


class ReadTransaction(Transaction):
Expand Down
24 changes: 19 additions & 5 deletions tests/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def callback(event):


def test_origin():
doc = Doc()
doc["text"] = text = Text()
doc0 = Doc()
doc0["text"] = text = Text()

class Origin:
pass
Expand All @@ -57,7 +57,7 @@ def callback(event, txn):

text.observe(callback)

with doc.transaction(origin=origin0) as txn:
with doc0.transaction(origin=origin0) as txn:
text += "Hello"

assert origin1 is origin0
Expand All @@ -68,13 +68,27 @@ def callback(event, txn):
assert str(excinfo.value) == "No current transaction"

with pytest.raises(TypeError) as excinfo:
doc.transaction(origin={})
doc0.transaction(origin={})

assert str(excinfo.value) == "Origin must be hashable"

with doc.transaction() as txn:
with doc0.transaction() as txn:
assert txn.origin is None

doc1 = Doc()
with doc0.transaction(origin=origin0) as txn0:
with doc1.transaction(origin=origin0) as txn1:
assert txn0.origin == origin0
assert txn1.origin == origin0
assert len(doc0._origins) == 1
assert list(doc0._origins.values())[0] == origin0
assert doc0._origins == doc1._origins
assert len(doc0._origins) == 1
assert list(doc0._origins.values())[0] == origin0
assert len(doc1._origins) == 0
assert len(doc0._origins) == 0
assert len(doc1._origins) == 0


def test_observe_callback_params():
doc = Doc()
Expand Down

0 comments on commit be2c7c4

Please sign in to comment.