In [99]:
!pip install langchain openai neo4j



In [100]:
import os
from typing import List, Dict, Optional
from typing import Any, List, Tuple
from typing import Optional, Type

from langchain.graphs import Neo4jGraph
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_to_openai_function_messages
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.pydantic_v1 import BaseModel, Field
from langchain.schema import AIMessage, HumanMessage
from langchain.tools.render import format_tool_to_openai_function
from langchain_community.chat_models import ChatOpenAI

from langchain.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
# Import things that are needed generically
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool

In [101]:
os.environ['OPENAI_API_KEY'] = "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"

graph = Neo4jGraph(
    url="bolt://localhost:7687", username="neo4j", password="kpmgkpmg"
)

In [102]:
from typing import List, Dict, Optional
from pydantic import BaseModel, Field

# Assuming 'graph' is an initialized graph database connection

def generate_full_text_query(input: str) -> str:
    full_text_query = ""
    words = input.split()
    for word in words[:-1]:
        full_text_query += f" {word}~0.8 AND"
    full_text_query += f" {words[-1]}~0.8"
    return full_text_query.strip()

candidate_query = """
CALL db.index.fulltext.queryNodes($index, $fulltextQuery)
YIELD node
RETURN coalesce(node.name, node.title) AS candidate, [el in labels(node) WHERE el IN ['section', 'model'] | el][0] AS label
LIMIT toInteger($limit)
"""

description_query = """
MATCH (entity)
WHERE entity.name = $candidate AND ('section' IN labels(entity) OR 'model' IN labels(entity))
WITH entity, 
     CASE WHEN 'model' IN labels(entity) THEN 'model'
          WHEN 'section' IN labels(entity) THEN 'section'
          ELSE 'unknown'
     END as entityType
OPTIONAL MATCH (entity)-[r]-(related)
WITH entity, entityType, 
     CASE WHEN entityType = 'model' THEN collect({relation: type(r), relatedNode: related.name})
          ELSE []
     END as modelRelations,
     CASE WHEN entityType = 'section' THEN collect({relation: type(r), relatedNode: related.name})
          ELSE []
     END as sectionRelations
RETURN 
    CASE 
        WHEN entityType = 'model' THEN 
            'Type: ' + entityType + '\nName: ' + entity.name + 
            '\nAuthor: ' + entity.author +
            '\nCreated At: ' + entity.created_at +
            '\nInput Columns: ' + entity.input_columns + 
            '\nOutput Columns: ' + entity.output_column +
            '\nParameters: ' + entity.parameters + 
            '\nPerformance Metrics: ' + entity.performance_metrics
        WHEN entityType = 'section' THEN
            'Type: ' + entityType + '\nName: ' + entity.name
        ELSE 'Entity type not recognized.'
    END as context
LIMIT 1
"""

def classify_query(input: str) -> str:
    if "model" in input:
        return "model"
    elif "section" in input or "fields" in input:
        return "section"
    else:
        return "unknown"

def get_candidates(input: str, type: str, limit: int = 3) -> List[Dict[str, str]]:
    ft_query = generate_full_text_query(input)
    candidates = graph.query(candidate_query, {'fulltextQuery': ft_query, 'index': type, 'limit': limit})
    return candidates

def get_information(entity: str, query_type:str) -> str:
    type = classify_query(query_type)
    if type == "unknown":
        return "Query classification failed. Please provide more specific information."
    
    candidates = get_candidates(entity, type)
    if not candidates:
        return "No information was found about the section or model in the database"
    elif len(candidates) > 1:
        newline = '\n'
        return f"Need additional information, which of these did you mean: {newline + newline.join(str(d) for d in candidates)}"
    data = graph.query(
        description_query, params={"candidate": candidates[0]['candidate']}
    )
    return data[0]["context"]

# Update the InformationInput and InformationTool classes as needed

In [103]:
print(get_information("Budget Summary", "section"))

Type: section
Name: Budget Summary


In [104]:
print(get_information("model_782951116", "model"))

Type: model
Name: model_782951116
Author: ["Alex", "Martin"]
Created At: 2024-02-12T20:59:17Z
Input Columns: ["revenue", "cost", "profitMargin", "tax_col"]
Output Columns: revenue
Parameters: [{name: "example_parameter", value: 77}]
Performance Metrics: [
    {name: "mean absolute error", value: 10228},
    {name: "root mean squared error", value: 101},
    {name: "R-squared", value: 0.95}
  ]


In [105]:
class InformationInput(BaseModel):
    entity: str = Field(description="movie or a person mentioned in the question")
    entity_type: str = Field(
        description="type of the entity. Available options are 'movie' or 'person'"
    )


class InformationTool(BaseTool):
    name = "Information"
    description = (
        "useful for when you need to answer questions about various actors or movies"
    )
    args_schema: Type[BaseModel] = InformationInput

    def _run(
        self,
        entity: str,
        entity_type: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        """Use the tool."""
        return get_information(entity, entity_type)

    async def _arun(
        self,
        entity: str,
        entity_type: str,
        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> str:
        """Use the tool asynchronously."""
        return get_information(entity, entity_type)

In [106]:
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo")
tools = [InformationTool()]  # Only include the InformationTool

llm_with_tools = llm.bind(functions=[format_tool_to_openai_function(t) for t in tools])

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a helpful assistant that finds information about sections. "
            "If the tool requires follow-up questions, make sure to ask the user "
            "for clarification. Include any available options that need to be "
            "clarified in the follow-up questions. Do only the things the user "
            "specifically requested.",
        ),
        MessagesPlaceholder(variable_name="chat_history"),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="agent_scratchpad"),
    ]
)

def _format_chat_history(chat_history: List[Tuple[str, str]]):
    buffer = []
    for human, ai in chat_history:
        buffer.append(HumanMessage(content=human))
        buffer.append(AIMessage(content=ai))
    return buffer

agent = (
    {
        "input": lambda x: x["input"],
        "chat_history": lambda x: _format_chat_history(x["chat_history"])
        if x.get("chat_history")
        else [],
        "agent_scratchpad": lambda x: format_to_openai_function_messages(
            x["intermediate_steps"]
        ),
    }
    | prompt
    | llm_with_tools
    | OpenAIFunctionsAgentOutputParser()
)

# Add typing for input
class AgentInput(BaseModel):
    input: str
    chat_history: List[Tuple[str, str]] = Field(
        ..., extra={"widget": {"type": "chat", "input": "input", "output": "output"}}
    )

class Output(BaseModel):
    output: Any

agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True).with_types(
    input_type=AgentInput, output_type=Output
)

In [107]:
if __name__ == "__main__":
    #original_query = "What do you know about the model model_782951116?"
    original_query = "What is the performance metrics of the model model_782951116?"
    print(agent_executor.invoke({"input": original_query}))



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `Information` with `{'entity': 'model_782951116', 'entity_type': 'model'}`


[0m[36;1m[1;3mType: model
Name: model_782951116
Author: ["Alex", "Martin"]
Created At: 2024-02-12T20:59:17Z
Input Columns: ["revenue", "cost", "profitMargin", "tax_col"]
Output Columns: revenue
Parameters: [{name: "example_parameter", value: 77}]
Performance Metrics: [
    {name: "mean absolute error", value: 10228},
    {name: "root mean squared error", value: 101},
    {name: "R-squared", value: 0.95}
  ][0m[32;1m[1;3mThe performance metrics of model_782951116 are:
- Mean Absolute Error: 10228
- Root Mean Squared Error: 101
- R-squared: 0.95[0m

[1m> Finished chain.[0m
{'input': 'What is the performance metrics of the model model_782951116?', 'output': 'The performance metrics of model_782951116 are:\n- Mean Absolute Error: 10228\n- Root Mean Squared Error: 101\n- R-squared: 0.95'}


In [109]:
if __name__ == "__main__":
    #original_query = "What do you know about the model model_782951116?"
    original_query = "what predictive models are used in the report Global Economic Outlook - December 2023？"
    print(agent_executor.invoke({"input": original_query}))



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `Information` with `{'entity': 'Global Economic Outlook - December 2023', 'entity_type': 'report'}`


[0m[36;1m[1;3mQuery classification failed. Please provide more specific information.[0m[32;1m[1;3mI'm sorry, but I couldn't retrieve the specific models used in the report "Global Economic Outlook - December 2023." If you have any other specific questions or details you'd like me to look into, please let me know![0m

[1m> Finished chain.[0m
{'input': 'what models are used in the report Global Economic Outlook - December 2023？', 'output': 'I\'m sorry, but I couldn\'t retrieve the specific models used in the report "Global Economic Outlook - December 2023." If you have any other specific questions or details you\'d like me to look into, please let me know!'}
