diff --git a/cmd/engine/main.go b/cmd/engine/main.go index e3dd112..6be4ee1 100644 --- a/cmd/engine/main.go +++ b/cmd/engine/main.go @@ -129,6 +129,28 @@ func run() error { } } + // ReRanker: opt-in Phase 2.3. Instantiated whenever an LLM client + // is wired — the per-request `enable_rerank` body field overrides + // the config, mirroring the planner pattern. + var reRanker *retrieval.ReRanker + if llmClient != nil { + reRankModel := cfg.Retrieval.ReRank.Model + if reRankModel == "" { + reRankModel = modelFor(cfg.LLM) + } + reRanker = retrieval.NewReRanker(llmClient, reRankModel) + if cfg.Retrieval.ReRank.MaxContentChars > 0 { + reRanker.MaxContentChars = cfg.Retrieval.ReRank.MaxContentChars + } + if cfg.Retrieval.ReRank.Enabled { + logger.Info("retrieval: rerank enabled", + "model", reRankModel, + "max_content_chars", reRanker.MaxContentChars, + "top_k", cfg.Retrieval.ReRank.TopK, + ) + } + } + pipeline := ingest.NewPipeline(ingest.Pipeline{ DB: pool, Storage: store, @@ -157,6 +179,8 @@ func run() error { Answer: cfg.Retrieval.Answer, Planner: planner, Planning: cfg.Retrieval.Planning, + ReRanker: reRanker, + ReRank: cfg.Retrieval.ReRank, } srv := &http.Server{ diff --git a/config.example.yaml b/config.example.yaml index ffb43a1..53b5a79 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -152,6 +152,36 @@ retrieval: # isolation (plan returned, but retrieval uses the original query). decompose: true + # rerank: Phase 2.3 content-aware re-rank pass. After the retrieval + # strategy returns candidate sections and their content is loaded, + # one extra LLM call scores each section (0-100) against the query + # and the engine reorders descending by score. + # + # This is the safety net for the case where the strategy reasoned + # over title + summary + HyDE candidate questions and got fooled + # by surface-level matches. Reading the actual content closes that + # gap. ~3-5k input tokens per query on gemini-2.5-flash; ~$0.0003 + # per call at typical rates. + # + # OPT-IN. Default disabled. Per-request `enable_rerank` body field + # overrides this block. Failures never drop sections — at worst the + # strategy's order is preserved. + rerank: + enabled: false + # Override the re-rank model; empty inherits the request's model + # (or the engine default). Keep this on a small/fast model — the + # re-rank prompt is short and shouldn't burn the flagship model. + model: "" + # Per-candidate content budget. Higher = more context for the + # model to judge with, lower = tighter cost. 2000 chars ≈ 500 + # tokens, comfortable for typical section sizes. + max_content_chars: 2000 + # Truncate the post-rerank candidate list to the top K. 0 means + # keep all candidates (re-rank only reorders). Useful when the + # strategy returns a wide candidate list and you want the + # re-rank pass to do the final selection. + top_k: 0 + ingest: # The summarize and HyDE stages run concurrently. This caps the total # number of LLM calls in flight across both stages combined, so the diff --git a/internal/api/server.go b/internal/api/server.go index 38dadce..2a4f557 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -73,6 +73,16 @@ type Deps struct { // `enable_planning` field on /v1/query and /v1/answer overrides // Planning.Enabled. Planning config.PlanningBlock + + // ReRanker runs Phase 2.3 content-aware re-rank on the strategy's + // candidate sections (one extra LLM call per query). Nil disables + // re-rank even when a request opts in via `enable_rerank`. + ReRanker *retrieval.ReRanker + + // ReRank carries the server-side re-rank config. The body-level + // `enable_rerank` field on /v1/query and /v1/answer overrides + // ReRank.Enabled. TopK truncates the post-rerank candidate list. + ReRank config.ReRankBlock } // Router builds and returns the chi router wired with v1 routes. @@ -400,6 +410,10 @@ func (d Deps) handleQuery(w http.ResponseWriter, r *http.Request) { // planner. A pointer so we can distinguish "absent" from // "explicit false" — absent falls back to the server config. EnablePlanning *bool `json:"enable_planning"` + // EnableReRank opts this request into the Phase 2.3 + // content-aware re-rank pass. Pointer for the same reason as + // EnablePlanning. Overrides retrieval.rerank.enabled. + EnableReRank *bool `json:"enable_rerank"` } if err := json.NewDecoder(r.Body).Decode(&body); err != nil { writeErr(w, http.StatusBadRequest, "invalid json: "+err.Error()) @@ -471,6 +485,15 @@ func (d Deps) handleQuery(w http.ResponseWriter, r *http.Request) { enriched = append(enriched, sectionWithContent{sec: sec, content: content}) } + // Optional: content-aware re-rank pass. One LLM call that scores + // each loaded section against the query and re-orders the slice + // descending by score. TopK truncates the survivors. Failures + // never drop sections — at worst the strategy's order is + // preserved (see retrieval.ReRanker.ReRank). + if d.reRankEnabled(body.EnableReRank) { + enriched, _ = d.runReRank(r.Context(), enriched, body.Query, body.Model) + } + // Optional: per-section answer-span extraction. Opt-in via config — // one LLM call per returned section. Failures are non-fatal; the // section is returned without a span. @@ -499,11 +522,18 @@ func (d Deps) handleQuery(w http.ResponseWriter, r *http.Request) { } // sectionWithContent bundles a tree section with its loaded content -// and an optional answer-span. Used by /v1/query and /v1/answer. +// and optional re-rank score / answer-span. Used by /v1/query and +// /v1/answer. type sectionWithContent struct { sec *tree.Section content string span *retrieval.AnswerSpan + + // hasScore reports whether score was populated by a re-rank pass. + // Distinct from score == 0 since 0 is a legitimate score the + // model can return. + hasScore bool + score float64 } // sectionWithContentToMap renders the section as the API map shape. @@ -528,6 +558,9 @@ func sectionWithContentToMap(e sectionWithContent) map[string]any { if e.span != nil { s["answer_span"] = e.span } + if e.hasScore { + s["score"] = e.score + } return s } @@ -620,6 +653,10 @@ func (d Deps) handleAnswer(w http.ResponseWriter, r *http.Request) { // EnablePlanning opts this request into the Phase 2.1 query // planner. See handleQuery for the same field's semantics. EnablePlanning *bool `json:"enable_planning"` + // EnableReRank opts this request into the Phase 2.3 re-rank + // pass. Synthesis then sees the re-ranked top-k. Overrides + // retrieval.rerank.enabled. + EnableReRank *bool `json:"enable_rerank"` } if err := json.NewDecoder(r.Body).Decode(&body); err != nil { writeErr(w, http.StatusBadRequest, "invalid json: "+err.Error()) @@ -699,6 +736,16 @@ func (d Deps) handleAnswer(w http.ResponseWriter, r *http.Request) { enriched = append(enriched, sectionWithContent{sec: sec, content: content}) } + // Optional: content-aware re-rank before synthesis sees the + // evidence. When TopK is set the synthesis prompt only ever sees + // the post-rerank top-k, keeping the answer focused on the + // best-evidence sections. + if d.reRankEnabled(body.EnableReRank) { + var reRankUsage retrieval.Usage + enriched, reRankUsage = d.runReRank(r.Context(), enriched, body.Query, body.Model) + totalUsage.Add(reRankUsage) + } + // Always extract spans for /v1/answer — they ground each citation. spanExtractor := d.spanExtractor(body.Model) runSpansConcurrent(r.Context(), spanExtractor, body.Query, enriched, d.AnswerSpan.MaxConcurrency, d.Logger) @@ -747,6 +794,9 @@ func (d Deps) handleAnswer(w http.ResponseWriter, r *http.Request) { c["quote_end"] = e.span.End } } + if e.hasScore { + c["score"] = e.score + } citations = append(citations, c) } @@ -1132,6 +1182,123 @@ func (d Deps) shouldDecompose(plan *retrieval.Plan) bool { return d.Planning.Decompose } +// --- re-rank helpers --- + +// reRankEnabled reports whether the request should go through the +// re-rank pass. The per-request body field (when present) wins over +// the server-side config; a nil body field falls back to the config. +// +// Returns false when no LLM client is wired or when no ReRanker is +// configured, regardless of intent — re-rank without an LLM is +// physically impossible. +func (d Deps) reRankEnabled(bodyOverride *bool) bool { + if d.ReRanker == nil || d.LLM == nil { + return false + } + if bodyOverride != nil { + return *bodyOverride + } + return d.ReRank.Enabled +} + +// runReRank executes the re-rank pass over the loaded section slice +// and returns the reordered slice plus the LLM Usage spent. On any +// failure the original slice is returned (with the same hasScore +// values it had on input — i.e. unchanged) so the caller never has +// to think about partial state. The error is LOGGED, not returned — +// re-rank is best-effort and a failure must never abort the request. +// +// requestModel is the model the request asked for. When the +// ReRanker has its own Model set (the config-level override), that +// wins; the request model is the fall-through. +func (d Deps) runReRank(ctx context.Context, enriched []sectionWithContent, query, requestModel string) ([]sectionWithContent, retrieval.Usage) { + if d.ReRanker == nil || d.LLM == nil || len(enriched) == 0 { + return enriched, retrieval.Usage{} + } + + // Apply the model fall-through: config override → request model → + // engine default. We don't mutate d.ReRanker since Deps is shared + // across requests; instead build a shallow copy with the chosen + // model. This is the same pattern spanExtractor() uses. + ranker := *d.ReRanker + if ranker.Model == "" { + if requestModel != "" { + ranker.Model = requestModel + } else { + ranker.Model = d.LLMModel + } + } + + candidates := make([]retrieval.SectionContent, len(enriched)) + for i, e := range enriched { + candidates[i] = retrieval.SectionContent{ + ID: e.sec.ID, + Title: e.sec.Title, + Content: e.content, + } + } + + scored, usage, err := ranker.ReRank(ctx, query, candidates) + if err != nil { + if d.Logger != nil { + d.Logger.Warn("rerank: failed; preserving strategy order", "err", err) + } + // ReRank returns input order on error so we *could* apply it + // (it'd just stamp score=0 on everything). Skip — the caller + // shouldn't see score=0 on every section when re-rank + // physically failed. + return enriched, usage + } + if len(scored) == 0 { + return enriched, usage + } + + reordered := reorderByScore(enriched, scored) + if d.ReRank.TopK > 0 && len(reordered) > d.ReRank.TopK { + reordered = reordered[:d.ReRank.TopK] + } + return reordered, usage +} + +// reorderByScore takes the loaded section slice and the model's +// scored output (already sorted descending by score by the +// ReRanker), and returns a new slice in the same order as scored +// with each entry carrying the per-section score. +// +// Defensive: every input enriched section appears in the output +// exactly once, in the order dictated by scored. If scored is +// missing an input ID (shouldn't happen — ReRank's contract is to +// surface every input ID), that section is appended at the end with +// hasScore=false so the response stays complete. +func reorderByScore(enriched []sectionWithContent, scored []retrieval.ScoredSection) []sectionWithContent { + byID := make(map[tree.SectionID]int, len(enriched)) + for i, e := range enriched { + byID[e.sec.ID] = i + } + + out := make([]sectionWithContent, 0, len(enriched)) + taken := make([]bool, len(enriched)) + for _, s := range scored { + idx, ok := byID[s.ID] + if !ok || taken[idx] { + continue + } + taken[idx] = true + e := enriched[idx] + e.hasScore = true + e.score = s.Score + out = append(out, e) + } + // Append anything ReRank didn't surface — invariant says this + // should be empty, but a defence-in-depth check costs nothing. + for i, e := range enriched { + if !taken[i] { + out = append(out, e) + } + } + return out +} + // writePlanHints appends a short, model-readable "Planner notes" block // describing the structured plan. Synthesis uses this to orient itself // before reading the evidence. diff --git a/openapi.yaml b/openapi.yaml index f215556..8b35548 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -462,6 +462,19 @@ components: retrieval fans out one selection call per sub-question and unions the results. Overrides the server's `retrieval.planning.enabled` setting for this request only. + enable_rerank: + type: boolean + description: | + Opt this request into the Phase 2.3 content-aware re-rank + pass. After the retrieval strategy returns candidate + section IDs and the engine loads their content, one extra + LLM call scores each section (0-100) against the query and + sections are reordered descending by score. When + `retrieval.rerank.top_k` is set on the server, only the + top-k sections survive. Failures preserve the strategy's + original order — re-rank never drops sections. Overrides + the server's `retrieval.rerank.enabled` setting for this + request only. QueryResponse: type: object @@ -513,6 +526,14 @@ components: description: Full section content from storage. answer_span: $ref: "#/components/schemas/AnswerSpan" + score: + type: number + description: | + Re-rank relevance score on a 0-100 scale, populated only + when the request opted into the Phase 2.3 re-rank pass + (`enable_rerank`) or the server has + `retrieval.rerank.enabled=true`. Sections are returned + sorted descending by score. Omitted when no re-rank ran. AnswerSpan: type: object @@ -560,6 +581,14 @@ components: full semantics. When enabled, the synthesis prompt also sees the planner's structured intent and entity hints, and the response carries a top-level `plan` field. + enable_rerank: + type: boolean + description: | + Opt this request into the Phase 2.3 content-aware re-rank + pass. See QueryRequest.enable_rerank for full semantics. + When the pass runs, the synthesis prompt sees the + re-ranked top-k (capped by `retrieval.rerank.top_k`), and + each citation in the response carries a `score` field. AnswerResponse: type: object @@ -647,6 +676,9 @@ components: found one). `quote_start`/`quote_end` give byte offsets into the source section's content. `page_start`/`page_end` are the section's page range — omitted for non-paginated formats. + `score` carries the re-rank relevance score on a 0-100 scale, + present only when the request opted into the Phase 2.3 + re-rank pass. properties: section_id: type: string @@ -662,3 +694,9 @@ components: type: integer page_end: type: integer + score: + type: number + description: | + Re-rank relevance score on a 0-100 scale. Omitted when no + re-rank ran. Higher means the section is more directly + relevant to the query. diff --git a/pkg/config/config.go b/pkg/config/config.go index 28ad5ac..abfb8ce 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -204,6 +204,44 @@ type RetrievalConfig struct { AnswerSpan AnswerSpanBlock `yaml:"answer_span"` Answer AnswerBlock `yaml:"answer"` Planning PlanningBlock `yaml:"planning"` + ReRank ReRankBlock `yaml:"rerank"` +} + +// ReRankBlock configures the Phase 2.3 content-aware re-rank pass. +// +// When enabled, every /v1/query and /v1/answer request that returns +// candidate sections runs one extra LLM call: the candidates' first +// MaxContentChars chars of content are fed to the model with the +// query, and the model returns a per-section relevance score +// (0-100). Sections are reordered by score; if TopK > 0 the response +// is truncated to the top K. +// +// The pass is opt-in. Per-request `enable_rerank` body field +// overrides this block. +// +// Re-rank failures never drop sections — at worst the original +// strategy order is preserved and the request returns unchanged from +// the no-rerank path. See pkg/retrieval/rerank.go for the exact +// degradation contract. +type ReRankBlock struct { + // Enabled toggles re-rank at the server level. Default: false. + Enabled bool `yaml:"enabled"` + + // Model overrides the re-rank LLM model. Empty means use the + // request's model (which itself falls back to the engine default). + // Point this at a small/fast model — the re-rank prompt is short + // and shouldn't burn the flagship model's budget. + Model string `yaml:"model"` + + // MaxContentChars caps how many characters of each candidate's + // content are sent to the model. Default: 2000. + MaxContentChars int `yaml:"max_content_chars"` + + // TopK caps the number of sections kept after re-ranking. 0 means + // keep all candidates (re-rank only reorders). Useful when the + // strategy is configured to return a wide candidate list and the + // re-rank pass picks the focused top-k for synthesis. + TopK int `yaml:"top_k"` } // PlanningBlock configures Phase 2.1 query planning + Phase 2.2 multi-hop @@ -368,6 +406,11 @@ func Default() Config { CacheSize: 128, Decompose: true, }, + ReRank: ReRankBlock{ + Enabled: false, + MaxContentChars: 2000, + TopK: 0, + }, }, Ingest: IngestConfig{ GlobalLLMConcurrency: 12, @@ -556,6 +599,27 @@ func applyEnvOverrides(c *Config) { c.Retrieval.Planning.Decompose = false } } + if v := os.Getenv("VLE_RETRIEVAL_RERANK_ENABLED"); v != "" { + switch strings.ToLower(strings.TrimSpace(v)) { + case "1", "true", "yes", "on": + c.Retrieval.ReRank.Enabled = true + case "0", "false", "no", "off": + c.Retrieval.ReRank.Enabled = false + } + } + if v := os.Getenv("VLE_RETRIEVAL_RERANK_MODEL"); v != "" { + c.Retrieval.ReRank.Model = v + } + if v := os.Getenv("VLE_RETRIEVAL_RERANK_MAX_CONTENT_CHARS"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + c.Retrieval.ReRank.MaxContentChars = n + } + } + if v := os.Getenv("VLE_RETRIEVAL_RERANK_TOP_K"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n >= 0 { + c.Retrieval.ReRank.TopK = n + } + } } // Validate checks that required fields for the selected drivers are set. @@ -642,5 +706,12 @@ func (c Config) Validate() error { return fmt.Errorf("retrieval.planning.cache_size must be >= 0, got %d", c.Retrieval.Planning.CacheSize) } + if c.Retrieval.ReRank.MaxContentChars < 0 { + return fmt.Errorf("retrieval.rerank.max_content_chars must be >= 0, got %d", c.Retrieval.ReRank.MaxContentChars) + } + if c.Retrieval.ReRank.TopK < 0 { + return fmt.Errorf("retrieval.rerank.top_k must be >= 0, got %d", c.Retrieval.ReRank.TopK) + } + return nil } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index fb60973..64ba7d3 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -46,11 +46,123 @@ func TestDefaultValues(t *testing.T) { if !cfg.Retrieval.Planning.Decompose { t.Error("retrieval.planning.decompose should default to true (when planning is enabled)") } + if cfg.Retrieval.ReRank.Enabled { + t.Error("retrieval.rerank.enabled should default to false (opt-in)") + } + if cfg.Retrieval.ReRank.MaxContentChars != 2000 { + t.Errorf("retrieval.rerank.max_content_chars = %d, want 2000", cfg.Retrieval.ReRank.MaxContentChars) + } + if cfg.Retrieval.ReRank.TopK != 0 { + t.Errorf("retrieval.rerank.top_k = %d, want 0 (keep all)", cfg.Retrieval.ReRank.TopK) + } if cfg.Log.Level != "info" { t.Errorf("log.level = %q, want info", cfg.Log.Level) } } +func TestReRankEnvOverride(t *testing.T) { + // Not parallel — mutates env. Restore on exit. + prevEnabled := os.Getenv("VLE_RETRIEVAL_RERANK_ENABLED") + prevModel := os.Getenv("VLE_RETRIEVAL_RERANK_MODEL") + prevMax := os.Getenv("VLE_RETRIEVAL_RERANK_MAX_CONTENT_CHARS") + prevTopK := os.Getenv("VLE_RETRIEVAL_RERANK_TOP_K") + defer func() { + os.Setenv("VLE_RETRIEVAL_RERANK_ENABLED", prevEnabled) + os.Setenv("VLE_RETRIEVAL_RERANK_MODEL", prevModel) + os.Setenv("VLE_RETRIEVAL_RERANK_MAX_CONTENT_CHARS", prevMax) + os.Setenv("VLE_RETRIEVAL_RERANK_TOP_K", prevTopK) + }() + + os.Setenv("VLE_RETRIEVAL_RERANK_ENABLED", "true") + os.Setenv("VLE_RETRIEVAL_RERANK_MODEL", "gemini-2.0-flash") + os.Setenv("VLE_RETRIEVAL_RERANK_MAX_CONTENT_CHARS", "1500") + os.Setenv("VLE_RETRIEVAL_RERANK_TOP_K", "3") + + cfg := Default() + applyEnvOverrides(&cfg) + + if !cfg.Retrieval.ReRank.Enabled { + t.Error("VLE_RETRIEVAL_RERANK_ENABLED=true should enable rerank") + } + if cfg.Retrieval.ReRank.Model != "gemini-2.0-flash" { + t.Errorf("rerank model = %q, want gemini-2.0-flash", cfg.Retrieval.ReRank.Model) + } + if cfg.Retrieval.ReRank.MaxContentChars != 1500 { + t.Errorf("rerank max_content_chars = %d, want 1500", cfg.Retrieval.ReRank.MaxContentChars) + } + if cfg.Retrieval.ReRank.TopK != 3 { + t.Errorf("rerank top_k = %d, want 3", cfg.Retrieval.ReRank.TopK) + } +} + +func TestReRankEnvOverrideDisable(t *testing.T) { + // Toggle off via env: start from a config that defaults to false, + // then set =false explicitly; verify the path executes (not just + // that the default value is preserved). + prev := os.Getenv("VLE_RETRIEVAL_RERANK_ENABLED") + defer os.Setenv("VLE_RETRIEVAL_RERANK_ENABLED", prev) + + cfg := Default() + cfg.Retrieval.ReRank.Enabled = true // simulate a YAML-set true + os.Setenv("VLE_RETRIEVAL_RERANK_ENABLED", "false") + applyEnvOverrides(&cfg) + if cfg.Retrieval.ReRank.Enabled { + t.Error("VLE_RETRIEVAL_RERANK_ENABLED=false should disable rerank even when YAML set it true") + } +} + +func TestReRankEnvOverrideRejectsBad(t *testing.T) { + // Garbage env values should be rejected, not silently zero the field. + prevMax := os.Getenv("VLE_RETRIEVAL_RERANK_MAX_CONTENT_CHARS") + prevTopK := os.Getenv("VLE_RETRIEVAL_RERANK_TOP_K") + defer func() { + os.Setenv("VLE_RETRIEVAL_RERANK_MAX_CONTENT_CHARS", prevMax) + os.Setenv("VLE_RETRIEVAL_RERANK_TOP_K", prevTopK) + }() + + os.Setenv("VLE_RETRIEVAL_RERANK_MAX_CONTENT_CHARS", "not-a-number") + os.Setenv("VLE_RETRIEVAL_RERANK_TOP_K", "abc") + + cfg := Default() + applyEnvOverrides(&cfg) + if cfg.Retrieval.ReRank.MaxContentChars != 2000 { + t.Errorf("bad max_content_chars env should preserve default, got %d", cfg.Retrieval.ReRank.MaxContentChars) + } + if cfg.Retrieval.ReRank.TopK != 0 { + t.Errorf("bad top_k env should preserve default, got %d", cfg.Retrieval.ReRank.TopK) + } +} + +func TestValidateReRankNegatives(t *testing.T) { + t.Parallel() + + // Negative max_content_chars rejected. + cfg := Default() + cfg.Database.URL = "postgres://localhost/test" + cfg.Retrieval.ReRank.MaxContentChars = -1 + if err := cfg.Validate(); err == nil { + t.Error("negative max_content_chars should fail validation") + } + + // Negative top_k rejected. + cfg2 := Default() + cfg2.Database.URL = "postgres://localhost/test" + cfg2.Retrieval.ReRank.TopK = -1 + if err := cfg2.Validate(); err == nil { + t.Error("negative top_k should fail validation") + } + + // Zero on both is valid (TopK=0 means "keep all", MaxContentChars=0 + // means "use default"). + cfg3 := Default() + cfg3.Database.URL = "postgres://localhost/test" + cfg3.Retrieval.ReRank.MaxContentChars = 0 + cfg3.Retrieval.ReRank.TopK = 0 + if err := cfg3.Validate(); err != nil { + t.Errorf("zero rerank values should pass validation: %v", err) + } +} + func TestPlanningEnvOverride(t *testing.T) { // Not parallel — mutates env. Restore on exit. prevEnabled := os.Getenv("VLE_RETRIEVAL_PLANNING_ENABLED") diff --git a/pkg/retrieval/rerank.go b/pkg/retrieval/rerank.go new file mode 100644 index 0000000..c32363b --- /dev/null +++ b/pkg/retrieval/rerank.go @@ -0,0 +1,431 @@ +package retrieval + +import ( + "context" + "encoding/json" + "fmt" + "log" + "strings" + + "github.com/hallelx2/llmgate" + + "github.com/hallelx2/vectorless-engine/pkg/tree" +) + +// ScoredSection is one entry in the re-ranker's output: a section ID, +// the model's relevance score (0-100), and an optional short reason. +// +// Scores are returned as float64 so the caller can sort stably and so +// future re-rankers (e.g. combining the model score with a cheap +// lexical prior) don't have to re-encode the field. The 0-100 scale is +// what the prompt asks the model to use; rerank.go preserves whatever +// the model returns (clamped to non-negative) rather than rescaling. +type ScoredSection struct { + ID tree.SectionID `json:"id"` + Score float64 `json:"score"` + Reason string `json:"reason,omitempty"` +} + +// SectionContent is one candidate passed to the re-ranker. The caller is +// responsible for loading Content from storage; the re-ranker does not +// touch the storage layer itself. +// +// Title is included separately from Content so the prompt can present +// both even when the section's body is empty (e.g. a structural-only +// section whose children carry the real text). Models tend to ignore a +// section whose body is blank if the title isn't surfaced explicitly. +type SectionContent struct { + ID tree.SectionID + Title string + Content string +} + +// ReRanker re-orders a strategy's candidate sections by reading the first +// chunk of each section's content and asking the LLM which sections +// actually answer the query. This rescues the case where the retrieval +// strategy reasoned over titles + summaries + HyDE candidate questions +// alone and got fooled by surface-level signals. +// +// One LLM call per query, regardless of candidate count. Cost is bounded +// by MaxContentChars per section × candidate count, plus a small +// prompt overhead. +// +// Re-rank is intentionally tolerant about model output: bad JSON, +// unknown IDs, and missing IDs are all handled gracefully so a single +// model blip never drops a candidate. See ReRank for the exact contract. +type ReRanker struct { + // LLM is the client used for the re-rank call. + LLM llmgate.Client + + // Model is the model name passed to the LLM. Callers should point + // this at a small/fast model — the re-rank call is short and + // running it on the flagship model would defeat the cost story. + Model string + + // MaxContentChars caps how many characters of each section's content + // are sent to the model. Default: defaultReRankMaxContentChars. + // Set higher when sections are long and the query needs more + // context to decide; set lower to tighten the budget. + MaxContentChars int + + // MaxRetries bounds JSON-parse retries on a single re-rank call. + // Zero defaults to defaultReRankRetries. + MaxRetries int +} + +// defaultReRankMaxContentChars is the per-section content budget. ~2000 +// chars × 5 candidates ≈ 10k chars ≈ 2.5k tokens — comfortable inside +// any modern model's context window and cheap on gemini-2.5-flash. +const defaultReRankMaxContentChars = 2000 + +// defaultReRankRetries mirrors the planning + selection retry counts. +// Models occasionally drop out of JSON mode; one retry usually recovers. +const defaultReRankRetries = 2 + +// NewReRanker constructs a ReRanker with sensible defaults. Callers can +// override MaxContentChars/MaxRetries on the returned struct. +func NewReRanker(client llmgate.Client, model string) *ReRanker { + return &ReRanker{ + LLM: client, + Model: model, + MaxContentChars: defaultReRankMaxContentChars, + MaxRetries: defaultReRankRetries, + } +} + +// ReRank scores and reorders candidates by relevance to query. Returns +// a slice the same length as candidates (in descending score order), +// the cumulative LLM Usage, and an optional error. +// +// Failure semantics — the whole point of this method is to be safer +// than a hard re-rank that can drop sections on model flakes: +// +// - Empty candidates → returns (nil, zero Usage, nil) without any +// LLM call. Callers can pass an empty list unconditionally. +// - LLM transport failure → returns the input order (each ID with +// score=0) plus a non-nil error. The caller logs and keeps moving. +// - All retry attempts return un-parseable JSON → returns the input +// order with score=0 and a nil error. This mirrors how +// runSelectionWithRetry degrades: a single JSON glitch must not +// 500 the request. +// - Response references unknown IDs → those entries are dropped; +// only IDs present in candidates surface. +// - Response is missing some input IDs → those IDs get score=0 and +// appear at the bottom of the output in their original relative +// order. +// +// In all cases, every input ID appears in the output exactly once. +// This is the load-bearing invariant: re-rank can reorder, but it +// never drops candidates. +func (r *ReRanker) ReRank(ctx context.Context, query string, candidates []SectionContent) ([]ScoredSection, Usage, error) { + if r == nil || r.LLM == nil { + // A nil re-ranker is treated as a no-op (returns input order) + // so production wiring can pass nil when re-rank is disabled. + return inputOrderScored(candidates), Usage{}, nil + } + if len(candidates) == 0 { + return nil, Usage{}, nil + } + + maxChars := r.MaxContentChars + if maxChars <= 0 { + maxChars = defaultReRankMaxContentChars + } + maxRetries := r.MaxRetries + if maxRetries < 0 { + maxRetries = 0 + } + if maxRetries == 0 { + maxRetries = defaultReRankRetries + } + + prompt := buildReRankPrompt(query, candidates, maxChars) + baseReq := llmgate.Request{ + Model: r.Model, + Messages: []llmgate.Message{ + {Role: llmgate.RoleSystem, Content: reRankSystemPrompt}, + {Role: llmgate.RoleUser, Content: prompt}, + }, + MaxTokens: 1024, + Temperature: 0, + JSONMode: true, + JSONSchema: []byte(reRankJSONSchema), + } + + scored, usage, err := runReRankWithRetry(ctx, r.LLM, baseReq, maxRetries) + if err != nil { + // Transport failure: preserve input order, surface the error + // so the caller can decide whether to log loud or quiet. + return inputOrderScored(candidates), usage, fmt.Errorf("rerank llm call: %w", err) + } + if scored == nil { + // All retries failed to parse. Degrade gracefully — same shape + // as a nil-rerank result so the response stays consistent. + return inputOrderScored(candidates), usage, nil + } + + return mergeScored(candidates, scored), usage, nil +} + +// reRankSystemPrompt frames the task. The 0-100 scale was picked over +// 0-1 because models are noticeably better at returning a coarse +// integer score than a fine-grained float, and the downstream code +// only needs ordering. "The answer is in this section" is deliberately +// phrased as "directly answers OR provides the load-bearing evidence +// for" — the goal is to surface sections that close out the query, not +// sections that merely mention the topic. +const reRankSystemPrompt = `You are a precise relevance scorer. Given a user query and a list of candidate document sections (each shown with its ID, title, and the first portion of its content), score how well each section actually answers the query. + +Rules: +- Score each section on a 0-100 integer scale where: + - 90-100: the section directly answers the query OR provides the load-bearing evidence the answer relies on. + - 60-89: the section is highly relevant — it discusses the topic the query is about and likely contributes to the answer. + - 30-59: the section is tangentially related — it mentions the topic but probably does not answer the query. + - 0-29: the section is not useful for this query. Generic mentions, off-topic content, or content that only matches on a shared keyword without real relevance. +- Score every section in the input list. Do not skip any. +- Use the section IDs exactly as provided. Do not invent IDs. +- For each score include a one-line "reason" (≤120 chars) explaining the score. Keep it concrete — quote a phrase from the section when possible. + +Return only the JSON object described in the schema. No prose, no markdown.` + +const reRankJSONSchema = `{ + "type": "object", + "properties": { + "scores": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "score": {"type": "number"}, + "reason": {"type": "string"} + }, + "required": ["id", "score"] + } + } + }, + "required": ["scores"] +}` + +// buildReRankPrompt renders the user message. Sections are presented as +// "[id] Title\nContent excerpt" blocks, separated by blank lines so the +// model sees clear boundaries. Content is truncated at maxChars with an +// ellipsis when cut so the model can tell a section was long without +// having to count. +func buildReRankPrompt(query string, candidates []SectionContent, maxChars int) string { + var b strings.Builder + b.WriteString("User query:\n") + b.WriteString(query) + b.WriteString("\n\nCandidate sections:\n") + + for i, c := range candidates { + if i > 0 { + b.WriteString("\n") + } + fmt.Fprintf(&b, "[%s] %s\n", string(c.ID), c.Title) + excerpt := strings.TrimSpace(c.Content) + if excerpt == "" { + b.WriteString("(section body is empty)\n") + continue + } + if len(excerpt) > maxChars { + excerpt = excerpt[:maxChars] + "…" + } + b.WriteString(excerpt) + b.WriteByte('\n') + } + + b.WriteString("\nReturn a JSON object with a `scores` array. Each entry has `id` (string, exactly as shown above), `score` (number, 0-100), and `reason` (string, ≤120 chars). Score every section in the candidate list.") + return b.String() +} + +// reRankPayload is the expected JSON shape. +type reRankPayload struct { + Scores []reRankItem `json:"scores"` +} + +type reRankItem struct { + ID string `json:"id"` + Score float64 `json:"score"` + Reason string `json:"reason"` +} + +// ParseReRank extracts a re-rank result from an LLM JSON response. +// Tolerates code-fence wrappers and leading/trailing prose, like +// ParseSelection and ParsePlan. Returns the parsed []ScoredSection +// (un-merged with the input candidate list — that's mergeScored's job) +// and any parse error. +func ParseReRank(raw string) ([]ScoredSection, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil, fmt.Errorf("empty rerank response") + } + if strings.HasPrefix(raw, "```") { + if i := strings.Index(raw, "\n"); i >= 0 { + raw = raw[i+1:] + } + raw = strings.TrimSuffix(raw, "```") + raw = strings.TrimSpace(raw) + } + if i := strings.Index(raw, "{"); i > 0 { + raw = raw[i:] + } + if j := strings.LastIndex(raw, "}"); j >= 0 && j < len(raw)-1 { + raw = raw[:j+1] + } + + var p reRankPayload + if err := json.Unmarshal([]byte(raw), &p); err != nil { + return nil, fmt.Errorf("unmarshal rerank: %w", err) + } + out := make([]ScoredSection, 0, len(p.Scores)) + for _, it := range p.Scores { + id := strings.TrimSpace(it.ID) + if id == "" { + continue + } + score := it.Score + if score < 0 { + score = 0 + } + out = append(out, ScoredSection{ + ID: tree.SectionID(id), + Score: score, + Reason: strings.TrimSpace(it.Reason), + }) + } + return out, nil +} + +// runReRankWithRetry issues the re-rank LLM call, retrying up to +// maxRetries additional times when the response does not parse. Mirrors +// runSelectionWithRetry / runPlanningWithRetry. Returns (nil, usage, nil) +// when retries are exhausted so the caller falls back to input order. +func runReRankWithRetry(ctx context.Context, client llmgate.Client, baseReq llmgate.Request, maxRetries int) ([]ScoredSection, Usage, error) { + var totalUsage Usage + var lastParseErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + req := baseReq + if attempt > 0 { + msgs := make([]llmgate.Message, len(baseReq.Messages)) + copy(msgs, baseReq.Messages) + tail := len(msgs) - 1 + msgs[tail] = llmgate.Message{ + Role: msgs[tail].Role, + Content: msgs[tail].Content + "\n\nIMPORTANT: respond with ONLY a JSON object matching the schema. Do not include prose, explanation, or markdown fences.", + } + req.Messages = msgs + } + resp, err := client.Complete(ctx, req) + if err != nil { + return nil, totalUsage, err + } + totalUsage.Add(Usage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.TotalTokens, + CostUSD: resp.Usage.CostUSD, + LLMCalls: 1, + }) + scored, parseErr := ParseReRank(resp.Content) + if parseErr == nil { + return scored, totalUsage, nil + } + lastParseErr = parseErr + } + log.Printf("retrieval: rerank parse failed after %d attempts (%v); preserving input order", maxRetries+1, lastParseErr) + return nil, totalUsage, nil +} + +// inputOrderScored returns the candidates as ScoredSection entries with +// score=0, preserving the input order. Used as the safe fallback when +// the re-ranker can't produce a real score for any reason. +func inputOrderScored(candidates []SectionContent) []ScoredSection { + if len(candidates) == 0 { + return nil + } + out := make([]ScoredSection, len(candidates)) + for i, c := range candidates { + out[i] = ScoredSection{ID: c.ID, Score: 0} + } + return out +} + +// mergeScored combines the input candidates with the model's scored +// output. Every input ID appears in the result exactly once: +// +// - Known IDs (present in both input + model output) carry the +// model's score and reason, sorted descending by score. +// - Unknown IDs from the model output are dropped — the model +// hallucinated. +// - Input IDs missing from the model output get score=0 and appear +// at the bottom in input order (so the response is still useful +// to the caller). +// +// Ties on score preserve the original input order. This makes the +// output deterministic when the model returns equal scores and keeps +// reasonable behaviour when the model returns uniform scores (e.g. +// "everything is a 50") — the strategy's original order wins. +func mergeScored(candidates []SectionContent, scored []ScoredSection) []ScoredSection { + if len(candidates) == 0 { + return nil + } + // Position of each input ID, for stable ordering on ties. + pos := make(map[tree.SectionID]int, len(candidates)) + for i, c := range candidates { + pos[c.ID] = i + } + + // Index scored entries by ID. If the model returned the same ID + // twice, the first wins — defensive against duplicate entries. + byID := make(map[tree.SectionID]ScoredSection, len(scored)) + for _, s := range scored { + if _, known := pos[s.ID]; !known { + continue // hallucinated ID + } + if _, seen := byID[s.ID]; seen { + continue + } + byID[s.ID] = s + } + + out := make([]ScoredSection, 0, len(candidates)) + missing := make([]ScoredSection, 0) + for _, c := range candidates { + if s, ok := byID[c.ID]; ok { + out = append(out, s) + } else { + missing = append(missing, ScoredSection{ID: c.ID, Score: 0}) + } + } + + // Stable descending sort by score with original-order tiebreak. + // Hand-rolled because we want strict stability across equal + // scores and Go's sort.Slice is not stable. + insertionSortByScore(out, pos) + + // Append missing IDs (all score=0) at the bottom, in input order. + out = append(out, missing...) + return out +} + +// insertionSortByScore sorts ss descending by Score, with ties broken +// by the input position recorded in pos (lower pos → earlier). O(n²) +// is fine here: re-rank candidates are typically ≤20. +func insertionSortByScore(ss []ScoredSection, pos map[tree.SectionID]int) { + for i := 1; i < len(ss); i++ { + cur := ss[i] + curPos := pos[cur.ID] + j := i - 1 + for j >= 0 { + cmp := ss[j] + cmpPos := pos[cmp.ID] + if cmp.Score > cur.Score || (cmp.Score == cur.Score && cmpPos <= curPos) { + break + } + ss[j+1] = ss[j] + j-- + } + ss[j+1] = cur + } +} diff --git a/pkg/retrieval/rerank_test.go b/pkg/retrieval/rerank_test.go new file mode 100644 index 0000000..d20eadb --- /dev/null +++ b/pkg/retrieval/rerank_test.go @@ -0,0 +1,514 @@ +package retrieval_test + +import ( + "context" + "encoding/json" + "errors" + "strings" + "sync" + "sync/atomic" + "testing" + + "github.com/hallelx2/llmgate" + + "github.com/hallelx2/vectorless-engine/pkg/retrieval" + "github.com/hallelx2/vectorless-engine/pkg/tree" +) + +// contains is a tiny shim so we don't have to litter assertions with +// strings.Contains; the external _test package cannot share span_test.go's +// internal helper. +func contains(s, sub string) bool { return strings.Contains(s, sub) } + +// rerankMock is a minimal llmgate client that returns scripted replies +// in order. Each Complete() call advances the counter; when the +// scripted replies are exhausted the last reply is reused so a +// single-element script behaves like a fixed response. +// +// Kept separate from planner/single_pass mocks because the re-rank +// schema is distinct and the test surface is small enough that a +// dedicated mock keeps the assertions simple. +type rerankMock struct { + mu sync.Mutex + replies []string + err error + + calls int32 + prompts []string +} + +func (m *rerankMock) Complete(ctx context.Context, req llmgate.Request) (*llmgate.Response, error) { + m.mu.Lock() + defer m.mu.Unlock() + atomic.AddInt32(&m.calls, 1) + for _, msg := range req.Messages { + if msg.Role == llmgate.RoleUser { + m.prompts = append(m.prompts, msg.Content) + } + } + if m.err != nil { + return nil, m.err + } + if len(m.replies) == 0 { + return &llmgate.Response{}, nil + } + idx := int(atomic.LoadInt32(&m.calls)) - 1 + if idx >= len(m.replies) { + idx = len(m.replies) - 1 + } + return &llmgate.Response{ + Content: m.replies[idx], + Usage: llmgate.Usage{ + InputTokens: 100, + OutputTokens: 30, + TotalTokens: 130, + CostUSD: 0.0003, + }, + }, nil +} + +func (m *rerankMock) CountTokens(ctx context.Context, s string) (int, error) { + return len(s) / 4, nil +} + +// scoreReply marshals a list of (id, score, reason) tuples into the +// re-rank JSON envelope. +func scoreReply(items ...rerankReplyItem) string { + type payload struct { + Scores []rerankReplyItem `json:"scores"` + } + raw, _ := json.Marshal(payload{Scores: items}) + return string(raw) +} + +type rerankReplyItem struct { + ID string `json:"id"` + Score float64 `json:"score"` + Reason string `json:"reason,omitempty"` +} + +// sampleCandidates builds a small candidate list used by most tests. +// "sec-1" / "sec-2" / "sec-3" — short stable IDs so assertions are +// readable. +func sampleCandidates() []retrieval.SectionContent { + return []retrieval.SectionContent{ + {ID: tree.SectionID("sec-1"), Title: "Long-Term Debt", Content: "The company reported long-term debt of $4.2B as of Q4 2024."}, + {ID: tree.SectionID("sec-2"), Title: "Revenue Breakdown", Content: "Apple's fiscal 2023 revenue was $383.3B, down 2.8% YoY."}, + {ID: tree.SectionID("sec-3"), Title: "Risk Factors", Content: "Foreign currency translation may impact future revenue."}, + } +} + +// TestReRanker_HappyPath: model returns reordered scores, output is +// sorted descending by score. +func TestReRanker_HappyPath(t *testing.T) { + t.Parallel() + m := &rerankMock{ + replies: []string{ + // sec-2 is the most relevant (92), sec-3 next (45), sec-1 last (10). + // Strategy returned them in 1/2/3 order — re-rank should flip it. + scoreReply( + rerankReplyItem{ID: "sec-1", Score: 10, Reason: "long-term debt not relevant to revenue query"}, + rerankReplyItem{ID: "sec-2", Score: 92, Reason: "directly states fiscal 2023 revenue"}, + rerankReplyItem{ID: "sec-3", Score: 45, Reason: "tangential mention of revenue"}, + ), + }, + } + r := retrieval.NewReRanker(m, "gemini-2.5-flash") + + got, usage, err := r.ReRank(context.Background(), "What was Apple's fiscal 2023 revenue?", sampleCandidates()) + if err != nil { + t.Fatalf("ReRank: %v", err) + } + if len(got) != 3 { + t.Fatalf("len(got) = %d, want 3", len(got)) + } + wantOrder := []tree.SectionID{"sec-2", "sec-3", "sec-1"} + for i, w := range wantOrder { + if got[i].ID != w { + t.Errorf("got[%d].ID = %q, want %q (full order: %+v)", i, got[i].ID, w, got) + } + } + if got[0].Score != 92 { + t.Errorf("top score = %v, want 92", got[0].Score) + } + if got[0].Reason == "" { + t.Error("top entry should carry a reason from the model") + } + if usage.LLMCalls != 1 { + t.Errorf("Usage.LLMCalls = %d, want 1", usage.LLMCalls) + } + if usage.CostUSD <= 0 { + t.Errorf("Usage.CostUSD = %v, want > 0", usage.CostUSD) + } +} + +// TestReRanker_EmptyInput: nil/empty candidate list short-circuits +// without an LLM call. +func TestReRanker_EmptyInput(t *testing.T) { + t.Parallel() + m := &rerankMock{replies: []string{`{"scores":[]}`}} + r := retrieval.NewReRanker(m, "gemini-2.5-flash") + + got, usage, err := r.ReRank(context.Background(), "irrelevant", nil) + if err != nil { + t.Fatalf("ReRank: %v", err) + } + if got != nil { + t.Errorf("empty input should return nil slice, got %+v", got) + } + if usage.LLMCalls != 0 { + t.Errorf("empty input should issue no LLM calls, got %d", usage.LLMCalls) + } + if c := atomic.LoadInt32(&m.calls); c != 0 { + t.Errorf("empty input must NOT call the LLM, got %d calls", c) + } +} + +// TestReRanker_LLMFailure: transport error bubbles up, input order is +// preserved with score=0 so the caller never loses a candidate. +func TestReRanker_LLMFailure(t *testing.T) { + t.Parallel() + m := &rerankMock{err: errors.New("provider 500")} + r := retrieval.NewReRanker(m, "gemini-2.5-flash") + + cands := sampleCandidates() + got, _, err := r.ReRank(context.Background(), "any query", cands) + if err == nil { + t.Fatal("expected transport error, got nil") + } + if len(got) != len(cands) { + t.Fatalf("LLM failure must preserve all candidates: got %d, want %d", len(got), len(cands)) + } + for i, c := range cands { + if got[i].ID != c.ID { + t.Errorf("got[%d].ID = %q, want %q (preserved input order)", i, got[i].ID, c.ID) + } + if got[i].Score != 0 { + t.Errorf("got[%d].Score = %v, want 0 on transport failure", i, got[i].Score) + } + } +} + +// TestReRanker_BadJSONExhaustsRetries: when all retry attempts return +// un-parseable JSON, returns input order + nil error (graceful +// degradation, matching runSelectionWithRetry behaviour). +func TestReRanker_BadJSONExhaustsRetries(t *testing.T) { + t.Parallel() + m := &rerankMock{ + replies: []string{ + "sorry, here's some prose instead of JSON", + "still talking, not JSON", + "and one more time", + }, + } + r := retrieval.NewReRanker(m, "gemini-2.5-flash") + r.MaxRetries = 2 // 1 initial + 2 retries = 3 attempts + + cands := sampleCandidates() + got, usage, err := r.ReRank(context.Background(), "any query", cands) + if err != nil { + t.Fatalf("parse-only failure must return nil error, got %v", err) + } + if c := atomic.LoadInt32(&m.calls); c != 3 { + t.Errorf("expected 3 LLM attempts (1 + 2 retries), got %d", c) + } + if usage.LLMCalls != 3 { + t.Errorf("usage.LLMCalls = %d, want 3 (all attempts counted)", usage.LLMCalls) + } + if len(got) != len(cands) { + t.Fatalf("parse failure must preserve all candidates: got %d", len(got)) + } + for i, c := range cands { + if got[i].ID != c.ID { + t.Errorf("got[%d].ID = %q, want %q (input order)", i, got[i].ID, c.ID) + } + if got[i].Score != 0 { + t.Errorf("got[%d].Score = %v, want 0 on parse failure", i, got[i].Score) + } + } +} + +// TestReRanker_BadJSONThenSuccess: a single bad reply followed by a +// good one returns the parsed scores. +func TestReRanker_BadJSONThenSuccess(t *testing.T) { + t.Parallel() + m := &rerankMock{ + replies: []string{ + "not json", + scoreReply( + rerankReplyItem{ID: "sec-1", Score: 50}, + rerankReplyItem{ID: "sec-2", Score: 80}, + rerankReplyItem{ID: "sec-3", Score: 20}, + ), + }, + } + r := retrieval.NewReRanker(m, "gemini-2.5-flash") + got, usage, err := r.ReRank(context.Background(), "Q", sampleCandidates()) + if err != nil { + t.Fatalf("ReRank: %v", err) + } + if usage.LLMCalls != 2 { + t.Errorf("LLMCalls = %d, want 2 (1 failed + 1 ok)", usage.LLMCalls) + } + if got[0].ID != "sec-2" { + t.Errorf("top = %q, want sec-2 (highest score)", got[0].ID) + } +} + +// TestReRanker_UnknownIDDropped: when the model invents an ID, it is +// silently dropped from the output. The known IDs still surface. +func TestReRanker_UnknownIDDropped(t *testing.T) { + t.Parallel() + m := &rerankMock{ + replies: []string{ + scoreReply( + rerankReplyItem{ID: "sec-1", Score: 30}, + rerankReplyItem{ID: "sec-2", Score: 70}, + rerankReplyItem{ID: "sec-3", Score: 50}, + rerankReplyItem{ID: "sec-bogus", Score: 99, Reason: "hallucinated"}, + ), + }, + } + r := retrieval.NewReRanker(m, "gemini-2.5-flash") + got, _, err := r.ReRank(context.Background(), "Q", sampleCandidates()) + if err != nil { + t.Fatalf("ReRank: %v", err) + } + if len(got) != 3 { + t.Fatalf("len(got) = %d, want 3 (bogus ID dropped)", len(got)) + } + for _, s := range got { + if s.ID == "sec-bogus" { + t.Errorf("hallucinated ID leaked into output: %+v", s) + } + } + if got[0].ID != "sec-2" { + t.Errorf("top = %q, want sec-2", got[0].ID) + } +} + +// TestReRanker_MissingIDsScoreZero: input IDs the model didn't score +// appear at the bottom with score=0. +func TestReRanker_MissingIDsScoreZero(t *testing.T) { + t.Parallel() + m := &rerankMock{ + replies: []string{ + // Model only scored sec-2; sec-1 and sec-3 are missing. + scoreReply(rerankReplyItem{ID: "sec-2", Score: 88}), + }, + } + r := retrieval.NewReRanker(m, "gemini-2.5-flash") + got, _, err := r.ReRank(context.Background(), "Q", sampleCandidates()) + if err != nil { + t.Fatalf("ReRank: %v", err) + } + if len(got) != 3 { + t.Fatalf("len(got) = %d, want 3 (missing IDs must still surface)", len(got)) + } + if got[0].ID != "sec-2" || got[0].Score != 88 { + t.Errorf("top = %+v, want sec-2 / 88", got[0]) + } + // Missing IDs come back with score=0 in input order. sec-1 was input + // position 0 and sec-3 was input position 2, so we expect sec-1 then sec-3. + if got[1].ID != "sec-1" || got[1].Score != 0 { + t.Errorf("got[1] = %+v, want sec-1 / 0", got[1]) + } + if got[2].ID != "sec-3" || got[2].Score != 0 { + t.Errorf("got[2] = %+v, want sec-3 / 0", got[2]) + } +} + +// TestReRanker_DuplicateIDsInResponse: when the model returns the same +// ID twice the first wins and the duplicate is dropped. +func TestReRanker_DuplicateIDsInResponse(t *testing.T) { + t.Parallel() + m := &rerankMock{ + replies: []string{ + scoreReply( + rerankReplyItem{ID: "sec-1", Score: 60, Reason: "first"}, + rerankReplyItem{ID: "sec-1", Score: 10, Reason: "duplicate"}, + rerankReplyItem{ID: "sec-2", Score: 30}, + rerankReplyItem{ID: "sec-3", Score: 20}, + ), + }, + } + r := retrieval.NewReRanker(m, "gemini-2.5-flash") + got, _, err := r.ReRank(context.Background(), "Q", sampleCandidates()) + if err != nil { + t.Fatalf("ReRank: %v", err) + } + if len(got) != 3 { + t.Fatalf("len(got) = %d, want 3", len(got)) + } + if got[0].ID != "sec-1" || got[0].Score != 60 || got[0].Reason != "first" { + t.Errorf("expected first occurrence to win, got %+v", got[0]) + } +} + +// TestReRanker_NegativeScoreClamped: a negative score from the model +// is clamped to 0 by ParseReRank. +func TestReRanker_NegativeScoreClamped(t *testing.T) { + t.Parallel() + m := &rerankMock{ + replies: []string{ + scoreReply( + rerankReplyItem{ID: "sec-1", Score: -5}, + rerankReplyItem{ID: "sec-2", Score: 30}, + rerankReplyItem{ID: "sec-3", Score: 0}, + ), + }, + } + r := retrieval.NewReRanker(m, "gemini-2.5-flash") + got, _, err := r.ReRank(context.Background(), "Q", sampleCandidates()) + if err != nil { + t.Fatalf("ReRank: %v", err) + } + // Top should be sec-2 (30). sec-1 with score=-5 is clamped to 0; it + // will tie with sec-3 (0) and the stable-sort tiebreak puts sec-1 + // before sec-3 since that's the input order. + if got[0].ID != "sec-2" { + t.Errorf("top = %q, want sec-2", got[0].ID) + } + for _, s := range got { + if s.Score < 0 { + t.Errorf("negative score leaked through: %+v", s) + } + } +} + +// TestReRanker_PromptIncludesContent: the prompt actually carries the +// candidate content (otherwise re-rank is back to title-only). +func TestReRanker_PromptIncludesContent(t *testing.T) { + t.Parallel() + m := &rerankMock{ + replies: []string{scoreReply( + rerankReplyItem{ID: "sec-1", Score: 50}, + rerankReplyItem{ID: "sec-2", Score: 50}, + rerankReplyItem{ID: "sec-3", Score: 50}, + )}, + } + r := retrieval.NewReRanker(m, "gemini-2.5-flash") + _, _, err := r.ReRank(context.Background(), "What was Apple's fiscal 2023 revenue?", sampleCandidates()) + if err != nil { + t.Fatalf("ReRank: %v", err) + } + if len(m.prompts) != 1 { + t.Fatalf("want 1 captured user prompt, got %d", len(m.prompts)) + } + prompt := m.prompts[0] + // All three IDs surfaced. + for _, id := range []string{"sec-1", "sec-2", "sec-3"} { + if !contains(prompt, "["+id+"]") { + t.Errorf("prompt missing ID marker %q", id) + } + } + // Content excerpt for sec-2 must appear (the prompt is title + body). + if !contains(prompt, "Apple's fiscal 2023 revenue was $383.3B") { + t.Error("prompt missing sec-2 content excerpt") + } + if !contains(prompt, "What was Apple's fiscal 2023 revenue?") { + t.Error("prompt missing user query") + } +} + +// TestReRanker_MaxContentCharsTruncates: when content exceeds the cap +// the prompt carries a truncated excerpt with an ellipsis. +func TestReRanker_MaxContentCharsTruncates(t *testing.T) { + t.Parallel() + m := &rerankMock{replies: []string{scoreReply(rerankReplyItem{ID: "sec-1", Score: 50})}} + r := retrieval.NewReRanker(m, "gemini-2.5-flash") + r.MaxContentChars = 20 + + longContent := "AAAAAAAAAA BBBBBBBBBB CCCCCCCCCC DDDDDDDDDD EEEEEEEEEE" + _, _, err := r.ReRank(context.Background(), "Q", []retrieval.SectionContent{ + {ID: tree.SectionID("sec-1"), Title: "T", Content: longContent}, + }) + if err != nil { + t.Fatalf("ReRank: %v", err) + } + prompt := m.prompts[0] + if !contains(prompt, "AAAAAAAAAA BBBBBBBBB") { + t.Errorf("prompt missing first 20 chars of long content") + } + if !contains(prompt, "…") { + t.Error("prompt missing ellipsis marker for truncation") + } + if contains(prompt, "EEEEEEEEEE") { + t.Error("prompt unexpectedly contained tail of long content") + } +} + +// TestReRanker_NilLLMNoOp: a re-ranker with a nil LLM client returns +// input order without panicking. This lets server wiring pass a stub +// for the disabled case. +func TestReRanker_NilLLMNoOp(t *testing.T) { + t.Parallel() + r := &retrieval.ReRanker{Model: "any"} + got, usage, err := r.ReRank(context.Background(), "Q", sampleCandidates()) + if err != nil { + t.Fatalf("ReRank with nil LLM: %v", err) + } + if len(got) != 3 { + t.Errorf("nil LLM should preserve input order, got %d entries", len(got)) + } + if usage.LLMCalls != 0 { + t.Errorf("nil LLM should issue no calls, got %d", usage.LLMCalls) + } +} + +// TestReRanker_NilReRankerNoOp: a nil *ReRanker is safe — callers can +// pass nil when re-rank is disabled. +func TestReRanker_NilReRankerNoOp(t *testing.T) { + t.Parallel() + var r *retrieval.ReRanker + got, usage, err := r.ReRank(context.Background(), "Q", sampleCandidates()) + if err != nil { + t.Fatalf("nil ReRanker: %v", err) + } + if len(got) != 3 { + t.Errorf("nil ReRanker should preserve input order, got %d entries", len(got)) + } + if usage.LLMCalls != 0 { + t.Errorf("nil ReRanker should issue no calls, got %d", usage.LLMCalls) + } +} + +// TestParseReRank_CodeFence: tolerant parsing matches the other +// JSON-mode parsers in this package. +func TestParseReRank_CodeFence(t *testing.T) { + t.Parallel() + raw := "```json\n" + scoreReply(rerankReplyItem{ID: "sec-1", Score: 75}) + "\n```" + got, err := retrieval.ParseReRank(raw) + if err != nil { + t.Fatalf("ParseReRank: %v", err) + } + if len(got) != 1 || got[0].ID != "sec-1" || got[0].Score != 75 { + t.Errorf("got %+v", got) + } +} + +// TestParseReRank_LeadingProse: leading prose ahead of the JSON object +// is stripped. +func TestParseReRank_LeadingProse(t *testing.T) { + t.Parallel() + raw := "Sure, here are the scores: " + scoreReply(rerankReplyItem{ID: "sec-1", Score: 10}) + got, err := retrieval.ParseReRank(raw) + if err != nil { + t.Fatalf("ParseReRank: %v", err) + } + if len(got) != 1 { + t.Errorf("got %+v", got) + } +} + +// TestParseReRank_Empty: empty / whitespace-only responses error so +// retry can fire. +func TestParseReRank_Empty(t *testing.T) { + t.Parallel() + if _, err := retrieval.ParseReRank(""); err == nil { + t.Error("empty input should parse-error so retry fires") + } + if _, err := retrieval.ParseReRank(" \n\n "); err == nil { + t.Error("whitespace input should parse-error") + } +}