In [1]:
import chromadb

In [2]:
from chromadb import PersistentClient

# 方式1：使用 PersistentClient（推荐，Chroma v0.4+ 新版本）
client = PersistentClient(
    path="./chroma_data"  # 数据存储在本地 ./chroma_data 文件夹
)

""" 方式2：如果用 LangChain 集成的 Chroma（常见场景）
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings  # 或其他嵌入模型

vector_db = Chroma(
    collection_name="my_collection",
    embedding_function=OpenAIEmbeddings(),
    persist_directory="./chroma_data"  # 关键：指定本地存储路径，无需服务端
)
"""

' 方式2：如果用 LangChain 集成的 Chroma（常见场景）\nfrom langchain.vectorstores import Chroma\nfrom langchain.embeddings import OpenAIEmbeddings  # 或其他嵌入模型\n\nvector_db = Chroma(\n    collection_name="my_collection",\n    embedding_function=OpenAIEmbeddings(),\n    persist_directory="./chroma_data"  # 关键：指定本地存储路径，无需服务端\n)\n'

In [3]:
#chroma_clinet = chromadb.HttpClient(host = "localhost", port=8000)

In [4]:
from chromadb.utils import embedding_functions
model_path = "../model/gte-large-zh"
em_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name = model_path)

In [5]:
collection = client.get_or_create_collection(name='rag_db', embedding_function=em_fn, metadata={"hnsw:space": "cosine"})

In [6]:
documents=["检索增强生成（Retrieval-augmented Generation），简称RAG，是当下热门的大模型前沿技术之一。",
           "检索增强生成模型结合了语言模型和信息检索技术。具体来说，当模型需要生成文本或者回答问题时，",
           "它会先从一个庞大的文档集合中检索出相关的信息，然后利用这些检索到的信息来指导文本的生成，从而提高预测的质量和准确性 。"]

In [7]:
collection.add(documents = documents,
              ids = ["id1", "id2", "id3"],
              metadatas=[{"chapter":3, "verse":16},
                        {"chapter":4, "verse":5},
                        {"chapter":12, "verse":5}])

In [8]:
collection.count()

3

In [9]:
#collection.peek(limit=1)

In [10]:
get_collection = client.get_collection(name="rag_db", embedding_function=em_fn)

In [11]:
id_result = get_collection.get(ids=['id2'], include=["documents", "embeddings","metadatas"])

In [12]:
print(id_result["documents"])

['检索增强生成模型结合了语言模型和信息检索技术。具体来说，当模型需要生成文本或者回答问题时，']


In [13]:
import numpy as np
np.array(id_result["embeddings"]).shape

(1, 1024)

In [14]:
query = "检索增强技术简称是什么？"

In [15]:
get_collection.query(query_texts=query,
                    n_results=2,
                    include=["documents","metadatas"])

{'ids': [['id2', 'id1']],
 'embeddings': None,
 'documents': [['检索增强生成模型结合了语言模型和信息检索技术。具体来说，当模型需要生成文本或者回答问题时，',
   '检索增强生成（Retrieval-augmented Generation），简称RAG，是当下热门的大模型前沿技术之一。']],
 'uris': None,
 'included': ['documents', 'metadatas'],
 'data': None,
 'metadatas': [[{'chapter': 4, 'verse': 5}, {'chapter': 3, 'verse': 16}]],
 'distances': None}

In [16]:
get_collection.query(query_texts=query,
                    n_results=2,
                    include=["documents","metadatas"],
                    where={"verse":5})

{'ids': [['id2', 'id3']],
 'embeddings': None,
 'documents': [['检索增强生成模型结合了语言模型和信息检索技术。具体来说，当模型需要生成文本或者回答问题时，',
   '它会先从一个庞大的文档集合中检索出相关的信息，然后利用这些检索到的信息来指导文本的生成，从而提高预测的质量和准确性 。']],
 'uris': None,
 'included': ['documents', 'metadatas'],
 'data': None,
 'metadatas': [[{'verse': 5, 'chapter': 4}, {'chapter': 12, 'verse': 5}]],
 'distances': None}

In [17]:
# $eq: equal to
# $ne not equal to
# $gt greater than
# $gte greater than or equal to
# $lt -less than
# $lte less than or equal to
get_collection.query(query_texts=query,
                    n_results=2,
                    include=["documents","metadatas"],
                    where={"chapter":{"$lt":10}})

{'ids': [['id2', 'id1']],
 'embeddings': None,
 'documents': [['检索增强生成模型结合了语言模型和信息检索技术。具体来说，当模型需要生成文本或者回答问题时，',
   '检索增强生成（Retrieval-augmented Generation），简称RAG，是当下热门的大模型前沿技术之一。']],
 'uris': None,
 'included': ['documents', 'metadatas'],
 'data': None,
 'metadatas': [[{'verse': 5, 'chapter': 4}, {'verse': 16, 'chapter': 3}]],
 'distances': None}

In [18]:
get_collection.query(query_texts=query,
                    n_results=2,
                    include=["documents","metadatas"],
                    where={"$and":[{"chapter":{"$lt":10}},
                                  {"verse":{"$eq":5}}
                                  ]}
                    )

{'ids': [['id2']],
 'embeddings': None,
 'documents': [['检索增强生成模型结合了语言模型和信息检索技术。具体来说，当模型需要生成文本或者回答问题时，']],
 'uris': None,
 'included': ['documents', 'metadatas'],
 'data': None,
 'metadatas': [[{'chapter': 4, 'verse': 5}]],
 'distances': None}

In [19]:
get_collection.query(query_texts=query,
                    n_results=2,
                    include=["documents","metadatas"],
                    where_document={"$contains":"检索"}
                    )

{'ids': [['id2', 'id1']],
 'embeddings': None,
 'documents': [['检索增强生成模型结合了语言模型和信息检索技术。具体来说，当模型需要生成文本或者回答问题时，',
   '检索增强生成（Retrieval-augmented Generation），简称RAG，是当下热门的大模型前沿技术之一。']],
 'uris': None,
 'included': ['documents', 'metadatas'],
 'data': None,
 'metadatas': [[{'verse': 5, 'chapter': 4}, {'chapter': 3, 'verse': 16}]],
 'distances': None}

In [21]:
# use langchain

#from langchain.embeddings.huggingface import HuggingFaceEmbeddings
#from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
model_path = "../model/gte-large-zh"
model  =HuggingFaceEmbeddings(model_name = model_path,
                             model_kwargs={'device':"cpu"})
embeddings = model.embed_documents(documents)

In [22]:
print(embeddings)

[[0.0022275266237556934, 0.010900181718170643, -0.06066453829407692, -0.0022906209342181683, 0.0027456318493932486, -0.043523918837308884, 0.0008226658101193607, -0.03207528963685036, -0.043596260249614716, -0.06867467612028122, 0.011284132488071918, -0.020719369873404503, -0.0024540473241358995, 0.022484200075268745, 0.003132208716124296, 0.017350932583212852, 0.010060084983706474, -0.017020225524902344, -0.012759997509419918, -0.020756786689162254, -0.028273653239011765, -0.014233355410397053, 0.013630951754748821, 0.035547252744436264, 0.025466082617640495, -0.013927300460636616, -0.03082045167684555, 0.01546204648911953, 0.034657564014196396, -0.018972575664520264, 0.04704611375927925, -0.0033103569876402617, -0.05470903962850571, 0.04977907985448837, -0.02578929252922535, -0.016000354662537575, -0.01177658699452877, -0.029126813635230064, 0.023769671097397804, 0.10458578169345856, 0.005735491402447224, 0.0776151493191719, 0.008250946179032326, 0.0387597493827343, -0.02995404787361