In [None]:
# config_path="../configs/query_encoder/config_PwC-Embedding_expr.json"
# data_schema="/workspace/configs/csv_schema/test_2.json"
# docs_jsonl_path="/workspace/results/searched_docs_db/search_documents_test.jsonl"
# auto_data_load=True


In [None]:
# main.py
import os
import json
from typing import List
from utils.load_json import load_json as load_config
from utils.load_jsonl_and_make_text_for_embedding import (
    load_jsonl_and_make_text_for_embedding as load_jsonl_docs,
)
from utils.create_class_from_schema import create_class_from_schema

# Import the newly created modules
from features.embedding_processor import generate_batch_embeddings
from data_handler.for_embedding import prepare_documents, save_results


def build_vectordb_search(
    config_path="../configs/query_encoder/config_PwC-Embedding_expr.json",
    data_schema="/workspace/configs/csv_schema/test_2.json",
    docs_jsonl_path="/workspace/results/searched_docs_db/search_documents_test.jsonl",
    auto_data_load=True,
):
    # 1. Load configurations and create dynamic document class
    config = load_config(config_path)
    try:
        DynamicDocument = create_class_from_schema("Document", data_schema)
    except Exception as e:
        print(f"Failed to load schema and create class: {e}")
        return

    # 2. Load source documents
    if auto_data_load:
        jsonl_path = config.get("jsonl_path", config_path)
    else:
        jsonl_path = docs_jsonl_path
    print(f"Loading documents from {jsonl_path}: Auto Loading Data", auto_data_load)
    documents_data = load_jsonl_docs(jsonl_path)
    print(f"Loaded {len(documents_data)} documents")

    if not documents_data:
        print("No documents found. Exiting.")
        return

    # 3. Generate Embeddings (Separated Logic)
    # This function is now focused only on the ML model and vector generation.
    model_name = config["model_name"]
    embeddings = generate_batch_embeddings(documents_data, model_name)

    if embeddings is None:
        print("Embedding generation failed. Exiting.")
        return

    # 4. Prepare Document Objects for Saving (Separated Logic)
    # This function handles the data structuring, combining raw data with embeddings.
    documents_to_save = prepare_documents(documents_data, embeddings, DynamicDocument)

    if not documents_to_save:
        print("No documents were successfully prepared for saving. Exiting.")
        return

    # 5. Save Results and Update Config (Separated Logic)
    # This function is responsible for all file I/O and finalization.
    save_results(
        config=config,
        documents_to_save=documents_to_save,
        embedding_shape=embeddings.shape,
        document_class=DynamicDocument,
        config_path=config_path,
        model_name=model_name,
    )


In [None]:
# -*- coding: utf-8 -*-
"""
Dense Retrieval Script with Pluggable Retrievers.
Refactored for modularity and extensibility.
"""
from __future__ import annotations

import argparse
import json
import sys
from typing import Any, Dict, List

from . import data_loader, query_encoder, result_saver
from .retrievers import get_retriever, base as retriever_base

def run_retrieval_for_question(
    q_item: data_loader.QuestionItem,
    encoder: query_encoder.QueryEncoder,
    retriever: retriever_base.Retriever,
    vectordb: data_loader.VectorDB,
    top_k: int,
    query_instruction: str,
) -> Dict[str, Any]:
    """Orchestrates the retrieval process for a single question item."""
    all_queries = [q_item.original_question] + q_item.single_hop_questions
    query_metas = [{"type": "original"}] + [
        {"type": "single_hop", "index": i} for i in range(len(q_item.single_hop_questions))
    ]

    query_vecs = encoder.encode(all_queries, instruction=query_instruction)
    scores, indices = retriever.search(query_vecs, top_k=top_k)

    results = []
    for i, (query_text, query_meta) in enumerate(zip(all_queries, query_metas)):
        hits = []
        for rank in range(scores.shape[1]):
            doc_idx = indices[i, rank]
            hit = {
                "rank": rank + 1,
                "score": float(scores[i, rank]),
                "doc_id": vectordb.doc_ids[doc_idx],
                **vectordb.metadata[doc_idx],  # Unpack all metadata
            }
            hits.append(hit)
        results.append({"query": query_text, "query_meta": query_meta, "hits": hits})

    return results


def main():
    p = argparse.ArgumentParser(description="Dense Retrieval with a configured model.")
    p.add_argument("--vectordb_csv", type=str, required=True, help="Path to the vector DB CSV file.")
    p.add_argument("--config_json", type=str, required=True, help="Path to the model configuration JSON.")
    p.add_argument("--questions_jsonl", type=str, required=True, help="Path to the questions JSONL file.")
    p.add_argument("--ids", type=str, help="Comma-separated question IDs to process.")
    p.add_argument("--range", dest="range_spec", type=str, help="1-based inclusive range like '1-10'.")
    p.add_argument("--top_k", type=int, default=10, help="Number of documents to retrieve.")
    p.add_argument("--device", type=str, default="auto", help="PyTorch device ('auto', 'cpu', 'cuda:0').")
    p.add_argument("--output_root", type=str, default="/workspace/results/retrieval_docs", help="Root directory for outputs.")
    p.add_argument("--query_instruction", type=str, default=None, help="Instruction to prepend to queries. Overrides model default.")
    p.add_argument("--force_numpy", action="store_true", help="Force using NumPy retriever even if FAISS is available.")
    args = p.parse_args()

    # --- 1. Load Data ---
    print("[INFO] Loading data...", file=sys.stderr)
    with open(args.config_json, "r", encoding="utf-8") as f:
        config = json.load(f)
    model_name = config.get("model_name")
    if not model_name:
        raise ValueError("'model_name' not found in config JSON.")
        
    vectordb = data_loader.load_vectordb_from_csv(args.vectordb_csv)
    all_questions = data_loader.load_questions_jsonl(args.questions_jsonl)
    
    selected_questions = data_loader.select_questions(
        all_questions,
        ids=args.ids.split(",") if args.ids else None,
        range_spec=args.range_spec
    )
    print(f"[INFO] Loaded {len(vectordb.doc_ids)} documents.", file=sys.stderr)
    print(f"[INFO] Selected {len(selected_questions)} questions to process.", file=sys.stderr)

    # --- 2. Initialize Components ---
    encoder = query_encoder.QueryEncoder(model_name=model_name, device=args.device)
    retriever = get_retriever(vectordb.embeddings, force_numpy=args.force_numpy)
    saver = result_saver.ResultSaver(args.output_root)

    # --- 3. Run Retrieval Loop ---
    saved_paths = []
    for q_item in selected_questions:
        retrieved_data = run_retrieval_for_question(
            q_item, encoder, retriever, vectordb, args.top_k, args.query_instruction
        )
        path = saver.save(q_item.qid, model_name, retrieved_data, q_item.meta)
        saved_paths.append(path)
        print(f"[OK] Saved result for QID {q_item.qid} to {path}", file=sys.stderr)
        
    # --- 4. Final Output ---
    print(json.dumps({"saved_files": saved_paths, "output_folder": saver.session_folder}, indent=2))


if __name__ == "__main__":
    main()
