In [1]:
import openai
from langchain.chains import GraphCypherQAChain
from langchain_community.graphs import Neo4jGraph
from langchain_openai import ChatOpenAI
from typing import Dict, List
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
)
from FlagEmbedding.bge_m3 import BGEM3FlagModel
from prompts import LAYER_ONE_PROMPT, LAYER_TWO_PROMPT, LAYER_THREE_PROMPT, QA_PROMPT
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate

API_KEY = "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"

class TPT:
    
    def __init__(self, graphDB: Neo4jGraph):
        self.graphDB = graphDB
        self.Tool1 = ThreeLayerGPT(self.graphDB)
        self.Tool2 = BaseGPT(self.graphDB)
        self.model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
        
class ThreeLayerGPT:
    
    def __init__(self, graphDB: Neo4jGraph):
        self.prompts = [LAYER_ONE_PROMPT, LAYER_TWO_PROMPT, LAYER_THREE_PROMPT]
        self.chat1 = ChatOpenAI(api_key=API_KEY, temperature=0)
        self.chat2 = ChatOpenAI(api_key=API_KEY, temperature=0)
        self.chat3 = ChatOpenAI(api_key=API_KEY, temperature=0)
        self.chat4 = ChatOpenAI(api_key=API_KEY, temperature=0)
        self.graphDB = graphDB
        
    def __call__(self, message_input):
        return self.getResult(message_input)
        
    def _distillation(self, message):
        messages = [
        SystemMessage(
            content=self.prompts[0]
        ),
        HumanMessage(
            content=message
        )]
        response = self.chat1.invoke(messages)
        content = eval(response.dict()['content'])
        return content        
        
    def _queryLogic(self, message):
        messages = [
        SystemMessage(
            content=self.prompts[1]
        ),
        HumanMessage(
            content=message
        )]
        response = self.chat2.invoke(messages)
        content = response.dict()['content']
        return content  
    
    def _queryGeneration(self, message):
        messages = [
        SystemMessage(
            content=self.prompts[2]
        ),
        HumanMessage(
            content=message
        )]
        response = self.chat3.invoke(messages)
        content = response.dict()['content']
        return content  
    
    def _generateAnswer(self, message):
        human_prompt_template = PromptTemplate.from_template(
        template= "The original question is {question}, the previous LLM has generate intermediate result {answer}")

        human_prompt = human_prompt_template.format(
            question =  message["question"], 
            answer = message["answer"]
        )

        messages = [
            SystemMessage(
                content=QA_PROMPT
            ),
            HumanMessage(
                content= human_prompt
            )]
        
        response = self.chat4.invoke(messages)
        content = response.dict()['content']
        return content           
    
    def getResult(self, message_input):
        response = self._distillation(message_input)
        message_layer1 = "User Requirement:" + str(message_input) + "\n" + "Query Logic and Relationship:" + str(response)
        response = self._queryLogic(message_layer1)
        updatemessage = "User Requirement:" + str(message_input) + "\n" + "Query Logic Extraction:" + str(response)
        query = self._queryGeneration(updatemessage) 
        data = self.graphDB.query(query)  
        user_message = {"question": message_input, "answer": data}
        final_output = self._generateAnswer(user_message)
        return final_output

class BaseGPT:
    
    def __init__(self, graphDB: Neo4jGraph):
        self.graphDB = graphDB
        self.chain = GraphCypherQAChain.from_llm(
            llm = ChatOpenAI(temperature=0, api_key=API_KEY),
            graph=self.graphDB, verbose=False
        )
    
    def __call__(self, message):
        try:
            res = self.chain(message)
        except:
            return "Error Query"
        return res['result']
    

RuntimeError: Failed to import transformers.data.data_collator because of the following error (look up to see its traceback):
Descriptors cannot be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

In [None]:
graph = Neo4jGraph(url="neo4j+s://9014d857.databases.neo4j.io", username="neo4j", password="RYYNymC3Bug5n9Il-ke_RPAHKkIQNlP5zujB-B8H8w8")
tool1 = ThreeLayerGPT(graph)

message = "How many reports are there in the database"

res1 = tool1(message)

print(res1)