In [None]:
import os
import sys
from typing import Any, Dict, Union


sys.path.append(os.path.join(os.getcwd(), ".."))


import pandas as pd
from pprint import pprint
from dotenv import load_dotenv
from langchain_neo4j import Neo4jGraph
from langchain_ollama import ChatOllama
from langchain_anthropic import ChatAnthropic
from langchain_huggingface import HuggingFaceEmbeddings
from ragas.dataset_schema import EvaluationDataset, EvaluationResult
from ragas.messages import (
    HumanMessage,
    ToolCall
)
from src.grag import (
    create_vector_cypher_retriever_tool,
    create_text2cypher_retriever_tool,
    run_tools_selection_workflow,
    evaluate_tools_selection,
)


load_dotenv()

## **Preparation**

In [None]:
OUTPUT_PATH = os.path.join("results", "llm_tools_selection")
DATASET_PATH = os.path.join("data", "testing_dataset.xlsx")

os.makedirs(OUTPUT_PATH, exist_ok=True)

df = pd.read_excel(DATASET_PATH)
dataset = []

for idx, row in df.iterrows():
    dataset.append({
        "user_input": [HumanMessage(content=str(row["user_input"]))],
        "reference_tool_calls": [
            ToolCall(
                name=str(row["reference_tool_call"]),
                args={"query": str(row["user_input"])}
            )
        ]
    })

evaluation_dataset = EvaluationDataset.from_list(dataset)

print(len(evaluation_dataset))

## **Evaluation**

In [None]:
CLAUDE_LLM_MODEL_NAME = "claude-3-5-haiku-20241022"
LLAMA_LLM_MODEL_NAME = "llama3.1:8b-instruct-q4_K_M"
EMBEDDING_MODEL_NAME = "intfloat/multilingual-e5-large"

# Large Language Model
claude_llm = ChatAnthropic(
    model_name=CLAUDE_LLM_MODEL_NAME,
    max_tokens_to_sample=4096,
    temperature=0.0,
    timeout=None,
    api_key=os.environ["ANTHROPIC_API_KEY"],
)

llama_llm = ChatOllama(
    model=LLAMA_LLM_MODEL_NAME,
    num_ctx=32768,
    num_predict=4096,
    temperature=0.0,
)

# Embedding Model
embedding_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)

In [None]:
URI = os.environ["NEO4J_HOST"]
USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"]
DATABASE = "db-large"


def save_experiment_dataset_or_result(
    dataset: Union[EvaluationDataset, EvaluationResult],
    experiment_name: str,
) -> None:
    result_df = dataset.to_pandas()
    result_df["reference_tool_calls"] = result_df["reference_tool_calls"].apply(
        lambda x: x[0]["name"]
    )

    result_df.to_json(
        os.path.join(OUTPUT_PATH, f"{experiment_name}.json"),
        orient="records",
    )


def run_test_case(test_case: Dict[str, Any]) -> Dict[str, Any]:
    experiment_name = test_case["llm_model"].model.replace(":", "-")

    neo4j_graph = Neo4jGraph(
        url=URI,
        username=USERNAME,
        password=PASSWORD,
        database=DATABASE,
        enhanced_schema=True
    )

    vector_cypher_retriever = create_vector_cypher_retriever_tool(
        neo4j_graph=neo4j_graph,
        embedder_model=embedding_model
    )

    text2cypher_retriever = create_text2cypher_retriever_tool(
        neo4j_graph=neo4j_graph,
        cypher_llm=test_case["llm_model"],
        embedder_model=embedding_model,
    )

    evaluation_dataset_completed = run_tools_selection_workflow(
        evaluation_dataset,
        experiment_name,
        model=test_case["llm_model"],
        tools=[vector_cypher_retriever, text2cypher_retriever],
    )

    # Checkpoint 1
    save_experiment_dataset_or_result(
        evaluation_dataset_completed,
        experiment_name=experiment_name,
    )

    evaluation_result = evaluate_tools_selection(
        evaluation_dataset_completed,
        experiment_name=experiment_name,
    )

    # Checkpoint 2
    save_experiment_dataset_or_result(
        evaluation_result,
        experiment_name=experiment_name,
    )

    return {
        "experiment_name": experiment_name,
        "args": {"llm": test_case["llm_model"].model},
        "evaluation_result": evaluation_result,
    }

In [None]:
test_cases = [
    # Llama (local)
    {"llm_model": llama_llm},
    # Claude (API)
    {"llm_model": claude_llm},
]

### **Test Case 1**

- Llama 3.1 8B Instruct

In [None]:
test_result_1 = run_test_case(test_cases[0])

In [None]:
pprint(test_result_1)

### **Test Case 2**

- Claude 3.5 Haiku

In [None]:
test_result_2 = run_test_case(test_cases[1])

In [None]:
pprint(test_result_2)