diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 230cec18..7a195d8c 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -27,18 +27,21 @@ """ from collections.abc import Sequence +import logging import os import numpy as np from numpy.typing import NDArray import stamina from stamina import BoundAsyncRetryingCaller +from stamina.instrumentation import RetryDetails, set_on_retry_hooks import openai from pydantic_ai import Embedder as _PydanticAIEmbedder from pydantic_ai.embeddings.base import EmbeddingModel as _PydanticAIEmbeddingModelBase from pydantic_ai.embeddings.result import EmbeddingResult, EmbedInputType from pydantic_ai.embeddings.settings import EmbeddingSettings +from pydantic_ai.exceptions import ModelAPIError from pydantic_ai.messages import ( ModelMessage, ModelRequest, @@ -60,15 +63,42 @@ openai.APIConnectionError, openai.APITimeoutError, openai.InternalServerError, + ModelAPIError, ) DEFAULT_CHAT_RETRIER = stamina.AsyncRetryingCaller(attempts=6, timeout=120).on( _TRANSIENT_ERRORS ) -DEFAULT_EMBED_RETRIER = stamina.AsyncRetryingCaller(attempts=4, timeout=30).on( +DEFAULT_EMBED_RETRIER = stamina.AsyncRetryingCaller(attempts=6, timeout=120).on( _TRANSIENT_ERRORS ) +_logger = logging.getLogger("stamina") + +_CALLABLE_LABELS: dict[str, str] = { + "request": "chat", + "embed_documents": "embedding", +} + + +def _on_retry(details: RetryDetails) -> None: + kind = _CALLABLE_LABELS.get(details.name, details.name) + caused = details.caused_by + exc_summary = repr(caused)[:200] + _logger.warning( + "stamina: retrying %s request (attempt %d, waited %.1fs so far, " + "waiting %.1fs): %s", + kind, + details.retry_num, + details.waited_so_far, + details.wait_for, + exc_summary, + ) + + +set_on_retry_hooks([_on_retry]) + + # --------------------------------------------------------------------------- # Chat model adapter # --------------------------------------------------------------------------- diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index fcc5e667..dd6814c1 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -238,7 +238,10 @@ async def add_messages_streaming( Args: messages: An async iterable of messages to ingest. - batch_size: Number of messages per commit batch. + batch_size: Target number of text chunks per commit batch. + Messages are never split across batches, so the actual + chunk count may exceed ``batch_size`` if a single message + has more chunks than that. on_batch_committed: Optional callback invoked after each batch is committed, receiving the batch's ``AddMessagesResult``. @@ -255,23 +258,28 @@ def _accumulate(result: AddMessagesResult) -> None: total.messages_added += result.messages_added total.semrefs_added += result.semrefs_added total.chunks_added += result.chunks_added + total.messages_skipped += result.messages_skipped if on_batch_committed: on_batch_committed(result) pending_commit: asyncio.Task[AddMessagesResult] | None = None + pending_skipped: int = 0 async def _drain_commit() -> None: - nonlocal pending_commit + nonlocal pending_commit, pending_skipped if pending_commit is not None: - _accumulate(await pending_commit) + result = await pending_commit + result.messages_skipped += pending_skipped + _accumulate(result) pending_commit = None + pending_skipped = 0 - async def _submit_batch(filtered: list[TMessage]) -> None: - nonlocal pending_commit - if not filtered: + async def _submit_batch(filtered: list[TMessage], skipped: int) -> None: + nonlocal pending_commit, pending_skipped + if not filtered and not skipped: return - if should_extract: + if filtered and should_extract: next_extraction = asyncio.create_task( self._extract_knowledge_for_batch(filtered) ) @@ -281,6 +289,13 @@ async def _submit_batch(filtered: list[TMessage]) -> None: # Wait for previous commit to finish (frees the DB connection) await _drain_commit() + if not filtered: + # Nothing to commit, just report skipped + total.messages_skipped += skipped + if on_batch_committed: + on_batch_committed(AddMessagesResult(messages_skipped=skipped)) + return + # Await extraction result for this batch extraction = await next_extraction if next_extraction is not None else None @@ -289,19 +304,29 @@ async def _submit_batch(filtered: list[TMessage]) -> None: pending_commit = asyncio.create_task( self._commit_batch_streaming(storage, filtered, extraction) ) + pending_skipped = skipped try: batch: list[TMessage] = [] + batch_chunks = 0 async for msg in messages: + msg_chunks = len(msg.text_chunks) + if batch and batch_chunks + msg_chunks > batch_size: + filtered, skipped = await self._filter_ingested(storage, batch) + await _submit_batch(filtered, skipped) + batch = [] + batch_chunks = 0 batch.append(msg) - if len(batch) >= batch_size: - filtered = await self._filter_ingested(storage, batch) - await _submit_batch(filtered) + batch_chunks += msg_chunks + if batch_chunks >= batch_size: + filtered, skipped = await self._filter_ingested(storage, batch) + await _submit_batch(filtered, skipped) batch = [] + batch_chunks = 0 if batch: - filtered = await self._filter_ingested(storage, batch) - await _submit_batch(filtered) + filtered, skipped = await self._filter_ingested(storage, batch) + await _submit_batch(filtered, skipped) await _drain_commit() except BaseException: @@ -317,23 +342,26 @@ async def _filter_ingested( self, storage: IStorageProvider[TMessage], batch: list[TMessage], - ) -> list[TMessage]: + ) -> tuple[list[TMessage], int]: """Filter out messages whose source_id has already been ingested. - Safe to call while a pending_commit task exists: is_source_ingested - is a synchronous SELECT on SQLite's single connection, so it won't - interleave with the commit task's cursor operations in asyncio's - cooperative model. If the storage provider becomes truly async - (e.g. aiosqlite), this assumption needs revisiting. + Returns (filtered_messages, skipped_count). + + Uses a single batch query instead of per-message lookups. """ + source_ids = [m.source_id for m in batch if m.source_id is not None] + if source_ids: + ingested = await storage.are_sources_ingested(source_ids) + else: + ingested = set[str]() filtered: list[TMessage] = [] + skipped = 0 for msg in batch: - if msg.source_id is not None and await storage.is_source_ingested( - msg.source_id - ): + if msg.source_id is not None and msg.source_id in ingested: + skipped += 1 continue filtered.append(msg) - return filtered + return filtered, skipped async def _extract_knowledge_for_batch( self, diff --git a/src/typeagent/knowpro/interfaces_core.py b/src/typeagent/knowpro/interfaces_core.py index 87ef7329..10d1765c 100644 --- a/src/typeagent/knowpro/interfaces_core.py +++ b/src/typeagent/knowpro/interfaces_core.py @@ -93,6 +93,7 @@ class AddMessagesResult: messages_added: int = 0 chunks_added: int = 0 semrefs_added: int = 0 + messages_skipped: int = 0 # Messages are referenced by their sequential ordinal numbers. diff --git a/src/typeagent/knowpro/interfaces_storage.py b/src/typeagent/knowpro/interfaces_storage.py index e1ed18ee..9f17574d 100644 --- a/src/typeagent/knowpro/interfaces_storage.py +++ b/src/typeagent/knowpro/interfaces_storage.py @@ -181,6 +181,10 @@ async def is_source_ingested(self, source_id: str) -> bool: """Check if a source has already been ingested.""" ... + async def are_sources_ingested(self, source_ids: list[str]) -> set[str]: + """Return the subset of source_ids that have already been ingested.""" + ... + async def get_source_status(self, source_id: str) -> str | None: """Get the ingestion status of a source.""" ... diff --git a/src/typeagent/storage/memory/provider.py b/src/typeagent/storage/memory/provider.py index bb6231a2..e697fe01 100644 --- a/src/typeagent/storage/memory/provider.py +++ b/src/typeagent/storage/memory/provider.py @@ -159,6 +159,10 @@ async def is_source_ingested(self, source_id: str) -> bool: """ return source_id in self._ingested_sources + async def are_sources_ingested(self, source_ids: list[str]) -> set[str]: + """Return the subset of source_ids that have already been ingested.""" + return self._ingested_sources & set(source_ids) + async def get_source_status(self, source_id: str) -> str | None: """Get the ingestion status of a source. diff --git a/src/typeagent/storage/sqlite/provider.py b/src/typeagent/storage/sqlite/provider.py index cafbac0b..a8ba9c06 100644 --- a/src/typeagent/storage/sqlite/provider.py +++ b/src/typeagent/storage/sqlite/provider.py @@ -564,6 +564,26 @@ async def is_source_ingested(self, source_id: str) -> bool: row = cursor.fetchone() return row is not None and row[0] == STATUS_INGESTED + async def are_sources_ingested(self, source_ids: list[str]) -> set[str]: + """Return the subset of source_ids that have already been ingested.""" + if not source_ids: + return set() + cursor = self.db.cursor() + result: set[str] = set() + # Chunk to stay within SQLite's SQLITE_MAX_VARIABLE_NUMBER + # (999 on older builds, 32766 on 3.32.0+). + chunk_size = 500 + for i in range(0, len(source_ids), chunk_size): + chunk = source_ids[i : i + chunk_size] + placeholders = ",".join("?" for _ in chunk) + cursor.execute( + f"SELECT source_id FROM IngestedSources" + f" WHERE source_id IN ({placeholders}) AND status = ?", + [*chunk, STATUS_INGESTED], + ) + result.update(row[0] for row in cursor.fetchall()) + return result + async def get_source_status(self, source_id: str) -> str | None: """Get the ingestion status of a source. diff --git a/tests/test_add_messages_streaming.py b/tests/test_add_messages_streaming.py index 5707f71b..8a8f8db3 100644 --- a/tests/test_add_messages_streaming.py +++ b/tests/test_add_messages_streaming.py @@ -14,7 +14,7 @@ from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro import knowledge_schema as kplib from typeagent.knowpro.convsettings import ConversationSettings -from typeagent.knowpro.interfaces_core import IKnowledgeExtractor +from typeagent.knowpro.interfaces_core import AddMessagesResult, IKnowledgeExtractor from typeagent.storage.sqlite.provider import SqliteStorageProvider from typeagent.transcripts.transcript import ( Transcript, @@ -478,3 +478,343 @@ async def test_streaming_extraction_with_empty_text_chunks() -> None: assert extractor.call_count == 1 await storage.close() + + +# --------------------------------------------------------------------------- +# Multi-chunk messages and chunk-based batching +# --------------------------------------------------------------------------- + + +def _make_multi_chunk_message( + chunks: list[str], + speaker: str = "Alice", + source_id: str | None = None, +) -> TranscriptMessage: + return TranscriptMessage( + text_chunks=chunks, + metadata=TranscriptMessageMeta(speaker=speaker), + tags=["test"], + source_id=source_id, + ) + + +@pytest.mark.asyncio +async def test_streaming_multi_chunk_extraction() -> None: + """Each chunk in a multi-chunk message triggers a separate extraction call.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + extractor = ControlledExtractor() + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [ + _make_multi_chunk_message(["c0", "c1", "c2"], source_id="s-0"), + _make_message("single chunk", source_id="s-1"), + ] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + assert result.messages_added == 2 + assert result.chunks_added == 4 # 3 + 1 + # 4 extraction calls: one per chunk + assert extractor.call_count == 4 + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_batch_size_counts_chunks() -> None: + """batch_size counts chunks, not messages — a 3-chunk message fills batch_size=3.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [ + _make_multi_chunk_message(["a", "b", "c"], source_id="s-0"), # 3 chunks + _make_message("d", source_id="s-1"), # 1 chunk + ] + batch_results: list[int] = [] + result = await transcript.add_messages_streaming( + _async_iter(msgs), + batch_size=3, + on_batch_committed=lambda r: batch_results.append(r.messages_added), + ) + + assert result.messages_added == 2 + # First message (3 chunks) fills batch_size=3, second message goes to batch 2 + assert batch_results == [1, 1] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_large_message_exceeds_batch_size() -> None: + """A single message with more chunks than batch_size becomes its own batch.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [ + _make_multi_chunk_message( + [f"chunk-{i}" for i in range(5)], source_id="s-big" + ), + _make_message("small", source_id="s-small"), + ] + batch_results: list[int] = [] + result = await transcript.add_messages_streaming( + _async_iter(msgs), + batch_size=3, + on_batch_committed=lambda r: batch_results.append(r.messages_added), + ) + + assert result.messages_added == 2 + # 5-chunk msg in batch 1, then 1-chunk msg in batch 2 + assert batch_results == [1, 1] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_mixed_chunk_sizes_batching() -> None: + """Messages of varying chunk counts are batched by cumulative chunk count.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [ + _make_message("a", source_id="s-0"), # 1 chunk, total=1 + _make_multi_chunk_message( + ["b1", "b2"], source_id="s-1" + ), # 2 chunks, total=3 → flush + _make_message("c", source_id="s-2"), # 1 chunk, total=1 + _make_message("d", source_id="s-3"), # 1 chunk, total=2 + _make_message("e", source_id="s-4"), # 1 chunk, total=3 → flush + ] + batch_results: list[int] = [] + result = await transcript.add_messages_streaming( + _async_iter(msgs), + batch_size=3, + on_batch_committed=lambda r: batch_results.append(r.messages_added), + ) + + assert result.messages_added == 5 + assert result.chunks_added == 6 + # Batch 1: msgs 0+1 (3 chunks), Batch 2: msgs 2+3+4 (3 chunks) + assert batch_results == [2, 3] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_multi_chunk_failure_ordinals() -> None: + """Extraction failures in multi-chunk messages record correct ordinals.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + # Fail on call index 1 (chunk 1 of first message) and 3 (chunk 0 of second message) + extractor = ControlledExtractor(fail_on={1, 3}) + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [ + _make_multi_chunk_message( + ["c0", "c1", "c2"], source_id="s-0" + ), # calls 0,1,2 + _make_multi_chunk_message(["d0", "d1"], source_id="s-1"), # calls 3,4 + ] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + assert result.messages_added == 2 + assert extractor.call_count == 5 + assert _failure_count(storage) == 2 + + failures = await storage.get_chunk_failures() + failure_locs = sorted((f.message_ordinal, f.chunk_ordinal) for f in failures) + # call 1 → msg 0, chunk 1; call 3 → msg 1, chunk 0 + assert failure_locs == [(0, 1), (1, 0)] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_multi_chunk_exception_preserves_earlier_batch() -> None: + """Exception during extraction of multi-chunk batch preserves committed batches.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + # Batch 1: 3-chunk msg (calls 0,1,2). Batch 2: 2-chunk msg (calls 3,4) — raise on 3 + extractor = ControlledExtractor(raise_on={3}) + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [ + _make_multi_chunk_message(["a", "b", "c"], source_id="s-0"), # batch 1 + _make_multi_chunk_message(["d", "e"], source_id="s-1"), # batch 2 + ] + + with pytest.raises(ExceptionGroup): + await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3) + + # Batch 1 committed (1 message, 3 chunks), batch 2 rolled back + assert await transcript.messages.size() == 1 + assert _ingested_count(storage) == 1 + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_multi_chunk_skip_and_ingest_mixed() -> None: + """Multi-chunk messages are skipped or ingested as a whole based on source_id.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + extractor = ControlledExtractor() + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + # Pre-mark s-1 as ingested + async with storage: + await storage.mark_source_ingested("s-1") + + msgs = [ + _make_multi_chunk_message(["a", "b"], source_id="s-0"), # ingested + _make_multi_chunk_message(["c", "d", "e"], source_id="s-1"), # skipped + _make_message("f", source_id="s-2"), # ingested + ] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + assert result.messages_added == 2 + assert result.messages_skipped == 1 + assert result.chunks_added == 3 # 2 + 1 (not 5) + # Only 3 extraction calls (2 chunks from s-0 + 1 chunk from s-2) + assert extractor.call_count == 3 + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_batch_size_1_separates_all() -> None: + """batch_size=1 commits every single-chunk message individually.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(4)] + batch_results: list[int] = [] + result = await transcript.add_messages_streaming( + _async_iter(msgs), + batch_size=1, + on_batch_committed=lambda r: batch_results.append(r.messages_added), + ) + + assert result.messages_added == 4 + assert batch_results == [1, 1, 1, 1] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_callback_reports_skipped_multi_chunk() -> None: + """on_batch_committed reports skipped count for batches with multi-chunk messages.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + # Pre-mark s-1 as ingested + async with storage: + await storage.mark_source_ingested("s-1") + + msgs = [ + _make_multi_chunk_message(["a", "b"], source_id="s-0"), # 2 chunks + _make_multi_chunk_message(["c", "d"], source_id="s-1"), # 2 chunks, skipped + _make_message("e", source_id="s-2"), # 1 chunk → total = 5 chunks in batch + ] + callback_results: list[tuple[int, int]] = [] + result = await transcript.add_messages_streaming( + _async_iter(msgs), + batch_size=10, # all fit in one batch + on_batch_committed=lambda r: callback_results.append( + (r.messages_added, r.messages_skipped) + ), + ) + + assert result.messages_added == 2 + assert result.messages_skipped == 1 + assert callback_results == [(2, 1)] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_preflush_avoids_oversized_batch() -> None: + """Adding a message that would exceed batch_size flushes first. + + With batch_size=10 and four 3-chunk messages, batches should be + [msg0,msg1,msg2] (9 chunks) and [msg3] (3 chunks) — never a single + batch of 12 chunks. + """ + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [ + _make_multi_chunk_message( + [f"m{i}c{j}" for j in range(3)], source_id=f"s-{i}" + ) + for i in range(4) + ] + batch_chunks: list[int] = [] + result = await transcript.add_messages_streaming( + _async_iter(msgs), + batch_size=10, + on_batch_committed=lambda r: batch_chunks.append(r.chunks_added), + ) + + assert result.messages_added == 4 + assert result.chunks_added == 12 + # Batch 1: 3 msgs × 3 chunks = 9, Batch 2: 1 msg × 3 chunks = 3 + assert batch_chunks == [9, 3] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_all_skipped_batch_after_real_batch() -> None: + """A batch of all-duplicates reports skipped correctly. + + First call ingests messages s-0..s-2. Second call re-submits the same + source_ids — they should all be filtered by _filter_ingested, exercising + the all-skipped + pending_commit path. + """ + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(3)] + + # First call — ingest originals + result1 = await transcript.add_messages_streaming( + _async_iter(msgs), + batch_size=3, + ) + assert result1.messages_added == 3 + assert result1.messages_skipped == 0 + + # Second call — all duplicates + dupes = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(3)] + batch_results: list[AddMessagesResult] = [] + result2 = await transcript.add_messages_streaming( + _async_iter(dupes), + batch_size=3, + on_batch_committed=lambda r: batch_results.append(r), + ) + + assert result2.messages_added == 0 + assert result2.messages_skipped == 3 + + # One callback for the all-skipped batch + assert len(batch_results) == 1 + assert batch_results[0].messages_added == 0 + assert batch_results[0].messages_skipped == 3 + + await storage.close() diff --git a/tools/ingest_email.py b/tools/ingest_email.py index 74a259ea..68784f3a 100644 --- a/tools/ingest_email.py +++ b/tools/ingest_email.py @@ -17,25 +17,17 @@ python tools/query.py --database email.db --query "What was discussed?" """ -""" -TODO - -- Collect knowledge outside db transaction to reduce lock time -""" - import argparse import asyncio +from collections.abc import AsyncIterator from datetime import datetime from pathlib import Path import sys import time -import traceback from typing import Iterable from dotenv import load_dotenv -import openai - from typeagent.aitools import utils from typeagent.emails.email_import import ( decode_encoded_words, @@ -45,6 +37,8 @@ from typeagent.emails.email_memory import EmailMemory from typeagent.emails.email_message import EmailMessage from typeagent.knowpro.convsettings import ConversationSettings +from typeagent.knowpro.interfaces_core import AddMessagesResult +from typeagent.knowpro.interfaces_storage import IStorageProvider from typeagent.storage.utils import create_storage_provider @@ -134,6 +128,35 @@ def create_arg_parser() -> argparse.ArgumentParser: ), ) + # Concurrency / batching + parser.add_argument( + "--concurrency", + type=int, + default=None, + metavar="N", + help=( + "Number of concurrent LLM extraction requests. " + "Default: 4 (from ConversationSettings)." + ), + ) + parser.add_argument( + "--batch-size", + type=int, + default=100, + metavar="N", + help="Number of chunks per commit batch. Default: 100.", + ) + parser.add_argument( + "--max-chunks", + type=int, + default=20, + metavar="N", + help=( + "Maximum number of text chunks to keep per email. " + "Extra chunks are silently dropped. Default: 20." + ), + ) + return parser @@ -149,8 +172,17 @@ def _validate_args(args: argparse.Namespace) -> None: if args.limit is not None and args.limit <= 0: errors.append("--limit must be a positive integer.") - # --offset without --limit is allowed (skip first N, ingest the rest) - # --limit without --offset is allowed (ingest at most N) + # --concurrency must be positive when given + if args.concurrency is not None and args.concurrency <= 0: + errors.append("--concurrency must be a positive integer.") + + # --batch-size must be positive + if args.batch_size <= 0: + errors.append("--batch-size must be a positive integer.") + + # --max-chunks must be positive when given + if args.max_chunks is not None and args.max_chunks <= 0: + errors.append("--max-chunks must be a positive integer.") # --start-date must be before --stop-date when both are given if args.start_date and args.stop_date: @@ -239,7 +271,7 @@ def _iter_emails( sliced_total = len(email_files) for i, email_file in enumerate(email_files): label = f"[{i + 1}/{sliced_total}] {email_file}" - yield str(email_file), email_file, label + yield str(email_file.resolve()), email_file, label def _print_email_verbose(email: EmailMessage) -> None: @@ -267,6 +299,103 @@ def _print_email_verbose(email: EmailMessage) -> None: print(f" {preview}") +def _flush_skipped( + skip_first: str | None, + skip_last: str | None, + skip_count: int, +) -> None: + """Print a summary line for a contiguous range of skipped files.""" + if skip_count == 0: + return + assert skip_first is not None and skip_last is not None + if skip_count == 1: + print(f" Skipped {skip_first} (already ingested)") + elif skip_first == skip_last: + print(f" Skipped {skip_first} x{skip_count} (already ingested)") + else: + print( + f" Skipped {skip_first} .. {skip_last}" f" ({skip_count} already ingested)" + ) + + +async def _email_generator( + eml_paths: list[str], + verbose: bool, + start_date: datetime | None, + stop_date: datetime | None, + offset: int, + limit: int | None, + max_chunks: int | None, + counters: dict[str, int], + storage: IStorageProvider[EmailMessage], +) -> AsyncIterator[EmailMessage]: + """Async generator that parses and yields EmailMessage objects. + + Checks each file's source_id against already-ingested sources + *before* parsing, so batches contain only new messages. + + *counters* is mutated in place to track ``parsed``, ``skipped``, + and ``failed`` counts for the caller's summary. + """ + skip_first: str | None = None + skip_last: str | None = None + skip_count = 0 + + for source_id, email_file, label in _iter_emails(eml_paths, verbose, offset, limit): + # Pre-parse dedup: skip before opening the file. + # A second dedup pass happens in _filter_ingested() to catch + # sources committed by an earlier batch in the same run. + if await storage.is_source_ingested(source_id): + counters["skipped"] += 1 + basename = email_file.name + if skip_first is None: + skip_first = basename + skip_last = basename + skip_count += 1 + continue + + # Flush any pending skip summary before processing a new file + _flush_skipped(skip_first, skip_last, skip_count) + skip_first = None + skip_last = None + skip_count = 0 + + try: + email = import_email_from_file(str(email_file)) + except Exception as e: + counters["failed"] += 1 + print( + f"Error parsing {source_id}: {e!r:.150s}", + file=sys.stderr, + ) + continue + + # Apply date filter + if not email_matches_date_filter(email.timestamp, start_date, stop_date): + counters["date_skipped"] += 1 + if verbose: + print(f"{label} [Outside date range, skipping]") + continue + + if verbose: + print(label) + _print_email_verbose(email) + + # Truncate chunks if --max-chunks is set + if max_chunks is not None and len(email.text_chunks) > max_chunks: + if verbose: + print(f" Truncating {len(email.text_chunks)} chunks to {max_chunks}") + email.text_chunks = email.text_chunks[:max_chunks] + + # Set source_id so streaming API handles dedup and tracking + email.source_id = source_id + counters["parsed"] += 1 + yield email + + # Flush any remaining skip summary at end of iteration + _flush_skipped(skip_first, skip_last, skip_count) + + async def ingest_emails( eml_paths: list[str], database: str, @@ -275,6 +404,9 @@ async def ingest_emails( stop_date: datetime | None = None, offset: int = 0, limit: int | None = None, + concurrency: int | None = None, + batch_size: int = 100, + max_chunks: int | None = 20, ) -> None: """Ingest email files into a database.""" @@ -288,6 +420,11 @@ async def ingest_emails( print("Setting up conversation settings...") settings = ConversationSettings() + + # Override concurrency if specified + if concurrency is not None: + settings.semantic_ref_index_settings.concurrency = concurrency + settings.storage_provider = await create_storage_provider( settings.message_text_index_settings, settings.related_term_index_settings, @@ -301,97 +438,113 @@ async def ingest_emails( if verbose: print(f"Target database: {database}") - concurrency = settings.semantic_ref_index_settings.concurrency + effective_concurrency = settings.semantic_ref_index_settings.concurrency if verbose: - print(f"Concurrency: {concurrency}") + print(f"Concurrency: {effective_concurrency}") + print(f"Batch size: {batch_size} chunks") print("\nParsing and importing emails...") - success_count = 0 - failed_count = 0 - skipped_count = 0 start_time = time.time() + last_batch_time = start_time + + # Counters mutated by the generator and callback + counters: dict[str, int] = { + "parsed": 0, + "skipped": 0, + "date_skipped": 0, + "failed": 0, + "ingested": 0, + "chunks": 0, + "semrefs": 0, + "batches": 0, + } + + def on_batch_committed(result: AddMessagesResult) -> None: + nonlocal last_batch_time + counters["ingested"] += result.messages_added + counters["chunks"] += result.chunks_added + counters["semrefs"] += result.semrefs_added + counters["batches"] += 1 + now = time.time() + batch_secs = now - last_batch_time + last_batch_time = now + elapsed = now - start_time + per_chunk = batch_secs / result.chunks_added if result.chunks_added else 0 + print( + f" Batch {counters['batches']}: " + f"+{result.messages_added} messages, " + f"+{result.chunks_added} chunks, " + f"+{result.semrefs_added} semrefs | " + f"{batch_secs:.1f}s ({per_chunk:.2f}s/chunk) | " + f"{counters['ingested']} total ingested | " + f"{elapsed:.1f}s elapsed", + flush=True, + ) - semref_coll = settings.storage_provider.semantic_refs - storage_provider = settings.storage_provider - - for source_id, email_file, label in _iter_emails(eml_paths, verbose, offset, limit): - try: - if verbose: - print(label, end="", flush=True) - - # Parse the email only after confirming it hasn't been ingested - email = import_email_from_file(str(email_file)) - - # Apply date filter - if not email_matches_date_filter(email.timestamp, start_date, stop_date): - skipped_count += 1 - if verbose: - print(" [Outside date range, skipping]") - continue - - if verbose: - _print_email_verbose(email) - - # Ingest the email - try: - await email_memory.add_messages_with_indexing( - [email], source_ids=[source_id] - ) - success_count += 1 - except openai.AuthenticationError as e: - if verbose: - traceback.print_exc() - sys.exit(f"Authentication error: {e!r}") - - # Print progress periodically - if concurrency and (success_count + failed_count) % concurrency == 0: - elapsed = time.time() - start_time - semref_count = await semref_coll.size() - print( - f"\n{label} " - f"{success_count} imported | " - f"{failed_count} failed | " - f"{skipped_count} skipped | " - f"{semref_count} semrefs | " - f"{elapsed:.1f}s elapsed\n" - ) + message_stream = _email_generator( + eml_paths, + verbose, + start_date, + stop_date, + offset, + limit, + max_chunks, + counters, + settings.storage_provider, + ) - except Exception as e: - failed_count += 1 - print( - f"Error processing {source_id}: {e!r:.150s}", - file=sys.stderr, - ) - mod = e.__class__.__module__ - qual = e.__class__.__qualname__ - exc_name = qual if mod == "builtins" else f"{mod}.{qual}" - async with storage_provider: - await storage_provider.mark_source_ingested(source_id, exc_name) - if verbose: - traceback.print_exc(limit=10) + result: AddMessagesResult | None = None + interrupted = False + try: + result = await email_memory.add_messages_streaming( + message_stream, + batch_size=batch_size, + on_batch_committed=on_batch_committed, + ) + except (KeyboardInterrupt, asyncio.CancelledError): + interrupted = True # Final summary elapsed = time.time() - start_time - semref_count = await semref_coll.size() + if interrupted and counters["batches"] == 0: + print() + print("Interrupted before any batches were committed.") + return + + messages_ingested = ( + result.messages_added if result is not None else counters["ingested"] + ) + total_chunks = result.chunks_added if result is not None else counters["chunks"] + semrefs_added = result.semrefs_added if result is not None else counters["semrefs"] + overall_per_chunk = elapsed / total_chunks if total_chunks else 0 print() if verbose: - print(f"Successfully imported {success_count} email(s)") - if skipped_count: - print(f"Skipped {skipped_count} already-ingested email(s)") - if failed_count: - print(f"Failed to import {failed_count} email(s)") - print(f"Extracted {semref_count} semantic references") + if interrupted: + print("Ingestion interrupted by user (^C).") + print(f"Successfully ingested {messages_ingested} email(s)") + print(f"Ingested {total_chunks} chunk(s)") + if counters["skipped"]: + print(f"Skipped {counters['skipped']} already-ingested email(s)") + if counters["date_skipped"]: + print(f"Skipped {counters['date_skipped']} email(s) outside date range") + if counters["failed"]: + print(f"Failed to parse {counters['failed']} email(s)") + print(f"Extracted {semrefs_added} semantic references") print(f"Total time: {elapsed:.1f}s") + print(f"Overall time per chunk: {overall_per_chunk:.2f}s/chunk") else: print( - f"Imported {success_count} emails to {database} " - f"({semref_count} refs, {elapsed:.1f}s)" + f"Ingested {messages_ingested} emails to {database} " + f"({total_chunks} chunks, {semrefs_added} refs added, {elapsed:.1f}s, " + f"{overall_per_chunk:.2f}s/chunk)" ) - if skipped_count: - print(f"Skipped: {skipped_count} (already ingested)") - if failed_count: - print(f"Failed: {failed_count}") + if counters["skipped"]: + print(f"Skipped: {counters['skipped']} (already ingested)") + if counters["date_skipped"]: + print(f"Skipped: {counters['date_skipped']} (outside date range)") + if counters["failed"]: + print(f"Failed: {counters['failed']}") # Show usage information print() @@ -419,6 +572,9 @@ def main() -> None: stop_date=stop_date, offset=args.offset, limit=args.limit, + concurrency=args.concurrency, + batch_size=args.batch_size, + max_chunks=args.max_chunks, ) )