In [0]:
%pip install databricks-vectorsearch langchain databricks-sdk pdfplumber pymupdf tiktoken pyyaml
%restart_python

In [0]:
from config.loader import init_config, get_config, save_config
config_path = '/Workspace/Users/benjamin.wynn@peraton.com/GEMRAG/config/base_config.yaml'
init_config(config_path)

In [0]:

import os
import time
import datetime
import logging
import yaml
import pandas as pd
import matplotlib.pyplot as plt
from logging_utils.logger import get_logger
from logging_utils.logging_config import setup_logging
# === Project Modules ===

from pdf_importer import run_pdf_ingestion_pipeline
from utils.vector_search_utils import create_endpoint, delete_index, get_index, create_index, index_exists

from databricks.vector_search.client import VectorSearchClient
from utils.rag_analytics import *
from logging_utils.run_logger import setup_run_logger

In [0]:
# === Experiment Directory ===
EXPERIMENT_ROOT = "experiments"
os.makedirs(EXPERIMENT_ROOT, exist_ok=True)

setup_logging(level="INFO")
logger = get_logger(__name__)
logger.info("Notebook started")

from pyspark.sql.types import StructType, StructField, StringType

config = get_config()

schema = StructType([
    StructField("file_path", StringType()),
    StructField("title", StringType()),
    StructField("author", StringType()),
    StructField("subject", StringType()),
    StructField("keywords", StringType()),
    StructField("creation_date", StringType()),
    StructField("mod_date", StringType()),
    StructField("content", StringType())
])

In [0]:
def construct_query(row, columns):
    ret_str = ""
    for i in range(0, len(columns)):
        ret_str += f"{columns[i]} {row[columns[i]]}dm, "
    return re.sub(r'[^\w\s]', '', ret_str)

In [0]:

# === Data Processing Pipeline ===
def process_policy_data():
    vs_config = config["vector_search"]
    doc_config = config["document_parsing"]
    catalog, schema = config["catalog_name"], config["schema_name"]
    

    doc_vol_path = f"/Volumes/sandbox_catalog/default/{doc_config['document_volume']}/"
    text_table_path = f"{catalog}.{schema}.{doc_config['text_table']}"
    chunk_table_path = f"{catalog}.{schema}.{doc_config['chunk_table']}"
    run_pdf_ingestion_pipeline(doc_vol_path, text_table_path, chunk_table_path)

    # Setup vector search endpoint or fetch existing endpoint
    client = VectorSearchClient(disable_notice=True)
    
    create_endpoint(client, vs_config["endpoint_name"])

    # Deletes the old index. This is only necessary if columns are added or removed.
    logger.info(index_exists(client, vs_config["endpoint_name"], vs_config["index_name"]))
    if vs_config["delete_old_index"] and index_exists(client, vs_config["endpoint_name"], vs_config["index_name"]):
        delete_index(client, vs_config["index_name"])
        while index_exists(client, vs_config["endpoint_name"], vs_config["index_name"]):
            logger.info("Waiting for index deletion...")
            time.sleep(5)
    # Creates VS Index or Fetches Existing index
    index = create_index( 
        client,
        vs_config["endpoint_name"],
        vs_config["index_name"],
        "sandbox_catalog.default.gem_text",
        vs_config["primary_key"],
        vs_config["source_column"],
        vs_config["indexed_columns"],
        vs_config["embedding_endpoint"]
    )
    return index




In [0]:
# === Search and Evaluation ===
def test_search_accuracy(index, config):
    experiment = config["experiment"]
    vs_config = config["vector_search"]

    test_df = pd.read_csv(experiment["test_data_path"])
    policy = test_df["policy"]
    claim_lines = test_df.drop(columns="policy")

    results_list = []
    total = claim_lines.shape[0]
    logger.info(f"Testing {total} claim lines...")

    last_percent = -1
    cols = claim_lines.columns
    test_start_time = time.time()
    
    for i, row in claim_lines.iterrows():
        percent = int((i / total) * 100)
        if percent != last_percent and percent % 5 == 0:
            logger.info(f"Progress: {percent}%")
            last_percent = percent

        query_text = construct_query(row, cols)
        start_time = time.time()
        response = index.similarity_search(
            query_text=query_text,
            columns=vs_config["indexed_columns"],
            num_results=experiment["n_results"],
            query_type=experiment["search_type"],
            disable_notice=True
        )
        elapsed = time.time() - start_time

        manifest = response.get("manifest", {}).get("columns", [])
        raw_results = response.get("result", {}).get("data_array", [])

        results = raw_results[:experiment["n_results"]]
        columns = [col.get("name") for col in manifest]
        df = pd.DataFrame(results, columns=columns)

        for _, match_row in df.iterrows():
            result_entry = {
                "claim": row["claim_id"],
                "query_time": elapsed,
                **match_row.to_dict()
            }
            results_list.append(result_entry)
    
    test_end_time = time.time()

    logger.info("Total query test runtime: %s seconds" % (test_end_time - test_start_time))
    logger.info("Average query test runtime: %s seconds" % ((test_end_time - test_start_time) / total))
    logger.info("Completed Query Testing")
    results_df = pd.DataFrame(results_list)
    
    logger.info("Calculating Query Metrics")
    metrics_df = calculate_metrics_parallel(claim_lines, policy, results_df, experiment["n_results"])

    return {"results": results_df, "metrics": metrics_df}

In [0]:

# === Main Pipeline Test Runner ===
def run_pipeline_test():
    # Step 1: Load config
    experiment = config["experiment"]
    doc_config = config["document_parsing"]
    vs_config = config["vector_search"]

    # Step 2: Setup experiment output directory
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    experiment_name = experiment.get("name", "unnamed_experiment")
    output_dir = os.path.join(EXPERIMENT_ROOT, f"{experiment_name}_{timestamp}")
    os.makedirs(output_dir, exist_ok=True)

    # Step 3: Setup experiment-specific logging
    run_logger = setup_run_logger(run_id=experiment_name, log_dir=output_dir)
    run_logger.info(f"Running experiment: {experiment_name}")
    run_logger.info(f"Logs saved to {output_dir}")
    # Step 4: Save config
    save_config(os.path.join(output_dir, "config.yaml"))

    # Step 5: Build or load index
    if False and experiment["create_index"]:
        run_logger.info("Creating index")
        start_time = time.time()
        # index = process_policy_data()
        end_time = time.time()

        run_logger.info("Total policy ingestion test runtime: %s seconds" % (end_time - start_time))
    else:
        run_logger.info("Skipping index creation, fetching existing index.")
        client = VectorSearchClient(disable_notice=True)
        index = get_index(
            client,
            vs_config["endpoint_name"],
            vs_config["index_name"]
        )

    # Step 6: Perform search and evaluation
    if experiment["test_search"]:
        results_obj = test_search_accuracy(index=index, config=config)
        
        run_logger.info("Saving results to experiments directory.")
        results_path = os.path.join(output_dir, "results.csv")
        metrics_path = os.path.join(output_dir, "metrics.csv")
        plot_path = os.path.join(output_dir, "plot")
        
        # results_obj["results"].to_csv(results_path, index=False)
        # results_obj["metrics"].to_csv(metrics_path, index=False)

        save_metrics_plot(results_obj["metrics"], plot_path)
        save_timings_plot(results_obj["results"]['query_time'], plot_path)
        save_precision_recall_plot(results_obj["metrics"], plot_path)
        save_metric_distributions(results_obj["metrics"], plot_path)

        run_logger.info(f"Results saved to {results_path}")
        run_logger.info(f"Metrics saved to {metrics_path}")

        return results_obj


In [0]:

test_res = run_pipeline_test()