In [None]:
import os
import openai
import glob
from typing import List

from chromadb.utils import embedding_functions
from langchain.chains import RetrievalQA
from langchain_openai import AzureChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.vectorstores import Chroma
from langchain_core.embeddings import Embeddings
from langchain.memory import ConversationBufferMemory

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

#API_KEY = "YOUR_TONG_API_KEY"
#os.environ["AZURE_OPENAI_API_KEY"] = API_KEY

REGION = "canadaeast"
API_BASE = "https://api.tonggpt.mybigai.ac.cn/proxy"
ENDPOINT = f"{API_BASE}/{REGION}"

openai.azure_endpoint = ENDPOINT
os.environ["AZURE_OPENAI_ENDPOINT"] = ENDPOINT
os.environ["OPENAI_API_VERSION"] = "2024-02-01"

# Build prompt
template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum. Keep the answer as concise as possible. Always say "thanks for asking!" at the end of the answer.
{context}
Question: {question}
Helpful Answer:"""

template_cn = """
你是一个问答机器人，使用下面提供的信息进行归纳和总结来回答最后的问题。如果你不知道答案，清回答“我不知道”，不要试图编造答案或使用未被提供的信息。最多输出3句话，尽可能让回答简单明了"
{context}
问题：{question}
回答："""

llm = AzureChatOpenAI(
    azure_deployment="gpt-35-turbo-0125",  # or your deployment
    temperature=0)

class DefaultChromaEmbedding(Embeddings):
    def __init__(self):
        self.default_ef = embedding_functions.DefaultEmbeddingFunction()

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        return self.default_ef(texts)

    def embed_query(self, text: str) -> List[float]:
        return self.default_ef([text])[0]

chroma_embedding = DefaultChromaEmbedding()

class KnowledgeBaseBuilder:
    def __init__(self, text_splitter=None, embedding_func=None, persist_dir="docs/data/chroma_dev/"):
        if not text_splitter:
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=1000,
                chunk_overlap=100,
                separators=[
                    "\n\n",
                    "\n",
                    " ",
                    ".",
                    ",",
                    "\u200b",  # Zero-width space
                    "\uff0c",  # Fullwidth comma
                    "\u3001",  # Ideographic comma
                    "\uff0e",  # Fullwidth full stop
                    "\u3002",  # Ideographic full stop
                    "",
                ],
            )
        if not embedding_func:
            embedding_func = DefaultChromaEmbedding()
        self.text_splitter = text_splitter
        self.embedding_func = embedding_func
        self.persist_dir = persist_dir
        self.vectordb = None

    def docs_preprocess(self, path_name):
        loaders = [
            PyPDFLoader(doc_path) for doc_path in glob.glob(path_name)
        ]
        docs = []
        for loader in loaders:
            docs.extend(loader.load())
        return self.text_splitter.split_documents(docs)

    def knowledge_base_build_chain(self, path_name):
        splits = self.docs_preprocess(path_name)
        vectordb = Chroma.from_documents(
            documents=splits,
            embedding=self.embedding_func,
            persist_directory=self.persist_dir
        )
        return vectordb


class QAChain:
    def __init__(self, vectordb):
        memory = ConversationBufferMemory(
            memory_key="chat_history",
            output_key="result",
            return_messages=True
        )
        self.qa_chain = RetrievalQA.from_chain_type(
            llm,
            retriever=vectordb.as_retriever(),
            return_source_documents=True,
            output_key="result",
            #memory=memory,
            chain_type_kwargs={"prompt": PromptTemplate.from_template(template_cn)}
        )

    def query(self, question):
        result = self.qa_chain.invoke({"query": question})
        return result["result"]

In [None]:
persist_directory = "docs/data/chroma_dev"

# Load existing database
# vectordb = Chroma(
#     persist_directory=persist_directory,
#     embedding_function=chroma_embedding
# )

# OR create new vector database
# remove old database files if any
import shutil
import os
if os.path.exists(persist_directory):
    shutil.rmtree(persist_directory)

knowledge_base_builder = KnowledgeBaseBuilder(persist_dir=persist_directory)
vectordb = knowledge_base_builder.knowledge_base_build_chain("docs/bigai/*.pdf")
#vectordb.persist()
qa_chain_dev = QAChain(vectordb)

In [None]:
question = "试用期导师的职责"
result = qa_chain_dev.query(question)
print(result)

In [None]:
question = "招聘新员工的渠道"
result = qa_chain_dev.query(question)
print(result)