Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/typeagent/knowpro/conversation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def _accumulate(result: AddMessagesResult) -> None:
on_batch_committed(result)

pending_commit: asyncio.Task[AddMessagesResult] | None = None
pending_extraction: asyncio.Task[_ExtractionResult | None] | None = None
pending_skipped: int = 0

async def _drain_commit() -> None:
Expand All @@ -275,7 +276,7 @@ async def _drain_commit() -> None:
pending_skipped = 0

async def _submit_batch(filtered: list[TMessage], skipped: int) -> None:
nonlocal pending_commit, pending_skipped
nonlocal pending_commit, pending_extraction, pending_skipped
if not filtered and not skipped:
return

Expand All @@ -285,6 +286,7 @@ async def _submit_batch(filtered: list[TMessage], skipped: int) -> None:
)
else:
next_extraction = None
pending_extraction = next_extraction

# Wait for previous commit to finish (frees the DB connection)
await _drain_commit()
Expand All @@ -298,6 +300,7 @@ async def _submit_batch(filtered: list[TMessage], skipped: int) -> None:

# Await extraction result for this batch
extraction = await next_extraction if next_extraction is not None else None
pending_extraction = None

# Start commit (DB transaction) — runs concurrently with the
# *next* batch's LLM extraction once we yield back to the loop.
Expand Down Expand Up @@ -330,6 +333,10 @@ async def _submit_batch(filtered: list[TMessage], skipped: int) -> None:

await _drain_commit()
except BaseException:
if pending_extraction is not None and not pending_extraction.done():
pending_extraction.cancel()
with contextlib.suppress(asyncio.CancelledError):
await pending_extraction
if pending_commit is not None and not pending_commit.done():
pending_commit.cancel()
with contextlib.suppress(asyncio.CancelledError):
Expand Down
169 changes: 169 additions & 0 deletions tests/test_add_messages_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Tests for add_messages_streaming."""

import asyncio
from collections.abc import AsyncIterator
import os
import tempfile
Expand Down Expand Up @@ -818,3 +819,171 @@ async def test_streaming_all_skipped_batch_after_real_batch() -> None:
assert batch_results[0].messages_skipped == 3

await storage.close()


# ---------------------------------------------------------------------------
# Coverage gap tests
# ---------------------------------------------------------------------------


class SlowExtractor:
"""Extractor that blocks on an event, allowing tests to control timing."""

def __init__(self, block_from: int) -> None:
self.call_count = 0
self.block_from = block_from
self.blocked = asyncio.Event()
self.cancelled = False

async def extract(self, message: str) -> typechat.Result[kplib.KnowledgeResponse]:
idx = self.call_count
self.call_count += 1
if idx >= self.block_from:
self.blocked.set()
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
self.cancelled = True
raise
return typechat.Success(_EMPTY_RESPONSE)


@pytest.mark.asyncio
async def test_streaming_pending_extraction_cancelled_on_commit_failure() -> None:
"""pending_extraction is cancelled when a prior commit raises during _drain_commit.

Timeline:
1. Batch 0: extraction succeeds (calls 0-2, fast), commit task created
(pending_commit = failing_commit)
2. Batch 1: extraction task created (pending_extraction, calls 3+, slow),
_drain_commit awaits batch 0's pending_commit which raises
3. except block: pending_extraction (batch 1's) is still in-flight → cancelled
"""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "test.db")
# Block extraction starting from call 3 (first call of batch 1)
# so that pending_extraction is still running when the except fires
extractor = SlowExtractor(block_from=3)
transcript, storage = await _create_transcript(
db_path, auto_extract=True, knowledge_extractor=extractor
)

async def failing_commit(*args, **kwargs):
raise RuntimeError("Simulated commit failure")

transcript._commit_batch_streaming = failing_commit # type: ignore[assignment]

msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)]

with pytest.raises(RuntimeError, match="Simulated commit failure"):
await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3)

assert extractor.cancelled

await storage.close()


@pytest.mark.asyncio
async def test_streaming_pending_commit_cancelled_on_iterator_error() -> None:
"""pending_commit is cancelled when the message iterator raises.

After batch 0 is submitted (pending_commit in flight), the async iterator
raises on the next message. The except block must cancel the still-running
pending_commit.
"""

async def _error_after(
items: list[TranscriptMessage], error_after: int
) -> AsyncIterator[TranscriptMessage]:
for i, item in enumerate(items):
if i == error_after:
# Yield to event loop so pending tasks start running
await asyncio.sleep(0)
raise ValueError("Iterator error")
yield item

with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "test.db")
transcript, storage = await _create_transcript(db_path)

commit_cancelled = False

async def slow_commit(*args, **kwargs):
nonlocal commit_cancelled
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
commit_cancelled = True
raise
return AddMessagesResult()

transcript._commit_batch_streaming = slow_commit # type: ignore[assignment]

msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)]

with pytest.raises(ValueError, match="Iterator error"):
await transcript.add_messages_streaming(
_error_after(msgs, error_after=4), batch_size=3
)

assert commit_cancelled

await storage.close()


@pytest.mark.asyncio
async def test_streaming_empty_batch_after_filter() -> None:
"""Streaming with an empty iterator after a real batch returns zeros."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "test.db")
transcript, storage = await _create_transcript(db_path)

# Ingest one real message, then do a second call with an empty iterator
msgs = [_make_message("msg-0", source_id="s-0")]
r1 = await transcript.add_messages_streaming(_async_iter(msgs))
assert r1.messages_added == 1

# Empty iterator → _submit_batch never called with content
r2 = await transcript.add_messages_streaming(_async_iter([]))
assert r2.messages_added == 0
assert r2.messages_skipped == 0

await storage.close()


@pytest.mark.asyncio
async def test_streaming_extraction_returns_none_for_empty_chunks() -> None:
"""_extract_knowledge_for_batch returns None when no text_locations exist.

Messages with empty text_chunks produce no TextLocations, so extraction
should be skipped entirely.
"""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "test.db")
extractor = ControlledExtractor()
transcript, storage = await _create_transcript(
db_path, auto_extract=True, knowledge_extractor=extractor
)

msgs = [
TranscriptMessage(
text_chunks=[],
metadata=TranscriptMessageMeta(speaker="Alice"),
tags=["test"],
source_id="empty-0",
),
TranscriptMessage(
text_chunks=[],
metadata=TranscriptMessageMeta(speaker="Bob"),
tags=["test"],
source_id="empty-1",
),
]
result = await transcript.add_messages_streaming(_async_iter(msgs))

assert result.messages_added == 2
assert result.chunks_added == 0
# No extraction calls since there are no chunks
assert extractor.call_count == 0

await storage.close()
11 changes: 7 additions & 4 deletions tools/ingest_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ async def ingest_emails(
counters: dict[str, int] = {
"parsed": 0,
"skipped": 0,
"batch_skipped": 0,
"date_skipped": 0,
"failed": 0,
"ingested": 0,
Expand All @@ -462,6 +463,7 @@ async def ingest_emails(
def on_batch_committed(result: AddMessagesResult) -> None:
nonlocal last_batch_time
counters["ingested"] += result.messages_added
counters["batch_skipped"] += result.messages_skipped
counters["chunks"] += result.chunks_added
counters["semrefs"] += result.semrefs_added
counters["batches"] += 1
Expand Down Expand Up @@ -516,6 +518,7 @@ def on_batch_committed(result: AddMessagesResult) -> None:
)
total_chunks = result.chunks_added if result is not None else counters["chunks"]
semrefs_added = result.semrefs_added if result is not None else counters["semrefs"]
total_skipped = counters["skipped"] + counters["batch_skipped"]
overall_per_chunk = elapsed / total_chunks if total_chunks else 0

print()
Expand All @@ -524,8 +527,8 @@ def on_batch_committed(result: AddMessagesResult) -> None:
print("Ingestion interrupted by user (^C).")
print(f"Successfully ingested {messages_ingested} email(s)")
print(f"Ingested {total_chunks} chunk(s)")
if counters["skipped"]:
print(f"Skipped {counters['skipped']} already-ingested email(s)")
if total_skipped:
print(f"Skipped {total_skipped} already-ingested email(s)")
if counters["date_skipped"]:
print(f"Skipped {counters['date_skipped']} email(s) outside date range")
if counters["failed"]:
Expand All @@ -539,8 +542,8 @@ def on_batch_committed(result: AddMessagesResult) -> None:
f"({total_chunks} chunks, {semrefs_added} refs added, {elapsed:.1f}s, "
f"{overall_per_chunk:.2f}s/chunk)"
)
if counters["skipped"]:
print(f"Skipped: {counters['skipped']} (already ingested)")
if total_skipped:
print(f"Skipped: {total_skipped} (already ingested)")
if counters["date_skipped"]:
print(f"Skipped: {counters['date_skipped']} (outside date range)")
if counters["failed"]:
Expand Down
Loading