Skip to content

Commit

Permalink
Store commit sequence aggregrate was built from with aggregate for op…
Browse files Browse the repository at this point in the history
…timistic concurrency
  • Loading branch information
jjrdk committed Apr 22, 2024
1 parent 96ed171 commit 8d3ec42
Show file tree
Hide file tree
Showing 14 changed files with 53 additions and 72 deletions.
38 changes: 25 additions & 13 deletions aett_domain/src/aett/domain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,17 @@ class Aggregate(ABC, typing.Generic[T]):
of the event relies on multiple dispatch to call the correct apply method in the subclass.
"""

def __init__(self, stream_id: str):
def __init__(self, stream_id: str, commit_sequence: int):
"""
Initialize the aggregate
param stream_id: The id of the stream
param commit_sequence: The commit sequence number which the aggregate was built from
"""
self.uncommitted: typing.List[EventMessage] = []
self._id = stream_id
self._version = 0
self._commit_sequence = commit_sequence
self.uncommitted.clear()

@property
Expand All @@ -26,6 +33,10 @@ def id(self) -> str:
def version(self) -> int:
return self._version

@property
def commit_sequence(self):
return self._commit_sequence

@abstractmethod
def apply_memento(self, memento: T) -> None:
"""
Expand Down Expand Up @@ -172,12 +183,12 @@ def get(self, cls: typing.Type[TAggregate], stream_id: str, max_version: int = 2
min_version = 0
if snapshot is not None:
min_version = snapshot.stream_revision
commits = self._store.get(tenant_id=self._tenant_id,
stream_id=stream_id,
min_revision=min_version,
max_revision=max_version)

aggregate = cls(stream_id)
commits = list(self._store.get(tenant_id=self._tenant_id,
stream_id=stream_id,
min_revision=min_version,
max_revision=max_version))
commit_sequence = commits[-1].commit_sequence if len(commits) > 0 else 0
aggregate = cls(stream_id, commit_sequence)
if snapshot is not None:
aggregate.apply_memento(memento_type(**jsonpickle.decode(snapshot.payload)))
for commit in commits:
Expand All @@ -188,10 +199,11 @@ def get(self, cls: typing.Type[TAggregate], stream_id: str, max_version: int = 2

def get_to(self, cls: typing.Type[TAggregate], stream_id: str,
max_time: datetime = datetime.datetime.max) -> TAggregate:
commits = self._store.get_to(tenant_id=self._tenant_id,
stream_id=stream_id,
max_time=max_time)
aggregate = cls(stream_id)
commits = list(self._store.get_to(tenant_id=self._tenant_id,
stream_id=stream_id,
max_time=max_time))
commit_sequence = commits[-1].commit_sequence if len(commits) > 0 else 0
aggregate = cls(stream_id, commit_sequence)
for commit in commits:
for event in commit.events:
aggregate.raise_event(event.body)
Expand All @@ -214,7 +226,7 @@ def save(self, aggregate: TAggregate, headers: Dict[str, str] = None) -> None:
stream_id=aggregate.id,
stream_revision=aggregate.version,
commit_id=uuid.uuid4(),
commit_sequence=0,
commit_sequence=aggregate.commit_sequence + 1,
commit_stamp=datetime.datetime.now(datetime.UTC),
headers=dict(headers),
events=list(aggregate.uncommitted),
Expand Down Expand Up @@ -313,4 +325,4 @@ def __init__(self, message: str):

class NonConflictingCommitException(Exception):
def __init__(self, message: str):
super().__init__(message)
super().__init__(message)
2 changes: 1 addition & 1 deletion aett_domain/tests/features/steps/AggregateRepository.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, storage: {}):
def get(self, cls: typing.Type[TAggregate], identifier: str, version: int = MAX_INT) -> TestAggregate:
memento_type: typing.Union = inspect.signature(cls.apply_memento).parameters['memento'].annotation
m = self.storage.get(identifier)
agg = cls(identifier)
agg = cls(identifier, 0)
agg.apply_memento(m)
return agg

Expand Down
2 changes: 1 addition & 1 deletion aett_domain/tests/features/steps/Aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def step_impl(context):

@given("an aggregate")
def step_impl(context):
agg = TestAggregate('test')
agg = TestAggregate('test', 0)
context.aggregate = agg


Expand Down
4 changes: 2 additions & 2 deletions aett_domain/tests/features/steps/Types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ class TestMemento(Memento):


class TestAggregate(Aggregate[TestMemento]):
def __init__(self, stream_id: str):
def __init__(self, stream_id: str, commit_sequence: int):
self.value = 0
super().__init__(stream_id=stream_id)
super().__init__(stream_id=stream_id, commit_sequence=commit_sequence)

def apply_memento(self, memento: TestMemento) -> None:
if memento is None:
Expand Down
15 changes: 1 addition & 14 deletions aett_dynamodb/src/aett/dynamodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,26 +89,13 @@ def _item_to_commit(self, item: dict) -> Commit:

def commit(self, commit: Commit):
try:
query_response = self.table.query(
TableName=self._table_name,
IndexName="RevisionIndex",
ConsistentRead=True,
Limit=1,
ProjectionExpression='CommitSequence,StreamRevision',
KeyConditionExpression=(Key("TenantAndStream").eq(f'{commit.tenant_id}{commit.stream_id}')),
ScanIndexForward=False)
items = query_response['Items']
commit_sequence = int(items[0]['CommitSequence']) if len(items) > 0 else 0
last_revision = int(items[0]['StreamRevision']) if len(items) > 0 else 0
if 0 < last_revision != commit.stream_revision - len(commit.events):
self._raise_conflict(commit)
item = {
'TenantAndStream': f'{commit.tenant_id}{commit.stream_id}',
'TenantId': commit.tenant_id,
'StreamId': commit.stream_id,
'StreamRevision': commit.stream_revision,
'CommitId': str(commit.commit_id),
'CommitSequence': commit_sequence + 1,
'CommitSequence': commit.commit_sequence,
'CommitStamp': int(commit.commit_stamp.timestamp()),
'Headers': jsonpickle.encode(commit.headers, unpicklable=False),
'Events': jsonpickle.encode([e.to_json() for e in commit.events], unpicklable=False)
Expand Down
4 changes: 2 additions & 2 deletions aett_dynamodb/tests/features/steps/Types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def detect(self, uncommitted: TestEvent, committed: TestEvent) -> bool:


class TestAggregate(Aggregate[TestMemento]):
def __init__(self, stream_id: str):
def __init__(self, stream_id: str, commit_sequence: int):
self.value = 0
super().__init__(stream_id=stream_id)
super().__init__(stream_id=stream_id, commit_sequence=commit_sequence)

def apply_memento(self, memento: TestMemento) -> None:
if self.id != memento.id:
Expand Down
4 changes: 2 additions & 2 deletions aett_inmemory/src/aett/inmemory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def get_all_to(self, tenant_id: str, max_time: datetime.datetime = datetime.date
def commit(self, commit: Commit):
self._ensure_stream(commit.tenant_id, commit.stream_id)
existing = self._buckets[commit.tenant_id][commit.stream_id]
if len(existing) > 0 and existing[-1].stream_revision >= commit.stream_revision:
if len(existing) > 0 and existing[-1].commit_sequence >= commit.commit_sequence:
if existing[-1].commit_id == commit.commit_id:
raise DuplicateCommitException('Duplicate commit')
commits = [e for c in (c.events for c in existing if c.stream_revision >= commit.stream_revision) for e in
commits = [e for c in (c.events for c in existing if c.commit_sequence >= commit.commit_sequence) for e in
c]
if self._conflict_detector.conflicts_with(list(map(self._get_body, commit.events)),
list(map(self._get_body, commits))):
Expand Down
4 changes: 2 additions & 2 deletions aett_inmemory/tests/features/steps/Types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def detect(self, uncommitted: TestEvent, committed: TestEvent) -> bool:


class TestAggregate(Aggregate[TestMemento]):
def __init__(self, stream_id):
def __init__(self, stream_id, commit_sequence):
self.value = 0
super().__init__(stream_id)
super().__init__(stream_id, commit_sequence)

def apply_memento(self, memento: TestMemento) -> None:
if self.id != memento.id:
Expand Down
13 changes: 7 additions & 6 deletions aett_mongo/src/aett/mongodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _doc_to_commit(self, doc: dict) -> Commit:
stream_id=doc['StreamId'],
stream_revision=int(doc['StreamRevision']),
commit_id=UUID(doc['CommitId']),
commit_sequence=0,
commit_sequence=int(doc['CommitSequence']),
commit_stamp=datetime.datetime.fromtimestamp(int(doc['CommitStamp']), datetime.UTC),
headers=jsonpickle.decode(doc['Headers']),
events=[EventMessage.from_json(e, self._topic_map) for e in jsonpickle.decode(doc['Events'])],
Expand All @@ -72,6 +72,7 @@ def commit(self, commit: Commit):
'StreamId': commit.stream_id,
'StreamRevision': commit.stream_revision,
'CommitId': str(commit.commit_id),
'CommitSequence': commit.commit_sequence,
'CommitStamp': int(datetime.datetime.now(datetime.UTC).timestamp()),
'Headers': jsonpickle.encode(commit.headers, unpicklable=False),
'Events': jsonpickle.encode([e.to_json() for e in commit.events], unpicklable=False),
Expand All @@ -81,7 +82,7 @@ def commit(self, commit: Commit):
except Exception as e:
if isinstance(e, pymongo.errors.DuplicateKeyError):
if self._detect_duplicate(commit.commit_id, commit.tenant_id, commit.stream_id,
commit.stream_revision):
commit.commit_sequence):
raise Exception(
f"Commit {commit.commit_id} already exists in stream {commit.stream_id}")
else:
Expand All @@ -96,15 +97,15 @@ def commit(self, commit: Commit):
raise Exception(
f"Failed to commit event to stream {commit.stream_id} with status code {e.response['ResponseMetadata']['HTTPStatusCode']}")

def _detect_duplicate(self, commit_id: UUID, tenant_id: str, stream_id: str, stream_revision: int) -> bool:
def _detect_duplicate(self, commit_id: UUID, tenant_id: str, stream_id: str, commit_sequence: int) -> bool:
duplicate_check = self._collection.find_one(
{'TenantId': tenant_id, 'StreamId': stream_id, 'StreamRevision': stream_revision})
{'TenantId': tenant_id, 'StreamId': stream_id, 'CommitSequence': commit_sequence})
s = str(duplicate_check.get('CommitId'))
return s == str(commit_id)

def _detect_conflicts(self, commit: Commit) -> (bool, int):
filters = {"TenantId": commit.tenant_id, "StreamId": commit.stream_id,
"StreamRevision": {'$lte': commit.stream_revision}}
"CommitSequence": {'$lte': commit.commit_sequence}}
query_response: pymongo.cursor.Cursor = \
self._collection.find({'$and': [filters]}).sort('CheckpointToken',
direction=pymongo.ASCENDING)
Expand Down Expand Up @@ -186,7 +187,7 @@ def initialize(self):
commits_collection.create_index([("TenantId", pymongo.ASCENDING), ("StreamId", pymongo.ASCENDING),
("StreamRevision", pymongo.ASCENDING)], comment="GetFrom", unique=True)
commits_collection.create_index([("TenantId", pymongo.ASCENDING), ("StreamId", pymongo.ASCENDING),
("StreamRevision", pymongo.ASCENDING)], comment="LogicalKey", unique=True)
("CommitSequence", pymongo.ASCENDING)], comment="LogicalKey", unique=True)
commits_collection.create_index([("CommitStamp", pymongo.ASCENDING)], comment="CommitStamp", unique=False)
commits_collection.create_index([("TenantId", pymongo.ASCENDING), ("StreamId", pymongo.ASCENDING),
("CommitId", pymongo.ASCENDING)], comment="CommitId", unique=True)
Expand Down
4 changes: 2 additions & 2 deletions aett_mongo/tests/features/steps/Types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def detect(self, uncommitted: TestEvent, committed: TestEvent) -> bool:


class TestAggregate(Aggregate[TestMemento]):
def __init__(self, stream_id):
def __init__(self, stream_id, commit_sequence):
self.value = 0
super().__init__(stream_id)
super().__init__(stream_id, commit_sequence)

def apply_memento(self, memento: TestMemento) -> None:
if self.id != memento.id:
Expand Down
11 changes: 2 additions & 9 deletions aett_postgres/src/aett/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,13 @@ def _item_to_commit(self, item):

def commit(self, commit: Commit):
try:
commit_seq_cur: psycopg.Cursor = self._connection.cursor()
commit_seq_cur.execute(
f"""SELECT MAX(CommitSequence) FROM {self._table_name} WHERE TenantId = %s AND StreamId = %s;""",
(commit.tenant_id, commit.stream_id))
commit_sequence = commit_seq_cur.fetchone()
commit_sequence = 0 if commit_sequence[0] is None else int(commit_sequence[0])
commit_seq_cur.close()
cur = self._connection.cursor()
cur.execute(f"""INSERT
INTO {self._table_name}
( TenantId, StreamId, StreamIdOriginal, CommitId, CommitSequence, StreamRevision, Items, CommitStamp, Headers, Payload )
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING CheckpointNumber;""", (commit.tenant_id, commit.stream_id, commit.stream_id,
commit.commit_id, commit_sequence + 1, commit.stream_revision, len(commit.events),
commit.commit_id, commit.commit_sequence, commit.stream_revision, len(commit.events),
commit.commit_stamp,
jsonpickle.encode(commit.headers, unpicklable=False).encode('utf-8'),
jsonpickle.encode([e.to_json() for e in commit.events], unpicklable=False).encode(
Expand All @@ -101,7 +94,7 @@ def commit(self, commit: Commit):
stream_id=commit.stream_id,
stream_revision=commit.stream_revision,
commit_id=commit.commit_id,
commit_sequence=commit_sequence + 1,
commit_sequence=commit.commit_sequence,
commit_stamp=commit.commit_stamp,
headers=commit.headers,
events=commit.events,
Expand Down
4 changes: 2 additions & 2 deletions aett_postgres/tests/features/steps/Types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def detect(self, uncommitted: TestEvent, committed: TestEvent) -> bool:


class TestAggregate(Aggregate[TestMemento]):
def __init__(self, event_stream: EventStream, memento: TestMemento = None):
def __init__(self, stream_id, commit_sequence):
self.value = 0
super().__init__(event_stream, memento)
super().__init__(stream_id=stream_id, commit_sequence=commit_sequence)

def apply_memento(self, memento: TestMemento) -> None:
if self.id != memento.id:
Expand Down
16 changes: 2 additions & 14 deletions aett_s3/src/aett/s3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ def _file_to_commit(self, key: str):
checkpoint_token=0)

def commit(self, commit: Commit):
commit_sequence = self._get_commit_sequence(commit) + 1
self.check_exists(commit_sequence=commit_sequence, commit=commit)
commit_key = f'{self._folder_name}/{commit.tenant_id}/{commit.stream_id}/{int(commit.commit_stamp.timestamp())}_{commit.commit_id}_{commit_sequence}_{commit.stream_revision}.json'
self.check_exists(commit_sequence=commit.commit_sequence, commit=commit)
commit_key = f'{self._folder_name}/{commit.tenant_id}/{commit.stream_id}/{int(commit.commit_stamp.timestamp())}_{commit.commit_id}_{commit.commit_sequence}_{commit.stream_revision}.json'
d = commit.__dict__
d['events'] = [e.to_json() for e in commit.events]
d['headers'] = {k: jsonpickle.encode(v, unpicklable=False) for k, v in commit.headers.items()}
Expand Down Expand Up @@ -138,17 +137,6 @@ def check_exists(self, commit_sequence: int, commit: Commit):
def _get_body(em: EventMessage):
return em.body

def _get_commit_sequence(self, commit: Commit):
response = self._resource.list_objects(
Delimiter='/',
Prefix=f'{self._folder_name}/{commit.tenant_id}/{commit.stream_id}/',
Bucket=self._s3_bucket)
if 'Contents' not in response:
return 0
keys = list(key for key in map(lambda r: r.get('Key'), response.get('Contents')))
keys.sort(reverse=True)
return int(keys[0].split('_')[-2])


class SnapshotStore(IAccessSnapshots):
def __init__(self, s3_config: S3Config, folder_name: str = 'snapshots'):
Expand Down
4 changes: 2 additions & 2 deletions aett_s3/tests/features/steps/Types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def detect(self, uncommitted: TestEvent, committed: TestEvent) -> bool:


class TestAggregate(Aggregate[TestMemento]):
def __init__(self, event_stream: EventStream, memento: TestMemento = None):
def __init__(self, stream_id, commit_sequence):
self.value = 0
super().__init__(event_stream, memento)
super().__init__(stream_id=stream_id, commit_sequence=commit_sequence)

def apply_memento(self, memento: TestMemento) -> None:
if self.id != memento.id:
Expand Down

0 comments on commit 8d3ec42

Please sign in to comment.