-
Notifications
You must be signed in to change notification settings - Fork 71
Use streaming in ingest_email.py #265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
17ee6ad
649db72
386861e
1419455
33b357f
f42ff4d
43c4905
551d407
ed19170
ec4085d
5bc5764
016e6a3
875b7ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -238,7 +238,10 @@ async def add_messages_streaming( | |||||||
|
|
||||||||
| Args: | ||||||||
| messages: An async iterable of messages to ingest. | ||||||||
| batch_size: Number of messages per commit batch. | ||||||||
| batch_size: Target number of text chunks per commit batch. | ||||||||
| Messages are never split across batches, so the actual | ||||||||
| chunk count may exceed ``batch_size`` if a single message | ||||||||
| has more chunks than that. | ||||||||
|
gvanrossum marked this conversation as resolved.
|
||||||||
| on_batch_committed: Optional callback invoked after each batch is | ||||||||
| committed, receiving the batch's ``AddMessagesResult``. | ||||||||
|
|
||||||||
|
|
@@ -255,23 +258,28 @@ def _accumulate(result: AddMessagesResult) -> None: | |||||||
| total.messages_added += result.messages_added | ||||||||
| total.semrefs_added += result.semrefs_added | ||||||||
| total.chunks_added += result.chunks_added | ||||||||
| total.messages_skipped += result.messages_skipped | ||||||||
| if on_batch_committed: | ||||||||
| on_batch_committed(result) | ||||||||
|
|
||||||||
| pending_commit: asyncio.Task[AddMessagesResult] | None = None | ||||||||
| pending_skipped: int = 0 | ||||||||
|
|
||||||||
| async def _drain_commit() -> None: | ||||||||
| nonlocal pending_commit | ||||||||
| nonlocal pending_commit, pending_skipped | ||||||||
| if pending_commit is not None: | ||||||||
| _accumulate(await pending_commit) | ||||||||
| result = await pending_commit | ||||||||
| result.messages_skipped += pending_skipped | ||||||||
| _accumulate(result) | ||||||||
| pending_commit = None | ||||||||
| pending_skipped = 0 | ||||||||
|
|
||||||||
| async def _submit_batch(filtered: list[TMessage]) -> None: | ||||||||
| nonlocal pending_commit | ||||||||
| if not filtered: | ||||||||
| async def _submit_batch(filtered: list[TMessage], skipped: int) -> None: | ||||||||
| nonlocal pending_commit, pending_skipped | ||||||||
| if not filtered and not skipped: | ||||||||
| return | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dead guard. All three call sites guarantee
Suggested change
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @KRRT7 I am confused by your suggestions. This one, when applied, generates invalid syntax. Maybe you are using a tool to generate them and it's got an off-by-one error or is working from an outdated file version? It's the case for all suggestions in this batch.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right. If we ever revisit this code we can delete it. Bernhard merged it before I got a chance to look at your feedback, in part confused by the misaligned suggestions. It's harmless so I see no need to fix it in a separate PR. |
||||||||
|
|
||||||||
| if should_extract: | ||||||||
| if filtered and should_extract: | ||||||||
| next_extraction = asyncio.create_task( | ||||||||
| self._extract_knowledge_for_batch(filtered) | ||||||||
| ) | ||||||||
|
|
@@ -281,6 +289,13 @@ async def _submit_batch(filtered: list[TMessage]) -> None: | |||||||
| # Wait for previous commit to finish (frees the DB connection) | ||||||||
| await _drain_commit() | ||||||||
|
|
||||||||
| if not filtered: | ||||||||
| # Nothing to commit, just report skipped | ||||||||
| total.messages_skipped += skipped | ||||||||
| if on_batch_committed: | ||||||||
| on_batch_committed(AddMessagesResult(messages_skipped=skipped)) | ||||||||
|
gvanrossum marked this conversation as resolved.
|
||||||||
| return | ||||||||
|
|
||||||||
| # Await extraction result for this batch | ||||||||
| extraction = await next_extraction if next_extraction is not None else None | ||||||||
|
|
||||||||
|
|
@@ -289,19 +304,29 @@ async def _submit_batch(filtered: list[TMessage]) -> None: | |||||||
| pending_commit = asyncio.create_task( | ||||||||
| self._commit_batch_streaming(storage, filtered, extraction) | ||||||||
| ) | ||||||||
| pending_skipped = skipped | ||||||||
|
|
||||||||
| try: | ||||||||
| batch: list[TMessage] = [] | ||||||||
| batch_chunks = 0 | ||||||||
| async for msg in messages: | ||||||||
| msg_chunks = len(msg.text_chunks) | ||||||||
| if batch and batch_chunks + msg_chunks > batch_size: | ||||||||
| filtered, skipped = await self._filter_ingested(storage, batch) | ||||||||
| await _submit_batch(filtered, skipped) | ||||||||
| batch = [] | ||||||||
| batch_chunks = 0 | ||||||||
| batch.append(msg) | ||||||||
| if len(batch) >= batch_size: | ||||||||
| filtered = await self._filter_ingested(storage, batch) | ||||||||
| await _submit_batch(filtered) | ||||||||
| batch_chunks += msg_chunks | ||||||||
| if batch_chunks >= batch_size: | ||||||||
| filtered, skipped = await self._filter_ingested(storage, batch) | ||||||||
| await _submit_batch(filtered, skipped) | ||||||||
| batch = [] | ||||||||
| batch_chunks = 0 | ||||||||
|
|
||||||||
| if batch: | ||||||||
| filtered = await self._filter_ingested(storage, batch) | ||||||||
| await _submit_batch(filtered) | ||||||||
| filtered, skipped = await self._filter_ingested(storage, batch) | ||||||||
| await _submit_batch(filtered, skipped) | ||||||||
|
|
||||||||
| await _drain_commit() | ||||||||
| except BaseException: | ||||||||
|
|
@@ -317,23 +342,26 @@ async def _filter_ingested( | |||||||
| self, | ||||||||
| storage: IStorageProvider[TMessage], | ||||||||
| batch: list[TMessage], | ||||||||
| ) -> list[TMessage]: | ||||||||
| ) -> tuple[list[TMessage], int]: | ||||||||
| """Filter out messages whose source_id has already been ingested. | ||||||||
|
|
||||||||
| Safe to call while a pending_commit task exists: is_source_ingested | ||||||||
| is a synchronous SELECT on SQLite's single connection, so it won't | ||||||||
| interleave with the commit task's cursor operations in asyncio's | ||||||||
| cooperative model. If the storage provider becomes truly async | ||||||||
| (e.g. aiosqlite), this assumption needs revisiting. | ||||||||
| Returns (filtered_messages, skipped_count). | ||||||||
|
|
||||||||
| Uses a single batch query instead of per-message lookups. | ||||||||
| """ | ||||||||
| source_ids = [m.source_id for m in batch if m.source_id is not None] | ||||||||
| if source_ids: | ||||||||
| ingested = await storage.are_sources_ingested(source_ids) | ||||||||
| else: | ||||||||
| ingested = set[str]() | ||||||||
| filtered: list[TMessage] = [] | ||||||||
| skipped = 0 | ||||||||
| for msg in batch: | ||||||||
| if msg.source_id is not None and await storage.is_source_ingested( | ||||||||
| msg.source_id | ||||||||
| ): | ||||||||
| if msg.source_id is not None and msg.source_id in ingested: | ||||||||
| skipped += 1 | ||||||||
| continue | ||||||||
| filtered.append(msg) | ||||||||
| return filtered | ||||||||
| return filtered, skipped | ||||||||
|
|
||||||||
| async def _extract_knowledge_for_batch( | ||||||||
| self, | ||||||||
|
|
||||||||
Uh oh!
There was an error while loading. Please reload this page.