diff --git a/aries_cloudagent/messaging/models/base_record.py b/aries_cloudagent/messaging/models/base_record.py index 40b7c0ba04..b724b333cd 100644 --- a/aries_cloudagent/messaging/models/base_record.py +++ b/aries_cloudagent/messaging/models/base_record.py @@ -314,7 +314,10 @@ async def query( storage = session.inject(BaseStorage) tag_query = cls.prefix_tag_filter(tag_filter) - if limit is not None or offset is not None: + post_filter = post_filter_positive or post_filter_negative + paginated = limit is not None or offset is not None + if not post_filter and paginated: + # Only fetch paginated records if post-filter is not being applied rows = await storage.find_paginated_records( type_filter=cls.RECORD_TYPE, tag_query=tag_query, @@ -328,23 +331,36 @@ async def query( ) result = [] + num_results_post_filter = 0 # to apply pagination post-filter + num_records_to_match = ( + (limit or DEFAULT_PAGE_SIZE) + (offset or 0) if paginated else sys.maxsize + ) # if pagination is not requested, set to sys.maxsize to process all records for record in rows: vals = json.loads(record.value) - if match_post_filter( - vals, - post_filter_positive, - positive=True, - alt=alt, - ) and match_post_filter( - vals, - post_filter_negative, - positive=False, - alt=alt, - ): - try: + try: + if not post_filter: # pagination would already be applied if requested result.append(cls.from_storage(record.id, vals)) - except BaseModelError as err: - raise BaseModelError(f"{err}, for record id {record.id}") + elif ( + (not paginated or num_results_post_filter < num_records_to_match) + and match_post_filter( + vals, + post_filter_positive, + positive=True, + alt=alt, + ) + and match_post_filter( + vals, + post_filter_negative, + positive=False, + alt=alt, + ) + ): + if num_results_post_filter >= (offset or 0): + # append post-filtered records after requested offset + result.append(cls.from_storage(record.id, vals)) + num_results_post_filter += 1 + except BaseModelError as err: + raise BaseModelError(f"{err}, for record id {record.id}") return result async def save( diff --git a/aries_cloudagent/messaging/models/tests/test_base_record.py b/aries_cloudagent/messaging/models/tests/test_base_record.py index 9a8fd8c3f8..eaa60aa61d 100644 --- a/aries_cloudagent/messaging/models/tests/test_base_record.py +++ b/aries_cloudagent/messaging/models/tests/test_base_record.py @@ -413,3 +413,37 @@ async def test_query_with_limit_and_offset(self): assert result[0]._id == record_id assert result[0].value == record_value assert result[0].a == "one" + + async def test_query_with_limit_and_offset_and_post_filter(self): + session = InMemoryProfile.test_session() + mock_storage = mock.MagicMock(BaseStorage, autospec=True) + session.context.injector.bind_instance(BaseStorage, mock_storage) + record_id = "record_id" + a_record = ARecordImpl(ident=record_id, a="one", b="two", code="red") + record_value = a_record.record_value + record_value.update({"created_at": time_now(), "updated_at": time_now()}) + tag_filter = {"code": "red"} + stored = StorageRecord( + ARecordImpl.RECORD_TYPE, + json.dumps(record_value), + {"code": "red"}, + record_id, + ) + mock_storage.find_all_records.return_value = [stored] * 15 # return 15 records + + # Query with limit and offset + result = await ARecordImpl.query( + session, + tag_filter, + limit=10, + offset=5, + post_filter_positive={"a": "one"}, + ) + mock_storage.find_all_records.assert_awaited_once_with( + type_filter=ARecordImpl.RECORD_TYPE, tag_query=tag_filter + ) + assert len(result) == 10 + assert result and isinstance(result[0], ARecordImpl) + assert result[0]._id == record_id + assert result[0].value == record_value + assert result[0].a == "one"