Faiss Vector Store的使用

In [1]:
import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

In [2]:
import torch

def torch_gc():
    if torch.cuda.is_available():
        # with torch.cuda.device(DEVICE):
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    elif torch.backends.mps.is_available():
        try:
            from torch.mps import empty_cache
            empty_cache()
        except Exception as e:
            print(e)
            print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本，以支持及时清理 torch 产生的内存占用。")

import LLM

In [3]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
model = model.eval()

  from .autonotebook import tqdm as notebook_tqdm


No compiled kernel found.
Compiling kernels : /home/xiaodong/.cache/huggingface/modules/transformers_modules/THUDM/chatglm-6b-int4/02a065cf2797029c036a02cac30f1da1a9bc49a3/quantization_kernels_parallel.c
Compiling gcc -O3 -fPIC -pthread -fopenmp -std=c99 /home/xiaodong/.cache/huggingface/modules/transformers_modules/THUDM/chatglm-6b-int4/02a065cf2797029c036a02cac30f1da1a9bc49a3/quantization_kernels_parallel.c -shared -o /home/xiaodong/.cache/huggingface/modules/transformers_modules/THUDM/chatglm-6b-int4/02a065cf2797029c036a02cac30f1da1a9bc49a3/quantization_kernels_parallel.so
Load kernel : /home/xiaodong/.cache/huggingface/modules/transformers_modules/THUDM/chatglm-6b-int4/02a065cf2797029c036a02cac30f1da1a9bc49a3/quantization_kernels_parallel.so
Setting CPU quantization kernel threads to 6
Using quantization cache
Applying quantization to glm layers


define CustomerLLM

In [4]:
from langchain.llms.base import LLM
from typing import Optional, List, Mapping, Any
from llama_index import LLMPredictor

class CustomLLM(LLM):

    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        response, history = model.chat(tokenizer, prompt, history=[])
        # only return newly generated tokens
        return response

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        return {"name_of_model": "THUDM/chatglm-6b-int4"}

    @property
    def _llm_type(self) -> str:
        return "custom"
    
llm_predictor = LLMPredictor(llm=CustomLLM())


INFO:numexpr.utils:Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.
NumExpr defaulting to 8 threads.


another LLM dolly 2.0

embedding model setup

In [5]:
from langchain.embeddings.huggingface import HuggingFaceEmbeddings


em_model = "GanymedeNil/text2vec-large-chinese"
em_device = "cuda"

hgf_embeddings = HuggingFaceEmbeddings(model_name=em_model,
                                        model_kwargs={'device': em_device})

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: GanymedeNil/text2vec-large-chinese
Load pretrained SentenceTransformer: GanymedeNil/text2vec-large-chinese
No sentence-transformers model found with name /home/xiaodong/.cache/torch/sentence_transformers/GanymedeNil_text2vec-large-chinese. Creating a new one with MEAN pooling.


In [6]:
from langchain.vectorstores import FAISS
from langchain.document_loaders import UnstructuredFileLoader, TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings.huggingface import HuggingFaceEmbeddings

import re

# 文本分句长度
SENTENCE_SIZE = 100

class ChineseTextSplitter(CharacterTextSplitter):
    def __init__(self, pdf: bool = False, sentence_size: int = SENTENCE_SIZE, **kwargs):
        super().__init__(**kwargs)
        self.pdf = pdf
        self.sentence_size = sentence_size

    def split_text1(self, text: str) -> List[str]:
        if self.pdf:
            text = re.sub(r"\n{3,}", "\n", text)
            text = re.sub('\s', ' ', text)
            text = text.replace("\n\n", "")
        sent_sep_pattern = re.compile('([﹒﹔﹖﹗．。！？]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))')  # del ：；
        sent_list = []
        for ele in sent_sep_pattern.split(text):
            if sent_sep_pattern.match(ele) and sent_list:
                sent_list[-1] += ele
            elif ele:
                sent_list.append(ele)
        return sent_list

    def split_text(self, text: str) -> List[str]:   ##此处需要进一步优化逻辑
        if self.pdf:
            text = re.sub(r"\n{3,}", r"\n", text)
            text = re.sub('\s', " ", text)
            text = re.sub("\n\n", "", text)

        text = re.sub(r'([;；.!?。！？\?])([^”’])', r"\1\n\2", text)  # 单字符断句符
        text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text)  # 英文省略号
        text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text)  # 中文省略号
        text = re.sub(r'([;；!?。！？\?]["’”」』]{0,2})([^;；!?，。！？\?])', r'\1\n\2', text)
        # 如果双引号前有终止符，那么双引号才是句子的终点，把分句符\n放到双引号后，注意前面的几句都小心保留了双引号
        text = text.rstrip()  # 段尾如果有多余的\n就去掉它
        # 很多规则中会考虑分号;，但是这里我把它忽略不计，破折号、英文双引号等同样忽略，需要的再做些简单调整即可。
        ls = [i for i in text.split("\n") if i]
        for ele in ls:
            if len(ele) > self.sentence_size:
                ele1 = re.sub(r'([,，.]["’”」』]{0,2})([^,，.])', r'\1\n\2', ele)
                ele1_ls = ele1.split("\n")
                for ele_ele1 in ele1_ls:
                    if len(ele_ele1) > self.sentence_size:
                        ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
                        ele2_ls = ele_ele2.split("\n")
                        for ele_ele2 in ele2_ls:
                            if len(ele_ele2) > self.sentence_size:
                                ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
                                ele2_id = ele2_ls.index(ele_ele2)
                                ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[
                                                                                                       ele2_id + 1:]
                        ele_id = ele1_ls.index(ele_ele1)
                        ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]

                id = ls.index(ele)
                ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:]
        return ls

filepath = "data/faq/ecommerce_faq.txt"
loader = TextLoader(filepath)
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=SENTENCE_SIZE)
docs = loader.load_and_split(textsplitter)

vs_path = "./vector_storage"
faiss_vector_store = FAISS.from_documents(docs, hgf_embeddings)


Batches: 100%|██████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.71it/s]

INFO:faiss.loader:Loading faiss with AVX2 support.
Loading faiss with AVX2 support.
INFO:faiss.loader:Successfully loaded faiss with AVX2 support.
Successfully loaded faiss with AVX2 support.





In [7]:
faiss_vector_store.save_local(vs_path)

faiss_vector_store = FAISS.load_local(vs_path, hgf_embeddings)
query_faiss = "配送范围是？"
result_docs = faiss_vector_store.similarity_search(query=query_faiss)
print(result_docs[0].page_content)

Batches: 100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 124.70it/s]

Q: 支持哪些省份配送？





In [8]:
from langchain import VectorDBQA

# ChatGLM
faq_chain = VectorDBQA.from_chain_type(llm=CustomLLM(), vectorstore=faiss_vector_store, verbose=True)





In [9]:
question = '''能送到西藏吗？大概需要几天？'''
result = faq_chain.run(question)
print(result)



[1m> Entering new VectorDBQA chain...[0m


Batches: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 60.75it/s]

The dtype of attention mask (torch.int64) is not bool






[1m> Finished chain.[0m
I'm sorry, but I don't have information on specific快递公司 or delivery times for locations like西藏. The delivery time for a specific order will depend on various factors such as the order's location, the service provider's network, and the delivery method used. Can you please provide more information about the order you are referring to so I can better assist you?
