Skip to content

Commit ca44097

Browse files
committed
memory: HybridRetriever split-on-ambiguous rerank refinement (monotonic, additive)
1 parent be9585e commit ca44097

3 files changed

Lines changed: 238 additions & 0 deletions

File tree

src/memory/core/types.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,17 @@ export interface CognitiveRetrievalResult {
274274
llmLatencyMs: number;
275275
notes?: string[];
276276
};
277+
/**
278+
* Step-6: when `HybridRetriever` runs with `splitAmbiguousThreshold`
279+
* set, the bottom fraction of traces by first-pass rerank score
280+
* are split at sentence boundaries and rescored. Replacements are
281+
* recorded here for post-hoc analysis.
282+
*/
283+
splitOnAmbiguous?: {
284+
threshold: number;
285+
candidateCount: number;
286+
replacedIds: string[];
287+
};
277288
};
278289
}
279290

src/memory/retrieval/hybrid/HybridRetriever.ts

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ export interface HybridRetrieverOptions {
7777
* to the raw query without aborting retrieval.
7878
*/
7979
hydeRetriever?: HydeRetriever;
80+
/**
81+
* Step-6: enable split-on-ambiguous rerank refinement. When set to a
82+
* value in (0, 1], the bottom fraction of traces by first-pass
83+
* rerank score are split at sentence boundaries, rescored with a
84+
* second rerank call (same query), and replaced by their better
85+
* half ONLY IF the better half outscores the original. Monotonic.
86+
*
87+
* Default: undefined (no split, Step 3 behavior preserved).
88+
*/
89+
splitAmbiguousThreshold?: number;
8090
/** Default dense weight in RRF. @default 0.7 */
8191
defaultDenseWeight?: number;
8292
/** Default sparse weight in RRF. @default 0.3 */
@@ -121,6 +131,7 @@ export class HybridRetriever {
121131
private readonly memoryStore: MemoryStore;
122132
private readonly rerankerService?: RerankerService;
123133
private readonly hydeRetriever?: HydeRetriever;
134+
private readonly splitAmbiguousThreshold?: number;
124135
private readonly defaultDenseWeight: number;
125136
private readonly defaultSparseWeight: number;
126137
private readonly defaultRrfK: number;
@@ -130,6 +141,7 @@ export class HybridRetriever {
130141
this.bm25 = new BM25Index(opts.bm25Config);
131142
this.rerankerService = opts.rerankerService;
132143
this.hydeRetriever = opts.hydeRetriever;
144+
this.splitAmbiguousThreshold = opts.splitAmbiguousThreshold;
133145
this.defaultDenseWeight = opts.defaultDenseWeight ?? 0.7;
134146
this.defaultSparseWeight = opts.defaultSparseWeight ?? 0.3;
135147
this.defaultRrfK = opts.defaultRrfK ?? 60;
@@ -214,6 +226,7 @@ export class HybridRetriever {
214226
}
215227

216228
// Optional rerank: same 0.7 cognitive + 0.3 neural blend as baseline.
229+
let splitDiagnostic: { threshold: number; candidateCount: number; replacedIds: string[] } | undefined;
217230
if (this.rerankerService && hydrated.length > 0) {
218231
try {
219232
const rerankerOutput = await this.rerankerService.rerank(
@@ -236,6 +249,21 @@ export class HybridRetriever {
236249
trace.retrievalScore = 0.7 * trace.retrievalScore + 0.3 * neural;
237250
}
238251
}
252+
253+
// Step-6: split-on-ambiguous refinement.
254+
if (
255+
this.splitAmbiguousThreshold !== undefined &&
256+
this.splitAmbiguousThreshold > 0 &&
257+
hydrated.length > 0
258+
) {
259+
splitDiagnostic = await this.refineAmbiguous(
260+
hydrated,
261+
neuralScores,
262+
query,
263+
this.splitAmbiguousThreshold,
264+
);
265+
}
266+
239267
hydrated.sort((a, b) => b.retrievalScore - a.retrievalScore);
240268
} catch {
241269
// Reranker errors are non-critical: keep RRF ordering.
@@ -250,6 +278,7 @@ export class HybridRetriever {
250278
scoringMs: denseTimings.scoringMs,
251279
totalMs: Date.now() - startTime,
252280
hypothesis: hypothesisDiagnostic,
281+
splitOnAmbiguous: splitDiagnostic,
253282
});
254283
}
255284

@@ -263,6 +292,7 @@ export class HybridRetriever {
263292
scoringMs: number;
264293
totalMs: number;
265294
hypothesis?: string;
295+
splitOnAmbiguous?: { threshold: number; candidateCount: number; replacedIds: string[] };
266296
},
267297
): CognitiveRetrievalResult {
268298
return {
@@ -275,7 +305,111 @@ export class HybridRetriever {
275305
totalTimeMs: d.totalMs,
276306
...(d.escalations ? { escalations: d.escalations } : {}),
277307
...(d.hypothesis ? { hyde: { hypothesis: d.hypothesis } } : {}),
308+
...(d.splitOnAmbiguous ? { splitOnAmbiguous: d.splitOnAmbiguous } : {}),
278309
},
279310
};
280311
}
312+
313+
/**
314+
* Step-6: split bottom-fraction traces by neural score, rescore the
315+
* halves, replace a trace's content with its better half IFF the
316+
* better half's neural score outranks the original's. Monotonic.
317+
*
318+
* Modifies `hydrated` in place: `trace.content` and `trace.retrievalScore`
319+
* are updated for replaced traces. Returns a diagnostic summary.
320+
*/
321+
private async refineAmbiguous(
322+
hydrated: ScoredMemoryTrace[],
323+
neuralScores: Map<string, number>,
324+
query: string,
325+
threshold: number,
326+
): Promise<{ threshold: number; candidateCount: number; replacedIds: string[] }> {
327+
const replacedIds: string[] = [];
328+
329+
const sortedByNeural = hydrated
330+
.map((t) => ({ trace: t, neural: neuralScores.get(t.id) ?? 0 }))
331+
.sort((a, b) => a.neural - b.neural);
332+
const candidateCount = Math.ceil(hydrated.length * threshold);
333+
const candidates = sortedByNeural.slice(0, candidateCount);
334+
335+
type Split = { traceId: string; halfAId: string; halfBId: string; halfA: string; halfB: string; originalNeural: number };
336+
const splits: Split[] = [];
337+
for (const { trace, neural } of candidates) {
338+
const halves = this.splitAtMidpointSentence(trace.content);
339+
if (!halves) continue;
340+
splits.push({
341+
traceId: trace.id,
342+
halfAId: `${trace.id}::a`,
343+
halfBId: `${trace.id}::b`,
344+
halfA: halves[0],
345+
halfB: halves[1],
346+
originalNeural: neural,
347+
});
348+
}
349+
350+
if (splits.length === 0) {
351+
return { threshold, candidateCount, replacedIds };
352+
}
353+
354+
const halfDocs = splits.flatMap((s) => [
355+
{ id: s.halfAId, content: s.halfA },
356+
{ id: s.halfBId, content: s.halfB },
357+
]);
358+
let halfScores: Map<string, number>;
359+
try {
360+
const halfOut = await this.rerankerService!.rerank(
361+
{ query, documents: halfDocs },
362+
{ topN: halfDocs.length },
363+
);
364+
halfScores = new Map(halfOut.results.map((r) => [r.id, r.relevanceScore]));
365+
} catch {
366+
return { threshold, candidateCount, replacedIds };
367+
}
368+
369+
const traceById = new Map(hydrated.map((t) => [t.id, t]));
370+
for (const s of splits) {
371+
const a = halfScores.get(s.halfAId) ?? -Infinity;
372+
const b = halfScores.get(s.halfBId) ?? -Infinity;
373+
const winningScore = Math.max(a, b);
374+
if (winningScore <= s.originalNeural) continue;
375+
const winningText = a >= b ? s.halfA : s.halfB;
376+
const trace = traceById.get(s.traceId);
377+
if (!trace) continue;
378+
trace.content = winningText;
379+
trace.retrievalScore += 0.3 * (winningScore - s.originalNeural);
380+
replacedIds.push(s.traceId);
381+
}
382+
383+
return { threshold, candidateCount, replacedIds };
384+
}
385+
386+
/**
387+
* Split a string at the sentence boundary nearest its midpoint.
388+
* Returns [firstHalf, secondHalf] or null if the string is too short
389+
* or no valid boundary is found.
390+
*/
391+
private splitAtMidpointSentence(text: string): [string, string] | null {
392+
if (text.length < 50) return null;
393+
const mid = Math.floor(text.length / 2);
394+
const window = Math.floor(text.length * 0.4);
395+
const lo = Math.max(0, mid - window);
396+
const hi = Math.min(text.length, mid + window);
397+
for (let offset = 0; offset <= window; offset++) {
398+
for (const sign of [-1, 1] as const) {
399+
const i = mid + sign * offset;
400+
if (i < lo || i > hi) continue;
401+
if (
402+
i > 0 &&
403+
i < text.length - 1 &&
404+
/[.!?]/.test(text[i]) &&
405+
/\s/.test(text[i + 1])
406+
) {
407+
return [text.slice(0, i + 1).trim(), text.slice(i + 1).trim()];
408+
}
409+
}
410+
}
411+
const spaceIdx = text.indexOf(' ', mid);
412+
if (spaceIdx === -1 || spaceIdx >= text.length - 1) return null;
413+
return [text.slice(0, spaceIdx).trim(), text.slice(spaceIdx + 1).trim()];
414+
}
281415
}

src/memory/retrieval/hybrid/__tests__/HybridRetriever.spec.ts

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,96 @@ describe('HybridRetriever + HyDE', () => {
211211
expect(reranker.lastQuery).toBe('what is the user up to?');
212212
});
213213
});
214+
215+
class TableDrivenReranker {
216+
public lastQuery: string | undefined;
217+
public callCount = 0;
218+
public calls: Array<{ query: string; docIds: string[] }> = [];
219+
constructor(private readonly scoreTable: Record<string, number>) {}
220+
async rerank(input: { query: string; documents: Array<{ id: string; content: string; originalScore?: number }> }) {
221+
this.lastQuery = input.query;
222+
this.callCount += 1;
223+
this.calls.push({ query: input.query, docIds: input.documents.map((d) => d.id) });
224+
return {
225+
results: input.documents.map((d) => ({
226+
id: d.id,
227+
relevanceScore: this.scoreTable[d.id] ?? 0,
228+
originalScore: d.originalScore,
229+
})),
230+
model: 'table-rerank',
231+
usage: { searchUnits: 1 },
232+
};
233+
}
234+
}
235+
236+
describe('HybridRetriever + split-on-ambiguous', () => {
237+
it('identifies bottom-N% by rerank score at threshold=0.3', async () => {
238+
const traces = Array.from({ length: 10 }, (_, i) =>
239+
mkTrace(`t${i}`, 0.9 - i * 0.01, `probe sentence one. probe sentence two. probe sentence three ${'x'.repeat(40)}.`),
240+
);
241+
const memoryStore = new FakeMemoryStore(traces);
242+
const neural: Record<string, number> = {};
243+
traces.forEach((t, i) => { neural[t.id] = 1 - i * 0.1; });
244+
traces.forEach((t) => {
245+
neural[`${t.id}::a`] = -1;
246+
neural[`${t.id}::b`] = -1;
247+
});
248+
const reranker = new TableDrivenReranker(neural);
249+
const r = new HybridRetriever({
250+
memoryStore: memoryStore as unknown as MemoryStore,
251+
rerankerService: reranker as unknown as RerankerService,
252+
splitAmbiguousThreshold: 0.3,
253+
});
254+
for (const t of traces) r.bm25.addDocument(t.id, t.content);
255+
const result = await r.retrieve('probe', neutralMood, scope, { recallTopK: 10 });
256+
expect(reranker.callCount).toBe(2);
257+
expect(reranker.calls[1].docIds).toHaveLength(6);
258+
expect(result.diagnostics.splitOnAmbiguous?.candidateCount).toBe(3);
259+
expect(result.diagnostics.splitOnAmbiguous?.replacedIds).toEqual([]);
260+
});
261+
262+
it('replaces content only when winning half outscores original', async () => {
263+
const longContent = 'probe first half sentence. probe second half sentence with lots of extra padding so splitting works right here.';
264+
const traces = [
265+
mkTrace('t0', 0.9, longContent),
266+
mkTrace('t1', 0.8, longContent),
267+
mkTrace('t2', 0.7, longContent),
268+
];
269+
const memoryStore = new FakeMemoryStore(traces);
270+
const neural: Record<string, number> = { t0: 0.9, t1: 0.5, t2: 0.2 };
271+
neural['t2::a'] = 0.1;
272+
neural['t2::b'] = 0.7;
273+
neural['t1::a'] = 0.0;
274+
neural['t1::b'] = 0.0;
275+
const reranker = new TableDrivenReranker(neural);
276+
const r = new HybridRetriever({
277+
memoryStore: memoryStore as unknown as MemoryStore,
278+
rerankerService: reranker as unknown as RerankerService,
279+
splitAmbiguousThreshold: 0.34,
280+
});
281+
for (const t of traces) r.bm25.addDocument(t.id, t.content);
282+
const result = await r.retrieve('probe', neutralMood, scope, { recallTopK: 10 });
283+
const replaced = result.retrieved.find((t) => t.id === 't2');
284+
expect(replaced).toBeDefined();
285+
expect(replaced!.content).toBe('probe second half sentence with lots of extra padding so splitting works right here.');
286+
const unchanged = result.retrieved.find((t) => t.id === 't1');
287+
expect(unchanged).toBeDefined();
288+
expect(unchanged!.content).toBe(longContent);
289+
expect(result.diagnostics.splitOnAmbiguous?.replacedIds).toEqual(['t2']);
290+
});
291+
292+
it('split disabled (threshold=0 or undefined) → no second rerank call', async () => {
293+
const traces = [mkTrace('t0', 0.9, 'probe some content'), mkTrace('t1', 0.8, 'probe other content')];
294+
const memoryStore = new FakeMemoryStore(traces);
295+
const neural: Record<string, number> = { t0: 0.9, t1: 0.8 };
296+
const reranker = new TableDrivenReranker(neural);
297+
const r = new HybridRetriever({
298+
memoryStore: memoryStore as unknown as MemoryStore,
299+
rerankerService: reranker as unknown as RerankerService,
300+
});
301+
for (const t of traces) r.bm25.addDocument(t.id, t.content);
302+
const result = await r.retrieve('probe', neutralMood, scope, { recallTopK: 10 });
303+
expect(reranker.callCount).toBe(1);
304+
expect(result.diagnostics.splitOnAmbiguous).toBeUndefined();
305+
});
306+
});

0 commit comments

Comments
 (0)