为了实现集装箱锂离子电池储能系统领域火灾爆炸事故报告的RAG构建，我们使用以下代码实现：

注：我们选取2019年美国亚利桑那McMicken电池储能系统火灾事故、2020年英国利物浦Carnegie Road电池储能系统火灾事故、2021年澳大利亚维多利亚大电池火灾事故、2023年瑞典哥德堡集装箱电池储能系统火灾事故为例进行研究。相关文件在文件夹中。

以下代码在Colab中运行。

（1）安装相关包

In [None]:
!pip install langchain_openai
!pip install langchain_chroma
!pip install langchain_community
!pip install pypdf
!pip install chromadb
!pip install langgraph

（2）设置环境变量

In [None]:
import os
os.environ['OPENAI_API_KEY'] = '你的OPENAI_API_KEY'
os.environ['LANGCHAIN_API_KEY'] = '你的LANGCHAIN_API_KEY'

（3）将文件添加至Chroma（此处将文件分段处理，也可以不分段）

In [None]:
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_community.document_loaders import PyPDFLoader

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

import os
import openai
openai.api_key = os.getenv("OPENAI_API_KEY")
print("run_api")

import chromadb
from chromadb.config import Settings
reports_client_path = r"指定的db路径"
reports_client = chromadb.PersistentClient(path=reports_client_path,settings=Settings(allow_reset=True))
reports_client.reset()

collection_name = "accident_reports_collection"
embedding_function = OpenAIEmbeddings()
accident_reports_collection = reports_client.get_or_create_collection(name=collection_name)

vectorstore = Chroma(
    client=reports_client,
    collection_name=collection_name,
    embedding_function=OpenAIEmbeddings(),
)
print("run_vectorstore")

"""加载pdf文件"""
def load_pdf(file_path):
    # 创建 PyPDFLoader 对象
    loader = PyPDFLoader(file_path=file_path)
    pages = []
    # 按页码依次处理
    for page in loader.lazy_load():
        pages.append(page)
    print("run_loader")
    return pages

"""分割文本为更小的块"""
def split_pdf(data):
    pdf_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    print("run_split_pdf")
    return pdf_splitter.split_documents(data)


def data_process(file_path,file_information):
    """主函数，处理用户输入"""
    # 加载并处理数据
    data = load_pdf(file_path)
    print(len(data))
    batch_size = 10
    for i in range(0, len(data), batch_size):
      start_index = 0
      end_index = min(start_index + batch_size, len(data))
      batch = data[start_index:end_index]
      splits = split_pdf(batch)
      id_batches=[]
      for _ in range(len(splits)):
          id_batch = f"{file_information}_{start_index}_{end_index}_{_}"
          id_batches.append(id_batch)
      vectorstore.add_documents(documents=splits, ids=id_batches,collection_name=collection_name)
      print("run_add")
    print("run_data_process")

def add_reports():
    file_path_1 = r"/content/drive/MyDrive/Colab Notebooks/Four_Firefighters_Injured_In_Lithium_Ion_Battery_ESS_Explosion_Arizona_0.pdf"
    file_path_2 = r"/content/drive/MyDrive/Colab Notebooks/2020-09-15 UK, Liverpool - Investigative Report.pdf"
    file_path_3 = r"/content/drive/MyDrive/Colab Notebooks/2021-07-30 Australia, Victoria, Moorabool - Investigation Report.pdf"
    file_path_4 = r"/content/drive/MyDrive/Colab Notebooks/2023-04-26 Sweden, Gothenburg - Investigation Report (English Translation).pdf"

    data_process(file_path_1,"2019_US")
    print("run_1")
    data_process(file_path_2,"2020_UK")
    print("run_2")
    data_process(file_path_3,"2021_Australia")
    print("run_3")
    data_process(file_path_4,"2023_Sweden")
    print("run_4")


if __name__ == "__main__":
    add_reports()

（4）检查是否创建集合

输出：[Collection(name=accident_reports_collection)]

In [None]:
import chromadb
from chromadb.config import Settings
reports_client_path = r"指定的db路径"
reports_client = chromadb.PersistentClient(path=reports_client_path,settings=Settings(allow_reset=True))
collections = reports_client.list_collections()
print(collections)

（5）使用hub提示模版

该提示模版要求大模型使用三句话回答。

In [None]:
"""使用hub中的提示模版"""
from langchain import hub
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import ChatOpenAI
from langchain_chroma import Chroma
from langchain_community.document_loaders import PyPDFLoader

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

import os
import openai
openai.api_key = os.getenv("OPENAI_API_KEY")
from langsmith import Client
api_key = os.getenv('LANGCHAIN_API_KEY')
print(os.getenv('LANGCHAIN_API_KEY'))
client = Client(api_key=api_key)
print("LangSmith client initialized successfully!")
print("run_api")

import chromadb
from chromadb.config import Settings
reports_client_path = r"指定的db路径"
reports_client = chromadb.PersistentClient(path=reports_client_path,settings=Settings(allow_reset=True))

collection_name = "accident_reports_collection"
embedding_function = OpenAIEmbeddings()
accident_reports_collection = reports_client.get_or_create_collection(name=collection_name)

vectorstore = Chroma(
    client=reports_client,
    collection_name=collection_name,
    embedding_function=OpenAIEmbeddings(),
)
# print(vectorstore)
print("run_vectorstore")

def run():
    # 初始化问答链
  llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.9, max_tokens = 4000)
  # qa_chain = RetrievalQA.from_chain_type(llm, retriever=vectorstore.as_retriever())
  # 用户提问循环
  while True:
      query = input("请输入问题，例如：维多利亚大电池火灾事故发生的原因是什么？\n")
      if query.lower() == "end":
          print("程序结束！")
          break

      print(f"问题：{query}")
      # answer = qa_chain.invoke({"query": query})
      prompt = hub.pull("rlm/rag-prompt")
      retrieved_docs = vectorstore.similarity_search(query=query, k=15)
      print("相似性文档如下：", retrieved_docs)
      docs_content = "\n".join(doc.page_content for doc in retrieved_docs)
      messages = prompt.invoke({"question": query, "context": docs_content})
      # invoke 的作用是：基于提供的 question 和 context，生成一个适合的提示。它可能会将 query 和 docs_content 结合成一个自然语言提示
      response = llm.invoke(messages)
      print(f"回答：{response.content}")


if __name__ == "__main__":
    run()

（6）使用hub中的提示模版，并定义状态类

In [None]:
"""使用hub中的提示模版，并定义状态类"""
from langchain import hub
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain_openai import ChatOpenAI
from langchain_chroma import Chroma
from langchain_community.document_loaders import PyPDFLoader

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

import os
import openai
openai.api_key = os.getenv("OPENAI_API_KEY")
from langsmith import Client
api_key = os.getenv('LANGCHAIN_API_KEY')
print(os.getenv('LANGCHAIN_API_KEY'))
client = Client(api_key=api_key)
print("LangSmith client initialized successfully!")
print("run_api")

import chromadb
from chromadb.config import Settings
reports_client_path = r"指定的db路径"
reports_client = chromadb.PersistentClient(path=reports_client_path,settings=Settings(allow_reset=True))

collection_name = "accident_reports_collection"
embedding_function = OpenAIEmbeddings()
accident_reports_collection = reports_client.get_or_create_collection(name=collection_name)

vectorstore = Chroma(
    client=reports_client,
    collection_name=collection_name,
    embedding_function=OpenAIEmbeddings(),
)
# print(vectorstore)
print("run_vectorstore")

def run():
    # 初始化问答链
    llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.7)
    prompt = hub.pull("rlm/rag-prompt")
    from langchain_core.documents import Document
    from typing_extensions import List, TypedDict
    from langgraph.graph import START, StateGraph

    class State(TypedDict):
        question: str
        context: List[Document]
        answer: str

    def retrieve(state: State):
        retrieved_docs = vectorstore.similarity_search(state["question"], k=15, filter = {"source": source} )
        print("相似性文档如下：", retrieved_docs)
        return {"context": retrieved_docs}


    def generate(state: State):
        docs_content = "\n\n".join(doc.page_content for doc in state["context"])
        messages = prompt.invoke({"question": state["question"], "context": docs_content})
        response = llm.invoke(messages)
        return {"answer": response.content}
  # 用户提问循环
    while True:
        query = input("请输入问题，例如：维多利亚大电池火灾事故发生的原因是什么？\n")
        if query.lower() == "end":
            print("程序结束！")
            break
        print(f"问题：{query}")
        source = input("请输入数据来源，例如：/content/drive/MyDrive/Colab Notebooks/2019-04-19 US, AZ, Surprise - Investigation.pdf\n 或 /content/drive/MyDrive/Colab Notebooks/2020-09-15 UK, Liverpool - Investigative Report.pdf\n 或 /content/drive/MyDrive/Colab Notebooks/2021-07-30 Australia, Victoria, Moorabool - Investigation Report.pdf\n 或 /content/drive/MyDrive/Colab Notebooks/2023-04-26 Sweden, Gothenburg - Investigation Report (English Translation).pdf\n")
        print(f"问题：{source}")

        graph_builder = StateGraph(State).add_sequence([retrieve, generate])
        graph_builder.add_edge(START, "retrieve")
        graph = graph_builder.compile()

        response = graph.invoke({"question": query})
        print(response["answer"])


if __name__ == "__main__":
    run()

（7）使用自定义提示模版，性能较好

In [None]:
"""使用自定义提示模版，性能较好"""
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain_openai import ChatOpenAI
from langchain_chroma import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.documents import Document
from typing_extensions import List, TypedDict
from langgraph.graph import START, StateGraph
from langchain_core.prompts import PromptTemplate

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

import os
import openai
openai.api_key = os.getenv("OPENAI_API_KEY")
from langsmith import Client
api_key = os.getenv('LANGCHAIN_API_KEY')
client = Client(api_key=api_key)
print("run_api")

import chromadb
from chromadb.config import Settings
reports_client_path = r"指定的db路径"
reports_client = chromadb.PersistentClient(path=reports_client_path,settings=Settings(allow_reset=True))

collection_name = "accident_reports_collection"
embedding_function = OpenAIEmbeddings()
accident_reports_collection = reports_client.get_or_create_collection(name=collection_name)

vectorstore = Chroma(
    client=reports_client,
    collection_name=collection_name,
    embedding_function=OpenAIEmbeddings(),
)
# print(vectorstore)
print("run_vectorstore")

def run():
    # 初始化问答链
    llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.7)
    # prompt = hub.pull("rlm/rag-prompt")
    # You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use 1000 words maximum and keep the answer concise.
    template = """你是一个问答任务的助手。请使用以下信息回答问题。如果不知道，则直接总结question和context的内容。请使用三句话回答并确保回答的准确性。

    Question: {question}

    Context: {context}

    Helpful Answer:"""
    custom_rag_prompt = PromptTemplate.from_template(template)

    class State(TypedDict):
        question: str
        context: List[Document]
        answer: str

    def retrieve(state: State):
        retrieved_docs = vectorstore.similarity_search(state["question"], k=15, filter = {"source": source} )
        print("相似性文档如下：", retrieved_docs)
        return {"context": retrieved_docs}


    def generate(state: State):
        from langchain_core.messages import HumanMessage
        docs_content = "\n\n".join(doc.page_content for doc in state["context"])
        messages = custom_rag_prompt.invoke({
            "question": state["question"],
            "context": docs_content})
        print(messages)
        response = llm.invoke(messages)
        print(response)
        return {"answer": response.content}
  # 用户提问循环
    while True:
        query = input("请输入问题，例如：维多利亚大电池火灾事故发生的原因是什么？\n")
        if query.lower() == "end":
            print("程序结束！")
            break
        print(f"问题：{query}")
        source = input("请输入数据来源，例如：/content/drive/MyDrive/Colab Notebooks/2019-04-19 US, AZ, Surprise - Investigation.pdf\n 或 /content/drive/MyDrive/Colab Notebooks/2020-09-15 UK, Liverpool - Investigative Report.pdf\n 或 /content/drive/MyDrive/Colab Notebooks/2021-07-30 Australia, Victoria, Moorabool - Investigation Report.pdf\n 或 /content/drive/MyDrive/Colab Notebooks/2023-04-26 Sweden, Gothenburg - Investigation Report (English Translation).pdf\n")
        print(f"数据来源：{source}")

        graph_builder = StateGraph(State).add_sequence([retrieve, generate])
        graph_builder.add_edge(START, "retrieve")
        graph = graph_builder.compile()

        response = graph.invoke({"question": query})
        print(response["answer"])


if __name__ == "__main__":
    run()

（8）使用OpenAI的API请求

In [None]:
"""使用OpenAI的API请求"""
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain_openai import ChatOpenAI
from langchain_chroma import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.documents import Document
from typing_extensions import List, TypedDict
from langgraph.graph import START, StateGraph
from langchain_core.prompts import PromptTemplate

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

import os
from langsmith import Client
langsmith_api_key = os.getenv('LANGCHAIN_API_KEY')
client = Client(api_key=langsmith_api_key)
# print("run_api")

from openai import OpenAI
openai_api_key = os.getenv("OPENAI_API_KEY")
llm_client = OpenAI(api_key=openai_api_key)

import chromadb
from chromadb.config import Settings
reports_client_path = r"指定的db路径"
reports_client = chromadb.PersistentClient(path=reports_client_path,settings=Settings(allow_reset=True))

collection_name = "accident_reports_collection"
embedding_function = OpenAIEmbeddings()
accident_reports_collection = reports_client.get_or_create_collection(name=collection_name)

vectorstore = Chroma(
    client=reports_client,
    collection_name=collection_name,
    embedding_function=OpenAIEmbeddings(),
)
# print("run_vectorstore")

def run():
    template = """你是一个问答任务的助手。请使用以下信息用中文回答问题。如果无法使用以下信息回答，则输出：无法根据信息回答。并直接根据Question内容回答。请使500字左右回答并确保回答的准确性。

    Question: {question}

    Context: {context}

    Helpful Answer:"""
    custom_rag_prompt = PromptTemplate.from_template(template)

    class State(TypedDict):
        question: str
        context: List[Document]
        answer: str

    def retrieve(state: State):
        retrieved_docs = vectorstore.similarity_search(state["question"], k=15, filter = {"source": source} )
        # print("相似性文档如下：", retrieved_docs)
        return {"context": retrieved_docs}


    def generate(state: State):
        docs_content = "\n\n".join(doc.page_content for doc in state["context"])
        messages = custom_rag_prompt.invoke({
            "question": state["question"],
            "context": docs_content})
        # print(messages)
        completion = llm_client.chat.completions.create(
            model="gpt-3.5-turbo",
            temperature=0.7,
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {
                    "role": "user",
                    "content": messages.text
                }
            ]
        )
        # response = llm.invoke(messages)
        # print(completion.choices[0].message.content)
        # print(response)
        return {"answer": completion.choices[0].message.content}
  # 用户提问循环
    while True:
        query = input("请输入问题，例如：维多利亚大电池火灾事故发生的原因是什么？\n")
        if query.lower() == "end":
            print("程序结束！")
            break
        print(f"问题：{query}")
        source_num = int(input("请输入数据来源：\n"))
        # 例如：/content/drive/MyDrive/Colab Notebooks/2019-04-19 US, AZ, Surprise - Investigation.pdf\n 或 /content/drive/MyDrive/Colab Notebooks/2020-09-15 UK, Liverpool - Investigative Report.pdf\n 或 /content/drive/MyDrive/Colab Notebooks/2021-07-30 Australia, Victoria, Moorabool - Investigation Report.pdf\n 或 /content/drive/MyDrive/Colab Notebooks/2023-04-26 Sweden, Gothenburg - Investigation Report (English Translation).pdf\n")
        if source_num == 1:
          source = "/content/drive/MyDrive/Colab Notebooks/Four_Firefighters_Injured_In_Lithium_Ion_Battery_ESS_Explosion_Arizona_0.pdf"
        elif source_num == 2:
          source = "/content/drive/MyDrive/Colab Notebooks/2020-09-15 UK, Liverpool - Investigative Report.pdf"
        elif source_num == 3:
          source = "/content/drive/MyDrive/Colab Notebooks/2021-07-30 Australia, Victoria, Moorabool - Investigation Report.pdf"
        elif source_num == 4:
          source = "/content/drive/MyDrive/Colab Notebooks/2023-04-26 Sweden, Gothenburg - Investigation Report (English Translation).pdf"
        print(f"数据来源：{source}")

        graph_builder = StateGraph(State).add_sequence([retrieve, generate])
        graph_builder.add_edge(START, "retrieve")
        graph = graph_builder.compile()

        response = graph.invoke({"question": query})
        print(response["answer"])


if __name__ == "__main__":
    run()