In [55]:
from sentence_transformers import SentenceTransformer
import numpy as np
import json
import os


class SemanticMatcher:
    def __init__(self, entries):
        """
        Initialize the semantic matcher

        :param entries: Dict, {key: {description}}
        """

        # Specify a local download path
        cache_folder = os.path.expanduser("/Users/shou/Code/huggingface_models")

        # Try downloading manually first
        self.model = SentenceTransformer(
            "intfloat/multilingual-e5-large",
            cache_folder=cache_folder,
            local_files_only=False,
        )

        self.entries = entries
        # Embed all data
        self.entry_embeddings = {}
        for key, entry_info in entries.items():
            entry_text = entry_info["identification"]
            self.entry_embeddings[key] = self.model.encode(f"passage: {entry_text}")

    def match(self, query, top_k=3, threshold=0.5):
        """
        Search the queries.

        :param query: Query keywords
        :param top_k: Returns the top k most similar results
        :param threshold: Similarity threshold
        :return: Matching entries and their similarities
        """
        # Generate an embedding vector for the query
        query_with_prefix = f"query: {query}"
        query_embedding = self.model.encode(query_with_prefix)

        # Calculating similarity
        similarities = {}
        for key, entry_embedding in self.entry_embeddings.items():
            similarity = np.dot(query_embedding, entry_embedding) / (
                np.linalg.norm(query_embedding) * np.linalg.norm(entry_embedding)
            )
            similarities[key] = similarity

        # Sort by similarity
        sorted_matches = sorted(similarities.items(), key=lambda x: x[1], reverse=True)

        # Filter and return results
        filtered_matches = [
            (key, similarity)
            for key, similarity in sorted_matches
            if similarity >= threshold
        ]

        return filtered_matches[:top_k]

In [56]:
# Input bird identifications
with open("ebird_data.json",'r', encoding='UTF-8') as f:
     entries = json.load(f)

# Encode the passages
matcher = SemanticMatcher(entries)

In [60]:
# Queries
test_queries = ["yellow"]
for query in test_queries:
    print(f"\nSearch: {query}")
    results = matcher.match(query)

    for key, similarity in results:
        print(f"Matched: {key}, Similarity: {similarity:.4f}")
        print("Detail:", entries[key]["binomialName"], entries[key]["url"])


Search: yellow
Matched: Eastern Yellow Wagtail, Similarity: 0.8241
Detail: Motacilla tschutschensis https://ebird.org/species/eaywag/JP-13
Matched: Citrine Wagtail, Similarity: 0.8187
Detail: Motacilla citreola https://ebird.org/species/citwag/JP-13
Matched: Yellow-bellied Tit, Similarity: 0.8180
Detail: Periparus venustulus https://ebird.org/species/yebtit4/JP-13
