In [None]:
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm
import json
import os
import chromadb

# Initialize model
cache_folder = os.path.expanduser("/Users/shou/Code/huggingface_models")
persist_dir = os.path.expanduser("./chroma_db")

model = SentenceTransformer(
    "dunzhang/stella_en_1.5B_v5",
    # cache_folder=cache_folder,
    # local_files_only=False,
    trust_remote_code=True,
)

# Create ChromaDB client
client = chromadb.PersistentClient(path=persist_dir)  # Directory for data persistence

# Create or get collection
try:
    collection = client.get_collection(name="bird_entries")
except:
    print("Collection does not exist. Creating a new one.")
    collection = client.create_collection(
        name="bird_entries", metadata={"description": "Bird identification embeddings", "hnsw:space": "cosine"}
    )

  from tqdm.autonotebook import tqdm, trange


model.safetensors:   0%|          | 10.5M/6.17G [00:00<?, ?B/s]

KeyboardInterrupt: 

In [None]:
# Read data
with open("./source/ebird_data.json", "r", encoding="UTF-8") as f:
    entries = json.load(f)
    
# Process data in batches to avoid memory overflow
batch_size = 100  # Adjust this value based on your available memory
total_batches = (len(entries) + batch_size - 1) // batch_size

# Create progress bar for batch processing
with tqdm(total=total_batches, desc="Processing batches") as pbar:
    for i in range(0, len(entries), batch_size):
        batch_entries = list(entries.items())[i : i + batch_size]

        # Prepare data for this batch
        batch_texts = []
        batch_ids = []
        batch_metadata = []

        # Create progress bar for entry processing within batch
        for key, entry_info in batch_entries:
            entry_text = f"{key} ({entry_info['binomialName']}), is {entry_info['identification']}"
            batch_texts.append(entry_text)
            batch_ids.append(key)
            batch_metadata.append(
                {
                    "binomialName": entry_info["binomialName"],
                    "macaulayID": entry_info["macaulayID"],
                    "url": entry_info["url"],
                }
            )

        # Generate embeddings for this batch
        batch_embeddings = model.encode(batch_texts, show_progress_bar=True)

        # Add to ChromaDB
        collection.add(
            embeddings=batch_embeddings.tolist(),  # ChromaDB requires list format
            documents=batch_texts,
            ids=batch_ids,
            metadatas=batch_metadata,
        )

        # Update batch progress bar
        pbar.update(1)

In [None]:
# Query function
def match(query, top_k=3):
    query_prompt_name = "s2p_query"
    query_embedding = model.encode(query, prompt_name=query_prompt_name)
    
    # Query using ChromaDB
    results = collection.query(
        query_embeddings=[query_embedding.tolist()],
        n_results=top_k
    )
    
    # Format results
    similarities = {}
    macaulayIDs = {}
    for id, distance, metadata in zip(results['ids'][0], results['distances'][0], results['metadatas'][0]):
        similarities[id] = 1 - distance
        macaulayIDs[id] = metadata['macaulayID']
    
    return similarities, macaulayIDs

# Test query
top_n_similarities, top_n_macaulayIDs = match("Red.")
print("\nQuery results:", top_n_similarities)
print("\nQuery results:", top_n_macaulayIDs)


Query results: {'Red Phalarope': 0.5217095475236465, 'Red Knot': 0.4956471920013428, 'Red-footed Booby': 0.4898524284362793}

Query results: {'Red Phalarope': '107267571', 'Red Knot': '27328091', 'Red-footed Booby': '243885681'}
