Skip to content

Commit b32099d

Browse files
committed
memory: HybridRetriever for BM25 + dense RRF retrieval
1 parent 9c4a6f4 commit b32099d

2 files changed

Lines changed: 377 additions & 0 deletions

File tree

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
/**
2+
* @file HybridRetriever.ts
3+
* @description Hybrid BM25 + dense retrieval for memory-domain traces.
4+
* Dense side uses {@link MemoryStore} (preserves 6-signal cognitive
5+
* scoring). Sparse side uses a per-instance {@link BM25Index}. RRF
6+
* merges by rank. Optional {@link RerankerService} runs over the
7+
* merged pool.
8+
*
9+
* ## What this does
10+
*
11+
* Given a query, runs dense retrieval through `MemoryStore.query`
12+
* (cognitive-scored traces) and sparse retrieval through an owned
13+
* `BM25Index` (keyword-matched trace content). Fuses the two ranked
14+
* lists via Reciprocal Rank Fusion, optionally reranks the merged
15+
* pool with a neural cross-encoder, and returns a standard
16+
* `CognitiveRetrievalResult` so downstream consumers (prompt
17+
* assembly, bench adapters) don't change shape.
18+
*
19+
* ## Why a separate class, not a `CognitiveMemoryManager` option
20+
*
21+
* Keeps the existing manager retrieval path untouched (same reason
22+
* as SessionRetriever in Step 2). MVP ships as opt-in.
23+
*
24+
* ## Rerank integration is mandatory-wired from the bench
25+
*
26+
* Per the Step 2 post-mortem (rerank-skip was the root cause of
27+
* that step's RED verdict), Step 3 threads rerank from day 1 when
28+
* the bench is configured with `--rerank cohere`. Callers outside
29+
* the bench can pass `undefined` for `rerankerService` to skip
30+
* rerank explicitly.
31+
*
32+
* ## Sparse-only documents are skipped in MVP
33+
*
34+
* A document that appears in `bm25.search` results but NOT in the
35+
* dense over-fetch pool is skipped. Rationale: at the default
36+
* over-fetch=3 and K=10 (30 dense candidates), a doc ranked top-30
37+
* on sparse is very likely in dense's top-30 on any coherent query.
38+
* Measured impact is expected to be negligible. If Tier A surfaces
39+
* a meaningful drop rate, the fix is to add a
40+
* `memoryStore.getTrace(id)` hydration path.
41+
*
42+
* @module agentos/memory/retrieval/hybrid/HybridRetriever
43+
*/
44+
45+
import { BM25Index, type BM25Config } from '../../../rag/search/BM25Index.js';
46+
import { reciprocalRankFusion, type RankedDoc } from './reciprocalRankFusion.js';
47+
import type { MemoryStore } from '../store/MemoryStore.js';
48+
import type { RerankerService } from '../../../rag/reranking/RerankerService.js';
49+
import type {
50+
CognitiveRetrievalResult,
51+
MemoryScope,
52+
ScoredMemoryTrace,
53+
} from '../../core/types.js';
54+
import type { PADState } from '../../core/config.js';
55+
56+
/**
57+
* Options for constructing a {@link HybridRetriever}.
58+
*/
59+
export interface HybridRetrieverOptions {
60+
memoryStore: MemoryStore;
61+
/** BM25 config (k1, b, optional tokenizer). Defaults match BM25Index. */
62+
bm25Config?: BM25Config;
63+
/**
64+
* Optional neural reranker. When provided, the merged pool is
65+
* reranked before truncation. Passing the same reranker the
66+
* baseline uses is the matched-ablation path.
67+
*/
68+
rerankerService?: RerankerService;
69+
/** Default dense weight in RRF. @default 0.7 */
70+
defaultDenseWeight?: number;
71+
/** Default sparse weight in RRF. @default 0.3 */
72+
defaultSparseWeight?: number;
73+
/** Default RRF constant. @default 60 */
74+
defaultRrfK?: number;
75+
}
76+
77+
/**
78+
* Per-call options for {@link HybridRetriever.retrieve}.
79+
*/
80+
export interface HybridRetrieveOptions {
81+
/** Final truncation after merge + rerank. @default 10 */
82+
recallTopK?: number;
83+
/** Over-fetch multiplier for each side before merge. @default 3 */
84+
overFetchMultiplier?: number;
85+
denseWeight?: number;
86+
sparseWeight?: number;
87+
rrfK?: number;
88+
}
89+
90+
/**
91+
* Hybrid BM25 + dense retriever.
92+
*
93+
* @example
94+
* ```ts
95+
* const hybrid = new HybridRetriever({ memoryStore, rerankerService });
96+
* // At ingest:
97+
* hybrid.bm25.addDocument(trace.id, trace.content, { tag: 'bench-session:s-1' });
98+
* // At query time:
99+
* const result = await hybrid.retrieve(
100+
* 'What did the user say about their mortgage?',
101+
* { valence: 0, arousal: 0, dominance: 0 },
102+
* { scope: 'user', scopeId: 'u1' },
103+
* { recallTopK: 10 },
104+
* );
105+
* ```
106+
*/
107+
export class HybridRetriever {
108+
readonly bm25: BM25Index;
109+
110+
private readonly memoryStore: MemoryStore;
111+
private readonly rerankerService?: RerankerService;
112+
private readonly defaultDenseWeight: number;
113+
private readonly defaultSparseWeight: number;
114+
private readonly defaultRrfK: number;
115+
116+
constructor(opts: HybridRetrieverOptions) {
117+
this.memoryStore = opts.memoryStore;
118+
this.bm25 = new BM25Index(opts.bm25Config);
119+
this.rerankerService = opts.rerankerService;
120+
this.defaultDenseWeight = opts.defaultDenseWeight ?? 0.7;
121+
this.defaultSparseWeight = opts.defaultSparseWeight ?? 0.3;
122+
this.defaultRrfK = opts.defaultRrfK ?? 60;
123+
}
124+
125+
async retrieve(
126+
query: string,
127+
mood: PADState,
128+
scope: { scope: MemoryScope; scopeId: string },
129+
options: HybridRetrieveOptions = {},
130+
): Promise<CognitiveRetrievalResult> {
131+
const startTime = Date.now();
132+
const recallTopK = options.recallTopK ?? 10;
133+
const overFetchMultiplier = options.overFetchMultiplier ?? 3;
134+
const overFetchTopK = recallTopK * overFetchMultiplier;
135+
const wDense = options.denseWeight ?? this.defaultDenseWeight;
136+
const wSparse = options.sparseWeight ?? this.defaultSparseWeight;
137+
const rrfK = options.rrfK ?? this.defaultRrfK;
138+
139+
// Dense side: use MemoryStore.query so we keep the 6-signal
140+
// cognitive scoring (strength, recency, etc.) — matches baseline.
141+
const { scored: denseScored, timings: denseTimings } = await this.memoryStore.query(
142+
query,
143+
mood,
144+
{ topK: overFetchTopK, scopes: [scope] },
145+
);
146+
147+
// Sparse side: BM25 over the per-instance index.
148+
const sparseResults = this.bm25.search(query, overFetchTopK);
149+
150+
// Fallback: empty BM25 index or zero sparse hits => dense-only
151+
// with explicit escalation diagnostic.
152+
if (sparseResults.length === 0) {
153+
return this.buildResult(denseScored.slice(0, recallTopK), {
154+
escalations: ['hybrid-retriever:sparse-empty'],
155+
candidatesScanned: denseScored.length,
156+
vectorSearchMs: denseTimings.vectorSearchMs,
157+
scoringMs: denseTimings.scoringMs,
158+
totalMs: Date.now() - startTime,
159+
});
160+
}
161+
162+
// Build 1-based ranked lists for RRF.
163+
const denseRanked: RankedDoc[] = denseScored.map((t, i) => ({ id: t.id, rank: i + 1 }));
164+
const sparseRanked: RankedDoc[] = sparseResults.map((r, i) => ({ id: r.id, rank: i + 1 }));
165+
const merged = reciprocalRankFusion(denseRanked, sparseRanked, {
166+
denseWeight: wDense,
167+
sparseWeight: wSparse,
168+
k: rrfK,
169+
});
170+
171+
// Hydrate: resolve each RRFResult.id to the ScoredMemoryTrace from
172+
// the dense side. Skip sparse-only docs (see file docstring).
173+
const denseById = new Map(denseScored.map((t) => [t.id, t]));
174+
const hydrated: ScoredMemoryTrace[] = [];
175+
for (const m of merged) {
176+
const trace = denseById.get(m.id);
177+
if (trace) {
178+
hydrated.push(trace);
179+
}
180+
// MVP: sparse-only docs (not in denseById) are skipped.
181+
}
182+
183+
// Optional rerank: same 0.7 cognitive + 0.3 neural blend as baseline.
184+
if (this.rerankerService && hydrated.length > 0) {
185+
try {
186+
const rerankerOutput = await this.rerankerService.rerank(
187+
{
188+
query,
189+
documents: hydrated.map((t) => ({
190+
id: t.id,
191+
content: t.content,
192+
originalScore: t.retrievalScore,
193+
})),
194+
},
195+
{ topN: hydrated.length },
196+
);
197+
const neuralScores = new Map(
198+
rerankerOutput.results.map((r) => [r.id, r.relevanceScore]),
199+
);
200+
for (const trace of hydrated) {
201+
const neural = neuralScores.get(trace.id);
202+
if (neural !== undefined) {
203+
trace.retrievalScore = 0.7 * trace.retrievalScore + 0.3 * neural;
204+
}
205+
}
206+
hydrated.sort((a, b) => b.retrievalScore - a.retrievalScore);
207+
} catch {
208+
// Reranker errors are non-critical: keep RRF ordering.
209+
}
210+
}
211+
212+
// Truncate to recallTopK.
213+
const truncated = hydrated.slice(0, recallTopK);
214+
return this.buildResult(truncated, {
215+
candidatesScanned: denseScored.length + sparseResults.length,
216+
vectorSearchMs: denseTimings.vectorSearchMs,
217+
scoringMs: denseTimings.scoringMs,
218+
totalMs: Date.now() - startTime,
219+
});
220+
}
221+
222+
/** Assemble the CognitiveRetrievalResult shape. */
223+
private buildResult(
224+
retrieved: ScoredMemoryTrace[],
225+
d: {
226+
escalations?: string[];
227+
candidatesScanned: number;
228+
vectorSearchMs: number;
229+
scoringMs: number;
230+
totalMs: number;
231+
},
232+
): CognitiveRetrievalResult {
233+
return {
234+
retrieved,
235+
partiallyRetrieved: [],
236+
diagnostics: {
237+
candidatesScanned: d.candidatesScanned,
238+
vectorSearchTimeMs: d.vectorSearchMs,
239+
scoringTimeMs: d.scoringMs,
240+
totalTimeMs: d.totalMs,
241+
...(d.escalations ? { escalations: d.escalations } : {}),
242+
},
243+
};
244+
}
245+
}
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import { describe, it, expect } from 'vitest';
2+
import { HybridRetriever } from '../HybridRetriever.js';
3+
import type {
4+
ScoredMemoryTrace,
5+
CognitiveRetrievalOptions,
6+
MemoryScope,
7+
} from '../../../core/types.js';
8+
import type { PADState } from '../../../core/config.js';
9+
import type { MemoryStore } from '../../store/MemoryStore.js';
10+
import type { RerankerService } from '../../../../rag/reranking/RerankerService.js';
11+
12+
function mkTrace(id: string, score: number, content = `content-${id}`): ScoredMemoryTrace {
13+
return {
14+
id,
15+
type: 'episodic',
16+
scope: 'user',
17+
scopeId: 'u1',
18+
content,
19+
entities: [],
20+
tags: [],
21+
provenance: { sourceType: 'user_statement', sourceTimestamp: 0, confidence: 1, verificationCount: 0 },
22+
emotionalContext: { valence: 0, arousal: 0, dominance: 0, intensity: 0, gmiMood: '' },
23+
encodingStrength: 0.5, stability: 0.5, retrievalCount: 0, lastAccessedAt: 0,
24+
accessCount: 0, reinforcementInterval: 0, associatedTraceIds: [],
25+
createdAt: 0, updatedAt: 0, isActive: true,
26+
retrievalScore: score,
27+
scoreBreakdown: {
28+
strengthScore: 0, similarityScore: score, recencyScore: 0,
29+
emotionalCongruenceScore: 0, graphActivationScore: 0, importanceScore: 0,
30+
},
31+
};
32+
}
33+
34+
class FakeMemoryStore {
35+
constructor(private traces: ScoredMemoryTrace[] = []) {}
36+
async query(_q: string, _mood: PADState, opts: CognitiveRetrievalOptions) {
37+
return {
38+
scored: this.traces.slice(0, opts.topK ?? 10),
39+
partial: [],
40+
timings: { vectorSearchMs: 1, scoringMs: 1 },
41+
};
42+
}
43+
}
44+
45+
class FakeReranker {
46+
public called = false;
47+
async rerank(input: { query: string; documents: Array<{ id: string; content: string; originalScore?: number }> }) {
48+
this.called = true;
49+
return {
50+
results: input.documents.slice().reverse().map((d, i) => ({
51+
id: d.id,
52+
relevanceScore: 1 - i * 0.1,
53+
originalScore: d.originalScore,
54+
})),
55+
model: 'fake-rerank',
56+
usage: { searchUnits: 1 },
57+
};
58+
}
59+
}
60+
61+
const neutralMood: PADState = { valence: 0, arousal: 0, dominance: 0 };
62+
const scope = { scope: 'user' as MemoryScope, scopeId: 'u1' };
63+
64+
describe('HybridRetriever', () => {
65+
it('happy path: RRF merges dense + sparse, returns results', async () => {
66+
const memoryStore = new FakeMemoryStore([
67+
mkTrace('t1', 0.9), mkTrace('t2', 0.8), mkTrace('t3', 0.7),
68+
]);
69+
const r = new HybridRetriever({
70+
memoryStore: memoryStore as unknown as MemoryStore,
71+
});
72+
r.bm25.addDocument('t2', 'alpha beta gamma');
73+
r.bm25.addDocument('t4', 'alpha delta');
74+
const result = await r.retrieve('alpha', neutralMood, scope, { recallTopK: 10 });
75+
expect(result.retrieved.length).toBeGreaterThan(0);
76+
expect(result.retrieved.some((t) => t.id === 't2')).toBe(true);
77+
});
78+
79+
it('sparse-only docs (not in dense pool) are skipped in MVP', async () => {
80+
const memoryStore = new FakeMemoryStore([mkTrace('t1', 0.9)]);
81+
const r = new HybridRetriever({
82+
memoryStore: memoryStore as unknown as MemoryStore,
83+
});
84+
r.bm25.addDocument('t99', 'alpha');
85+
const result = await r.retrieve('alpha', neutralMood, scope, { recallTopK: 10 });
86+
expect(result.retrieved.some((t) => t.id === 't99')).toBe(false);
87+
});
88+
89+
it('empty BM25 index: degrades to dense-only with escalation diagnostic', async () => {
90+
const memoryStore = new FakeMemoryStore([mkTrace('t1', 0.9), mkTrace('t2', 0.8)]);
91+
const r = new HybridRetriever({
92+
memoryStore: memoryStore as unknown as MemoryStore,
93+
});
94+
const result = await r.retrieve('q', neutralMood, scope, { recallTopK: 10 });
95+
expect(result.retrieved.length).toBe(2);
96+
expect(result.diagnostics.escalations).toContain('hybrid-retriever:sparse-empty');
97+
});
98+
99+
it('rerank applied when rerankerService present', async () => {
100+
const memoryStore = new FakeMemoryStore([mkTrace('t1', 0.9), mkTrace('t2', 0.8)]);
101+
const reranker = new FakeReranker();
102+
const r = new HybridRetriever({
103+
memoryStore: memoryStore as unknown as MemoryStore,
104+
rerankerService: reranker as unknown as RerankerService,
105+
});
106+
r.bm25.addDocument('t1', 'alpha');
107+
r.bm25.addDocument('t2', 'alpha beta');
108+
await r.retrieve('alpha', neutralMood, scope, { recallTopK: 10 });
109+
expect(reranker.called).toBe(true);
110+
});
111+
112+
it('rerank skipped when no rerankerService', async () => {
113+
const memoryStore = new FakeMemoryStore([mkTrace('t1', 0.9)]);
114+
const r = new HybridRetriever({
115+
memoryStore: memoryStore as unknown as MemoryStore,
116+
});
117+
r.bm25.addDocument('t1', 'alpha');
118+
const result = await r.retrieve('alpha', neutralMood, scope, { recallTopK: 10 });
119+
expect(result.retrieved.length).toBeGreaterThan(0);
120+
});
121+
122+
it('truncates to recallTopK after merge', async () => {
123+
const traces = Array.from({ length: 20 }, (_, i) => mkTrace(`t${i}`, 1 - i * 0.01));
124+
const memoryStore = new FakeMemoryStore(traces);
125+
const r = new HybridRetriever({
126+
memoryStore: memoryStore as unknown as MemoryStore,
127+
});
128+
for (const t of traces) r.bm25.addDocument(t.id, t.content);
129+
const result = await r.retrieve('content', neutralMood, scope, { recallTopK: 5 });
130+
expect(result.retrieved.length).toBe(5);
131+
});
132+
});

0 commit comments

Comments
 (0)