In [None]:
#参考：https://zhuanlan.zhihu.com/p/682641846

In [None]:
# ----------------- 导入必要的package ----------------- #
import torch
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain import PromptTemplate
from langchain_community.document_transformers import (
    LongContextReorder,
)
from langchain import HuggingFacePipeline
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import LLMChain, StuffDocumentsChain

# ----------------- 配置项 ---------------------------- #
model_path = "../../models/Baichuan2-13B-Chat"
embed_path = "../../models/bge-large-zh-v1.5"
# ----------------- 加载embedding模型 ----------------- #
embeddings = HuggingFaceEmbeddings(
    model_name=embed_path,
    model_kwargs={"device": "cuda"},
    encode_kwargs={"normalize_embeddings": True},
)
# ----------------- 加载LLM -------------------------- #
tokenizer = AutoTokenizer.from_pretrained(model_path,
                                          device_map="auto",
                                          trust_remote_code=True,
                                          torch_dtype=torch.float16)

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
    trust_remote_code=True,
    device_map="auto",
)

pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    return_full_text=True,
)

llm = HuggingFacePipeline(pipeline=pipeline)
# ----------------- 加载文件 --------------------------- #
loader = PyPDFLoader("../data/中华人民共和国证券法(2019修订).pdf")
documents = loader.load_and_split()
text_splitter = RecursiveCharacterTextSplitter(separators=["。"], chunk_size=512, chunk_overlap=32)
texts_chunks = text_splitter.split_documents(documents)
文本存入向量库后创建retriever，设置返回10个文本块。

# ----------------- 存入向量库，创建retriever ------------ #
vectorstore = Chroma.from_documents(texts_chunks, embeddings, persist_directory="db")
retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
利用get_relevant_documents()获取相关文档，然后用LongContextReorder()进行重排序。

# ----------------- 文档重排序 -------------------------- #
query = "公司首次公开发行新股，应当符合哪些条件？"
docs = retriever.get_relevant_documents(query)

# 相关性小的文档放在中间，相关性大的文档放在首尾两端
reordering = LongContextReorder()
reordered_docs = reordering.transform_documents(docs)

In [None]:
# ----------------- 构造提示模板 -------------------------- #
document_prompt = PromptTemplate(
    input_variables=["page_content"], template="{page_content}"
)
document_variable_name = "context"

template = """你是一名智能助手，可以根据上下文回答用户的问题。

已知内容：
{context}

问题：
{question}
"""
prompt = PromptTemplate(template=template, input_variables=["context", "question"])

In [None]:
# ----------------- 初始化chain并测试 -------------------------- #
llm_chain = LLMChain(llm=llm, prompt=prompt)
chain = StuffDocumentsChain(
    llm_chain=llm_chain,
    document_prompt=document_prompt,
    document_variable_name=document_variable_name,
)
result = chain.run(input_documents=reordered_docs, question=query)
print(result)