In [None]:
from sentence_transformers import SentenceTransformer
import numpy as np
import json
import os

class SemanticMatcher:
    def __init__(self, entries):
        """
        Initialize the semantic matcher
        
        :param entries: Dict, {key: {description}}
        """

        # Specify a local download path
        cache_folder = os.path.expanduser('/Users/shou/Code/huggingface_models')

        # Try downloading manually first
        self.model = SentenceTransformer('intfloat/multilingual-e5-large', 
                                    cache_folder=cache_folder, 
                                    local_files_only=False)
        # self.model = SentenceTransformer('intfloat/multilingual-e5-large')
        
        self.entries = entries
        # Embed all data
        self.entry_embeddings = {}
        for key, entry_info in entries.items():
            entry_text = entry_info['identification']
            self.entry_embeddings[key] = self.model.encode(f"passage: {entry_text}")

    def match(self, query, top_k=3, threshold=0.5):
        """
        Search the queries
        
        :param query: 查询关键词
        :param top_k: 返回最相似的前k个结果
        :param threshold: 相似度阈值
        :return: 匹配的词条及其相似度
        """
        # 生成查询的嵌入向量
        query_with_prefix = f"query: {query}"
        query_embedding = self.model.encode(query_with_prefix)
        
        # 计算相似度
        similarities = {}
        for key, entry_embedding in self.entry_embeddings.items():
            similarity = np.dot(query_embedding, entry_embedding) / (
                np.linalg.norm(query_embedding) * np.linalg.norm(entry_embedding)
            )
            similarities[key] = similarity
        
        # 根据相似度排序
        sorted_matches = sorted(
            similarities.items(), 
            key=lambda x: x[1], 
            reverse=True
        )
        
        # 过滤并返回结果
        filtered_matches = [
            (key, similarity) 
            for key, similarity in sorted_matches 
            if similarity >= threshold
        ]
        
        return filtered_matches[:top_k]


In [40]:
# Input bird identifications
with open("ebird_data.json",'r', encoding='UTF-8') as f:
     entries = json.load(f)

# Encode the passages
matcher = SemanticMatcher(entries)

In [45]:
# Queries
test_queries = ["オスの成鳥は、頭が青灰色、背が褐色、腹がオレンジ色"]
for query in test_queries:
    print(f"\nSearch: {query}")
    results = matcher.match(query)

    for key, similarity in results:
        print(f"Matched: {key}, Similarity: {similarity:.4f}")
        # 可以进一步打印词条详细信息
        print("Detail:", entries[key]["binomialName"], entries[key]["url"])


Search: オスの成鳥は、頭が青灰色、背が褐色、腹がオレンジ色
Matched: Izu Thrush, Similarity: 0.8120
Detail: Turdus celaenops https://ebird.org/species/izuthr1/JP-13
Matched: Slaty-backed Gull, Similarity: 0.8084
Detail: Larus schistisagus https://ebird.org/species/slbgul/JP-13
Matched: Eyebrowed Thrush, Similarity: 0.8069
Detail: Turdus obscurus https://ebird.org/species/eyethr/JP-13
