Skip to content

Commit

Permalink
Fire callbacks when rows in a collection are modified.
Browse files Browse the repository at this point in the history
  • Loading branch information
jamalex committed Dec 17, 2018
1 parent bf69225 commit e5aa1b0
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 11 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ You can also see [more examples in action in the smoke test runner](https://gith
# TODO

* Cloning pages hierarchically
* Debounce cache-saving?
* Support inline "user" and "page" links, and reminders, in markdown conversion
* Utilities to support updating/creating collection schemas
* Utilities to support updating/creating collection_view queries
Expand Down
2 changes: 1 addition & 1 deletion notion/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _convert_diff_to_changelist(self, difference, old_val, new_val):
content_changed = True
continue

# check whether the value changed matches on of our mapped fields/properties
# check whether the value changed matches one of our mapped fields/properties
fields = [(name, field) for name, field in mappers.items() if path.startswith(field.path)]
if fields:
changed_fields.add(fields[0])
Expand Down
6 changes: 5 additions & 1 deletion notion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def refresh_records(self, **kwargs):
"""
self._store.call_get_record_values(**kwargs)

def refresh_collection_rows(self, collection_id):
row_ids = self.search_pages_with_parent(collection_id)
self._store.set_collection_rows(collection_id, row_ids)

def post(self, endpoint, data):
"""
All API requests on Notion.so are done as POSTs (except the websocket communications).
Expand Down Expand Up @@ -169,7 +173,7 @@ def search_pages_with_parent(self, parent_id, search=""):

self._store.store_recordmap(response["recordMap"])

return [self.get_block(page_id) for page_id in response["results"]]
return response["results"]

def create_record(self, table, parent, **kwargs):

Expand Down
18 changes: 16 additions & 2 deletions notion/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,23 @@ def add_row(self):

return CollectionRowBlock(self._client, row_id)

def get_rows(self, search=""):
def get_rows(self):

return self._client.search_pages_with_parent(self.id, search=search)
return [self._client.get_block(row_id) for row_id in self._client._store.get_collection_rows(self.id)]

def _convert_diff_to_changelist(self, difference, old_val, new_val):

changes = []
remaining = []

for operation, path, values in difference:

if path == "rows":
changes.append((operation, path, values))
else:
remaining.append((operation, path, values))

return changes + super()._convert_diff_to_changelist(remaining, old_val, new_val)


class CollectionView(Record):
Expand Down
7 changes: 3 additions & 4 deletions notion/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,10 @@ def _refresh_updated_records(self, events):

collection_id = match.groups()[0]

collection = self.client.get_collection(collection_id)
self.client.refresh_collection_rows(collection_id)
row_ids = self.client._store.get_collection_rows(collection_id)

row_ids = [row.id for row in collection.get_rows()]

logger.debug("Something inside {} has changed; refreshing all {} rows inside it".format(collection, len(row_ids)))
logger.debug("Something inside collection {} has changed; refreshing all {} rows inside it".format(collection_id, len(row_ids)))

records_to_refresh["block"] += row_ids

Expand Down
27 changes: 24 additions & 3 deletions notion/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self, client, cache_key=None):
self._cache_key = cache_key or str(int(datetime.datetime.now().timestamp() * 1000))
self._values = defaultdict(lambda: defaultdict(dict))
self._role = defaultdict(lambda: defaultdict(str))
self._collection_row_ids = {}
self._callbacks = defaultdict(lambda: defaultdict(list))
self._records_to_refresh = {}
self._pages_to_refresh = []
Expand Down Expand Up @@ -103,15 +104,35 @@ def remove_callbacks(self, table, id, callback_or_callback_id_prefix=""):
def _get_cache_path(self, attribute):
return str(Path(CACHE_DIR).joinpath("{}{}.json".format(self._cache_key, attribute)))

def _load_cache(self, attributes=("_values", "_role")):
def _load_cache(self, attributes=("_values", "_role", "_collection_row_ids")):
for attr in attributes:
try:
with open(self._get_cache_path(attr)) as f:
for k, v in json.load(f).items():
getattr(self, attr)[k].update(v)
if attr == "_collection_row_ids":
self._collection_row_ids.update(json.load(f))
else:
for k, v in json.load(f).items():
getattr(self, attr)[k].update(v)
except FileNotFoundError:
pass

def set_collection_rows(self, collection_id, row_ids):

if collection_id in self._collection_row_ids:
old_ids = set(self._collection_row_ids[collection_id])
new_ids = set(row_ids)
added = new_ids - old_ids
removed = old_ids - new_ids
for id in added:
self._trigger_callbacks("collection", collection_id, [("row_added", "rows", id)], old_ids, new_ids)
for id in removed:
self._trigger_callbacks("collection", collection_id, [("row_removed", "rows", id)], old_ids, new_ids)
self._collection_row_ids[collection_id] = row_ids
self._save_cache("_collection_row_ids")

def get_collection_rows(self, collection_id):
return self._collection_row_ids.get(collection_id, [])

def _save_cache(self, attribute):
with open(self._get_cache_path(attribute), "w") as f:
json.dump(getattr(self, attribute), f)
Expand Down

0 comments on commit e5aa1b0

Please sign in to comment.