In [1]:
import polars as pl
from neo4j import GraphDatabase
from embedder import Retriever, Gemini_Embeddings

## Database Connections
URI = "bolt://localhost:7687"
AUTH = ("neo4j", "fairusecases")

driver = GraphDatabase.driver(URI, auth=AUTH)

In [None]:
retriever = Retriever(Gemini_Embeddings(), driver)
df = pl.read_csv("../Data/case_complaints.csv")

In [None]:
## Helper function to get all the cases
def get_retrieved_cases(text, text_sim, court_weight, cit_weight, retriever=retriever):
    
    df_cases = retriever.search_similar_cases(text, 10, text_sim, court_weight, cit_weight)

    cases = df_cases["Case"].to_list()
    text_similarities = df_cases["TextSimilarity"].to_list()
    pagerank = df_cases["CasePageRank"].to_list()
    courts = df_cases["CourtName"].to_list()

    return {"cases": cases, "text_sim": text_similarities, "pagerank": pagerank, "court": courts}

In [None]:
## Standard RAG
df_standard_RAG = df.with_columns(
    pl.col("Complaint").map_elements(lambda x: get_retrieved_cases(x, 1, 0, 0)).alias("retrieved")
).unnest("retrieved")

df_standard_RAG.explode(["cases", "text_sim",	"pagerank",	"court"]).write_csv("./Data/StandardRAGRetrieval.csv")

In [None]:
## Structured RAG
df_pagerank = df.with_columns(
    pl.col("Complaint").map_elements(lambda x: get_retrieved_cases(x, .33, .33, .33)).alias("retrieved")
).unnest("retrieved")

df_pagerank.explode(["cases",	"text_sim",	"pagerank",	"court"]).write_csv("./Data/PRRAGRetrieval.csv")