From e5aa1b0839937cf3f96bd7b932537a916880ec84 Mon Sep 17 00:00:00 2001 From: Jamie Alexandre Date: Sun, 16 Dec 2018 20:23:12 -0800 Subject: [PATCH] Fire callbacks when rows in a collection are modified. --- README.md | 1 + notion/block.py | 2 +- notion/client.py | 6 +++++- notion/collection.py | 18 ++++++++++++++++-- notion/monitor.py | 7 +++---- notion/store.py | 27 ++++++++++++++++++++++++--- 6 files changed, 50 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index e161ee0..d2881e6 100644 --- a/README.md +++ b/README.md @@ -168,6 +168,7 @@ You can also see [more examples in action in the smoke test runner](https://gith # TODO * Cloning pages hierarchically +* Debounce cache-saving? * Support inline "user" and "page" links, and reminders, in markdown conversion * Utilities to support updating/creating collection schemas * Utilities to support updating/creating collection_view queries diff --git a/notion/block.py b/notion/block.py index a18e889..1ab1727 100644 --- a/notion/block.py +++ b/notion/block.py @@ -243,7 +243,7 @@ def _convert_diff_to_changelist(self, difference, old_val, new_val): content_changed = True continue - # check whether the value changed matches on of our mapped fields/properties + # check whether the value changed matches one of our mapped fields/properties fields = [(name, field) for name, field in mappers.items() if path.startswith(field.path)] if fields: changed_fields.add(fields[0]) diff --git a/notion/client.py b/notion/client.py index f4df0a2..9a0ff63 100644 --- a/notion/client.py +++ b/notion/client.py @@ -111,6 +111,10 @@ def refresh_records(self, **kwargs): """ self._store.call_get_record_values(**kwargs) + def refresh_collection_rows(self, collection_id): + row_ids = self.search_pages_with_parent(collection_id) + self._store.set_collection_rows(collection_id, row_ids) + def post(self, endpoint, data): """ All API requests on Notion.so are done as POSTs (except the websocket communications). @@ -169,7 +173,7 @@ def search_pages_with_parent(self, parent_id, search=""): self._store.store_recordmap(response["recordMap"]) - return [self.get_block(page_id) for page_id in response["results"]] + return response["results"] def create_record(self, table, parent, **kwargs): diff --git a/notion/collection.py b/notion/collection.py index ff4cdce..a68bce9 100644 --- a/notion/collection.py +++ b/notion/collection.py @@ -52,9 +52,23 @@ def add_row(self): return CollectionRowBlock(self._client, row_id) - def get_rows(self, search=""): + def get_rows(self): - return self._client.search_pages_with_parent(self.id, search=search) + return [self._client.get_block(row_id) for row_id in self._client._store.get_collection_rows(self.id)] + + def _convert_diff_to_changelist(self, difference, old_val, new_val): + + changes = [] + remaining = [] + + for operation, path, values in difference: + + if path == "rows": + changes.append((operation, path, values)) + else: + remaining.append((operation, path, values)) + + return changes + super()._convert_diff_to_changelist(remaining, old_val, new_val) class CollectionView(Record): diff --git a/notion/monitor.py b/notion/monitor.py index 98875c1..deb9958 100644 --- a/notion/monitor.py +++ b/notion/monitor.py @@ -173,11 +173,10 @@ def _refresh_updated_records(self, events): collection_id = match.groups()[0] - collection = self.client.get_collection(collection_id) + self.client.refresh_collection_rows(collection_id) + row_ids = self.client._store.get_collection_rows(collection_id) - row_ids = [row.id for row in collection.get_rows()] - - logger.debug("Something inside {} has changed; refreshing all {} rows inside it".format(collection, len(row_ids))) + logger.debug("Something inside collection {} has changed; refreshing all {} rows inside it".format(collection_id, len(row_ids))) records_to_refresh["block"] += row_ids diff --git a/notion/store.py b/notion/store.py index 31cbdd8..72215bb 100644 --- a/notion/store.py +++ b/notion/store.py @@ -74,6 +74,7 @@ def __init__(self, client, cache_key=None): self._cache_key = cache_key or str(int(datetime.datetime.now().timestamp() * 1000)) self._values = defaultdict(lambda: defaultdict(dict)) self._role = defaultdict(lambda: defaultdict(str)) + self._collection_row_ids = {} self._callbacks = defaultdict(lambda: defaultdict(list)) self._records_to_refresh = {} self._pages_to_refresh = [] @@ -103,15 +104,35 @@ def remove_callbacks(self, table, id, callback_or_callback_id_prefix=""): def _get_cache_path(self, attribute): return str(Path(CACHE_DIR).joinpath("{}{}.json".format(self._cache_key, attribute))) - def _load_cache(self, attributes=("_values", "_role")): + def _load_cache(self, attributes=("_values", "_role", "_collection_row_ids")): for attr in attributes: try: with open(self._get_cache_path(attr)) as f: - for k, v in json.load(f).items(): - getattr(self, attr)[k].update(v) + if attr == "_collection_row_ids": + self._collection_row_ids.update(json.load(f)) + else: + for k, v in json.load(f).items(): + getattr(self, attr)[k].update(v) except FileNotFoundError: pass + def set_collection_rows(self, collection_id, row_ids): + + if collection_id in self._collection_row_ids: + old_ids = set(self._collection_row_ids[collection_id]) + new_ids = set(row_ids) + added = new_ids - old_ids + removed = old_ids - new_ids + for id in added: + self._trigger_callbacks("collection", collection_id, [("row_added", "rows", id)], old_ids, new_ids) + for id in removed: + self._trigger_callbacks("collection", collection_id, [("row_removed", "rows", id)], old_ids, new_ids) + self._collection_row_ids[collection_id] = row_ids + self._save_cache("_collection_row_ids") + + def get_collection_rows(self, collection_id): + return self._collection_row_ids.get(collection_id, []) + def _save_cache(self, attribute): with open(self._get_cache_path(attribute), "w") as f: json.dump(getattr(self, attribute), f)