In [None]:
import os
import sys
from ast import literal_eval
from typing import Any, Dict, List, Union


sys.path.append(os.path.join(os.getcwd(), ".."))


import pandas as pd
from dotenv import load_dotenv
from langchain_neo4j import Neo4jGraph
from langchain_ollama import ChatOllama
from langchain_anthropic import ChatAnthropic
from langchain_huggingface import HuggingFaceEmbeddings
from ragas.dataset_schema import EvaluationDataset, EvaluationResult
from src.grag import run_text2cypher_workflow, evaluate_retriever


load_dotenv()

## **Preparation**

In [None]:
OUTPUT_PATH = os.path.join("results", "text2cypher_retriever")
DATASET_PATH = os.path.join("data", "testing_dataset.xlsx")

os.makedirs(OUTPUT_PATH, exist_ok=True)

df = pd.read_excel(DATASET_PATH)
dataset = []

for idx, row in df.iterrows():
    if row["is_valid"]:
        dataset.append(
            {
                "user_input": str(row["user_input"]),
                "reference_contexts": literal_eval(row["reference_contexts_1"]),
            }
        )

evaluation_dataset = EvaluationDataset.from_list(dataset)

len(evaluation_dataset)

## **Evaluation**

In [None]:
URI = os.environ["NEO4J_HOST"]
USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"]


def print_result_summary(result: Dict[str, Any]) -> None:
    for key, value in result.items():
        if key != "cypher_result":
            print(f"{key}: {value}")


def save_experiment_dataset_or_result(
    dataset: Union[EvaluationDataset, EvaluationResult],
    generated_cypher_results: List[str],
    experiment_name: str
) -> None:
    df = dataset.to_pandas()
    df["cypher_result"] = generated_cypher_results
    df.to_json(
        os.path.join(OUTPUT_PATH, f"{experiment_name}.json"),
        orient="records",
    )


def run_test_case(test_case: Dict[str, Any]) -> Dict[str, Any]:
    if test_case["embedding_model"]:
        prompt_type = "few-shot"
    else:
        prompt_type = "zero-shot"

    experiment_name = f"{test_case['llm_model'].model}_{prompt_type}".replace(
        ":", "-"
    )

    neo4j_graph = Neo4jGraph(
        url=URI,
        username=USERNAME,
        password=PASSWORD,
        database=test_case["database_name"],
        enhanced_schema=True
    )

    text2cypher_workflow_result = run_text2cypher_workflow(
        evaluation_dataset,
        experiment_name,
        neo4j_graph=neo4j_graph,
        cypher_llm=test_case["llm_model"],
        embedder_model=test_case["embedding_model"],
    )

    evaluation_dataset_completed, generated_cypher_results = (
        text2cypher_workflow_result
    )

    # Checkpoint 1
    save_experiment_dataset_or_result(
        evaluation_dataset_completed,
        generated_cypher_results,
        experiment_name=experiment_name
    )

    evaluation_result = evaluate_retriever(
        evaluation_dataset_completed,
        experiment_name=experiment_name,
    )

    # Checkpoint 2
    save_experiment_dataset_or_result(
        evaluation_result,
        generated_cypher_results,
        experiment_name=experiment_name
    )

    return {
        "experiment_name": experiment_name,
        "args": {
            "database": test_case["database_name"],
            "llm": test_case["llm_model"].model,
            "prompt_type": prompt_type
        },
        "evaluation_result": evaluation_result,
        "cypher_result": generated_cypher_results,
    }

In [None]:
CLAUDE_LLM_MODEL_NAME = "claude-3-5-haiku-20241022"
LLAMA_LLM_MODEL_NAME = "llama3.1:8b-instruct-q4_K_M"
EMBEDDING_MODEL_NAME = "intfloat/multilingual-e5-large"

claude_llm = ChatAnthropic(
    model_name=CLAUDE_LLM_MODEL_NAME,
    max_tokens_to_sample=4096,
    temperature=0.0,
    timeout=None,
    api_key=os.environ["ANTHROPIC_API_KEY"],
)

llama_llm = ChatOllama(
    model=LLAMA_LLM_MODEL_NAME,
    num_ctx=32768,
    num_predict=4096,
    temperature=0.0,
)

embedding_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)

In [None]:
test_cases = [
    {
        # Llama Zero-shot
        "llm_model": llama_llm,
        "embedding_model": None,
        "database_name": "db-large"
    },
    {
        # Llama Few-shot
        "llm_model": llama_llm,
        "embedding_model": embedding_model,
        "database_name": "db-large"
    },
    {
        # Claude Zero-shot
        "llm_model": claude_llm,
        "embedding_model": None,
        "database_name": "db-large"
    },
    {
        # Claude Few-shot
        "llm_model": claude_llm,
        "embedding_model": embedding_model,
        "database_name": "db-large"
    },
]

### **Test Case 1**

- Llama 3.1 8B Instruct
- Zero-Shot

In [None]:
test_result_1 = run_test_case(test_cases[0])

In [None]:
print_result_summary(test_result_1)

### **Test Case 2**

- Llama 3.1 8B Instruct
- Few-Shot

In [None]:
test_result_2 = run_test_case(test_cases[1])

In [None]:
print_result_summary(test_result_2)

### **Test Case 3**

- Claude 3.5 Haiku
- Zero-Shot

In [None]:
test_result_3 = run_test_case(test_cases[2])

In [None]:
print_result_summary(test_result_3)

### **Test Case 4**

- Claude 3.5 Haiku
- Few-Shot

In [None]:
test_result_4 = run_test_case(test_cases[3])

In [None]:
print_result_summary(test_result_4)