Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/typeagent/knowpro/conversation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""Base class for conversations with incremental indexing support."""

import asyncio
import contextlib
from collections.abc import AsyncIterable, Callable, Sequence
import contextlib
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Generic, Self, TypeVar
Expand Down Expand Up @@ -282,9 +282,7 @@ async def _submit_batch(filtered: list[TMessage]) -> None:
await _drain_commit()

# Await extraction result for this batch
extraction = (
await next_extraction if next_extraction is not None else None
)
extraction = await next_extraction if next_extraction is not None else None

# Start commit (DB transaction) — runs concurrently with the
# *next* batch's LLM extraction once we yield back to the loop.
Expand Down
1 change: 0 additions & 1 deletion src/typeagent/knowpro/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from typechat import Result

from . import convknowledge
from . import knowledge_schema as kplib
from .interfaces import IKnowledgeExtractor

Expand Down
58 changes: 32 additions & 26 deletions src/typeagent/storage/memory/semrefindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ async def add_batch_to_semantic_ref_index[
(tl.message_ordinal, tl.chunk_ordinal, knowledge_result.value)
)
if bulk_items:
await add_knowledge_batch_to_semantic_ref_index(
conversation, bulk_items
)
await add_knowledge_batch_to_semantic_ref_index(conversation, bulk_items)


async def add_batch_to_semantic_ref_index_from_list[
Expand All @@ -103,9 +101,7 @@ async def add_batch_to_semantic_ref_index_from_list[
f"Message ordinal {tl.message_ordinal} out of range "
f"for list starting at {start_ordinal}"
)
text_batch.append(
messages[list_index].text_chunks[tl.chunk_ordinal].strip()
)
text_batch.append(messages[list_index].text_chunks[tl.chunk_ordinal].strip())

knowledge_results = await extract_knowledge_from_text_batch(
knowledge_extractor,
Expand All @@ -123,9 +119,7 @@ async def add_batch_to_semantic_ref_index_from_list[
(tl.message_ordinal, tl.chunk_ordinal, knowledge_result.value)
)
if bulk_items:
await add_knowledge_batch_to_semantic_ref_index(
conversation, bulk_items
)
await add_knowledge_batch_to_semantic_ref_index(conversation, bulk_items)


async def add_term_to_index(
Expand Down Expand Up @@ -360,11 +354,13 @@ def _collect_knowledge_refs_and_terms(
for entity in knowledge.entities:
if not validate_entity(entity):
continue
refs.append(SemanticRef(
semantic_ref_ordinal=ordinal,
range=text_range,
knowledge=entity,
))
refs.append(
SemanticRef(
semantic_ref_ordinal=ordinal,
range=text_range,
knowledge=entity,
)
)
terms.append((entity.name, ordinal))
for type_name in entity.type:
terms.append((type_name, ordinal))
Expand All @@ -377,11 +373,13 @@ def _collect_knowledge_refs_and_terms(
ordinal += 1

for action in list(knowledge.actions) + list(knowledge.inverse_actions):
refs.append(SemanticRef(
semantic_ref_ordinal=ordinal,
range=text_range,
knowledge=action,
))
refs.append(
SemanticRef(
semantic_ref_ordinal=ordinal,
range=text_range,
knowledge=action,
)
)
terms.append((" ".join(action.verbs), ordinal))
if action.subject_entity_name != "none":
terms.append((action.subject_entity_name, ordinal))
Expand All @@ -404,11 +402,13 @@ def _collect_knowledge_refs_and_terms(
ordinal += 1

for topic_text in knowledge.topics:
refs.append(SemanticRef(
semantic_ref_ordinal=ordinal,
range=text_range,
knowledge=Topic(text=topic_text),
))
refs.append(
SemanticRef(
semantic_ref_ordinal=ordinal,
range=text_range,
knowledge=Topic(text=topic_text),
)
)
terms.append((topic_text, ordinal))
ordinal += 1

Expand All @@ -431,7 +431,10 @@ async def add_knowledge_to_semantic_ref_index(

base_ordinal = await semantic_refs.size()
refs, terms = _collect_knowledge_refs_and_terms(
base_ordinal, message_ordinal, chunk_ordinal, knowledge,
base_ordinal,
message_ordinal,
chunk_ordinal,
knowledge,
)

if refs:
Expand Down Expand Up @@ -460,7 +463,10 @@ async def add_knowledge_batch_to_semantic_ref_index(

for msg_ord, chunk_ord, knowledge in items:
refs, terms = _collect_knowledge_refs_and_terms(
base_ordinal + len(all_refs), msg_ord, chunk_ord, knowledge,
base_ordinal + len(all_refs),
msg_ord,
chunk_ord,
knowledge,
)
all_refs.extend(refs)
all_terms.extend(terms)
Expand Down
4 changes: 1 addition & 3 deletions tests/test_add_messages_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,7 @@ async def test_streaming_exception_in_later_batch_preserves_earlier() -> None:

msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)]
with pytest.raises(ExceptionGroup) as exc_info:
await transcript.add_messages_streaming(
_async_iter(msgs), batch_size=3
)
await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3)

assert any(
isinstance(e, RuntimeError) and "Systemic failure" in str(e)
Expand Down
47 changes: 33 additions & 14 deletions tools/benchmark_semref_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@
TranscriptMessageMeta,
)


# ---------------------------------------------------------------------------
# Inlined pre-optimization write path (one append + add_term per item)
# ---------------------------------------------------------------------------


async def _individual_add_knowledge(
conversation, message_ordinal, chunk_ordinal, knowledge,
conversation,
message_ordinal,
chunk_ordinal,
knowledge,
):
"""Reproduces the pre-optimization per-item write logic."""
verify_has_semantic_ref_index(conversation)
Expand Down Expand Up @@ -95,7 +97,9 @@ async def _individual_add_knowledge(
if action.object_entity_name != "none":
await semantic_ref_index.add_term(action.object_entity_name, ordinal)
if action.indirect_object_entity_name != "none":
await semantic_ref_index.add_term(action.indirect_object_entity_name, ordinal)
await semantic_ref_index.add_term(
action.indirect_object_entity_name, ordinal
)
if action.params:
for param in action.params:
if isinstance(param, str):
Expand Down Expand Up @@ -135,8 +139,7 @@ def synthetic_knowledge(chunk_index: int) -> kplib.KnowledgeResponse:
name=f"entity_{chunk_index}_{j}",
type=[f"type_{j}", f"category_{chunk_index % 5}"],
facets=[
kplib.Facet(name=f"facet_{j}", value=f"value_{j}")
for j in range(2)
kplib.Facet(name=f"facet_{j}", value=f"value_{j}") for j in range(2)
],
)
for j in range(3)
Expand Down Expand Up @@ -237,15 +240,21 @@ async def main() -> None:
description="Benchmark semref index write strategies.",
)
parser.add_argument(
"--chunks", type=int, default=50,
"--chunks",
type=int,
default=50,
help="Number of knowledge chunks to write per run (default: 50).",
)
parser.add_argument(
"--rounds", type=int, default=10,
"--rounds",
type=int,
default=10,
help="Number of timed rounds (default: 10).",
)
parser.add_argument(
"--warmup", type=int, default=2,
"--warmup",
type=int,
default=2,
help="Number of untimed warmup rounds (default: 2).",
)
args = parser.parse_args()
Expand All @@ -262,21 +271,31 @@ async def main() -> None:
print(f"Total semrefs per run: ~{refs_per_chunk * args.chunks}")

individual = await run_benchmark(
"Individual writes", bench_individual,
args.chunks, args.rounds, args.warmup,
"Individual writes",
bench_individual,
args.chunks,
args.rounds,
args.warmup,
)
print_report(
"Individual writes (per-entity append + add_term)",
individual, args.rounds, args.warmup,
individual,
args.rounds,
args.warmup,
)

batched = await run_benchmark(
"Batched writes", bench_batched,
args.chunks, args.rounds, args.warmup,
"Batched writes",
bench_batched,
args.chunks,
args.rounds,
args.warmup,
)
print_report(
"Batched writes (bulk extend + add_terms_batch)",
batched, args.rounds, args.warmup,
batched,
args.rounds,
args.warmup,
)

speedup = statistics.fmean(individual) / statistics.fmean(batched)
Expand Down
Loading