In [4]:
import chromadb
from chromadb.utils import embedding_functions
import numpy as np

class TextMatcher:
    def __init__(self, collection_name="text_matcher", threshold=0.7):
        self.client = chromadb.PersistentClient(path="./vector_db_match/chroma_db")
        self.embedding_function = embedding_functions.DefaultEmbeddingFunction()
        self.threshold = threshold
        
        try:
            self.collection = self.client.get_collection(
                name=collection_name,
                embedding_function=self.embedding_function
            )
        except:
            self.collection = self.client.create_collection(
                name=collection_name,
                embedding_function=self.embedding_function
            )
    
    def add_texts(self, texts):
        """添加文本集合到向量数据库"""
        ids = [str(i) for i in range(len(texts))]
        self.collection.add(
            documents=texts,
            ids=ids
        )
        print(f"已添加 {len(texts)} 个文本到向量数据库")
    
    def find_closest(self, query_text, n_results=1):
        """查找最相似的文本，如果低于阈值则返回None"""
        results = self.collection.query(
            query_texts=[query_text],
            n_results=n_results
        )
        
        if not results['documents'][0]:
            return None
        
        closest_text = results['documents'][0][0]
        distance = results['distances'][0][0]
        
        # Chroma使用余弦距离，距离越小越相似
        # 转换为相似度分数：1 - distance
        similarity = 1 - distance
        
        if similarity >= self.threshold:
            return {
                'text': closest_text,
                'similarity': similarity,
                'distance': distance
            }
        else:
            return None

# 使用示例
if __name__ == "__main__":
    # 初始化匹配器
    matcher = TextMatcher(threshold=0.1)
    
    # 示例文本集合
    sample_texts = [
        "706",
        "701",
        "702",
        "703",
        "704",
        "302",
        "501",
        "601"
    ]
    
    # 添加文本到数据库
    matcher.add_texts(sample_texts)
    
    # 测试查询
    test_queries = [
        "C-02-701",
        "2-703"
    ]
    
    print("\n测试查询结果：")
    for query in test_queries:
        result = matcher.find_closest(query)
        print(f"\n查询: '{query}'")
        if result:
            print(f"匹配结果: '{result['text']}'")
            print(f"相似度: {result['similarity']:.3f}")
        else:
            print("未找到足够相似的文本")

已添加 8 个文本到向量数据库

测试查询结果：

查询: 'C-02-701'
匹配结果: '701'
相似度: 0.227

查询: '2-703'
匹配结果: '703'
相似度: 0.639
