In [1]:
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import PromptTemplate
from lib.embeds import MyEmbeddings
from lib.faiss import FAISSVS
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from lib.chatglm_llm import ChatGLM, AlpacaGLM
from lib.config import *
from lib.utils import get_docs


class LocalDocQA:
    def __init__(self, 
                 embedding_model=EMBEDDING_MODEL, 
                 embedding_device=EMBEDDING_DEVICE, 
                 llm_model=LLM_MODEL, 
                 llm_device=LLM_DEVICE, 
                 llm_history_len=LLM_HISTORY_LEN, 
                 top_k=VECTOR_SEARCH_TOP_K,
                 vs_name = VS_NAME
                 ) -> None:
        
        torch.cuda.empty_cache()
        torch.cuda.empty_cache()

        self.embedding_model = embedding_model
        self.llm_model = llm_model
        self.embedding_device = embedding_device
        self.llm_device = llm_device
        self.llm_history_len = llm_history_len
        self.top_k = top_k
        self.vs_name = vs_name

        self.llm = AlpacaGLM()
        self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], llm_device=llm_device)

        self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model])
        self.load_vector_store(vs_name)

        self.prompt = PromptTemplate(
            template=PROMPT_TEMPLATE,
            input_variables=["context", "question"]
        )
        self.search_params = {
            "engine": "bing",
            "gl": "us",
            "hl": "en",
            "serpapi_api_key": ""
        }

    def init_knowledge_vector_store(self, vs_name: str):
        
        docs = get_docs(KNOWLEDGE_PATH)
        vector_store = FAISSVS.from_documents(docs, self.embeddings)
        vs_path = VECTORSTORE_PATH + vs_name
        vector_store.save_local(vs_path)

    def add_knowledge_to_vector_store(self, vs_name: str):
        docs = get_docs(ADD_KNOWLEDGE_PATH)
        new_vector_store = FAISSVS.from_documents(docs, self.embeddings)
        vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings)        
        vector_store.merge_from(new_vector_store)
        vector_store.save_local(VECTORSTORE_PATH + vs_name)

    def load_vector_store(self, vs_name: str):
        self.vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings)

    # def get_search_based_answer(self, query):
        
    #     search = SerpAPIWrapper(params=self.search_params)
    #     docs = search.run(query)
    #     search_chain = load_qa_chain(self.llm, chain_type="stuff")
    #     answer = search_chain.run(input_documents=docs, question=query)

    #     return answer
    
    def get_knowledge_based_answer(self, query):
        
        docs = self.vector_store.max_marginal_relevance_search(query)
        print(f'召回的文档和相似度分数：{docs}')
        # 这里 doc[1] 就是对应的score 
        docs = [doc[0] for doc in docs]
        
        document_prompt = PromptTemplate(
            input_variables=["page_content"], template="Context:\n{page_content}"
        )
        llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
        combine_documents_chain = StuffDocumentsChain(
            llm_chain=llm_chain,
            document_variable_name="context",
            document_prompt=document_prompt,
        )
        answer = combine_documents_chain.run(
            input_documents=docs, question=query
        )

        self.llm.history[-1][0] = query
        self.llm.history[-1][-1] = answer
        return answer, docs, self.llm.history

In [2]:
qa = LocalDocQA()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
query = r"""make a brief introduction of APT?"""
ans, docs, _ = qa.get_knowledge_based_answer(query)

召回的文档和相似度分数：[(Document(page_content='****** LOGI APT Group Intelligence Research Yearbook APT Knowledge Graph APT组织情报 研究年鉴', metadata={'source': './KnowledgeStore/APT group Intelligence Research handbook-2022.pdf', 'page': 0}), 0.45381865), (Document(page_content='9 MANDIANT APT42: Crooked Charms, Cons and Compromises FIGURE 8. APT42 impersonates University of Oxford vaccinologist. APT42 Credential harvesting page masquerading as a Yahoo login portal.', metadata={'source': './KnowledgeStore/APT42_Crooked_Charms_Cons_and_Compromises.pdf', 'page': 8}), 0.4535672), (Document(page_content='The origin story of APT32 macros T H R E A T R E S E A R C H R E P O R T R u n n i n g t h r o u g h a l l t h e S U O f i l e s t r u c t u r e s i s l a b o r i o u s a n d d i d n ’ t y i e l d m u c h m o r e t h a n a s t r i n g d u m p w o u l d h a v e d o n e a n y w a y . W e f i n d p a t h s t o s o u r c e c o d e f i l e s , p r o j e c t n a m e s , e t c . W e c a n i n f e r f r o m t h 

In [4]:
ans

'\nAnswer: APT stands for Advanced Persistent Threat, which is a type of malicious cyberattack that is carried out by a sophisticated hacker group or state-sponsored organization. APTs are designed to remain undetected for a long period of time and are often used to steal sensitive data or disrupt critical infrastructure.'