In [1]:
from tqdm.autonotebook import tqdm, trange
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
import json
import os

cache_folder = os.path.expanduser("/Users/shou/Code/huggingface_models")

model = SentenceTransformer(
    "dunzhang/stella_en_1.5B_v5",
    cache_folder=cache_folder,
    local_files_only=False,
    trust_remote_code=True,
)

  from tqdm.autonotebook import tqdm, trange


In [2]:
# Input bird identifications
with open("ebird_data.json", "r", encoding="UTF-8") as f:
    entries = json.load(f)

entryList = []
for key, entry_info in entries.items():
    entry_text = key + ", also called" + entry_info["binomialName"]+ ". " + entry_info["identification"]
    entryList.append(entry_text)

entry_embeddings = model.encode(entryList)

In [3]:
# Prompt query
def match(query):
    query_prompt_name = "s2p_query"

    query_embeddings = model.encode(query, prompt_name=query_prompt_name)
    similarities = model.similarity(query_embeddings, entry_embeddings)

    topk = 3
    top_n_values, top_n_index = torch.topk(similarities, topk)
    specieList = list(entries.keys())

    top_n_similarities = {}
    for value, index in zip(
        top_n_values.flatten().tolist(), top_n_index.flatten().tolist()
    ):
        key = specieList[index]
        top_n_similarities[key] = value

    return top_n_similarities

In [4]:
top_n_similarities = match("blue")

for key, similarity in top_n_similarities.items():
    print(f"Matched: {key}, Similarity: {similarity:.4f}")
    print("Detail:", entries[key]["binomialName"], entries[key]["url"])

Matched: Blue-and-white Flycatcher, Similarity: 0.5152
Detail: Cyanoptila cyanomelana https://ebird.org/species/bawfly2/JP-13
Matched: Blue Rock-Thrush, Similarity: 0.5012
Detail: Monticola solitarius https://ebird.org/species/burthr/JP-13
Matched: Red-flanked Bluetail, Similarity: 0.4920
Detail: Tarsiger cyanurus https://ebird.org/species/refblu1/JP-13


In [7]:
from dash import Dash, dcc, html, Input, Output, State, callback

external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]

app = Dash(__name__, title="RAG-ebird", external_stylesheets=external_stylesheets)


app.layout = html.Div(
    [
        dcc.Input(
            id="input-text-state",
            type="text",
            value="A blue bird next to water.",
            style={"width": "1040px", "margin": "20px"},
        ),
        html.Button(
            id="submit-button-state",
            n_clicks=0,
            children="Search",
            style={"width": "200px", "margin-bottom": "30px"},
        ),
        dcc.Loading(
            id="loading",
            type="default",
            children=html.Div(
                id="output-state", style={"display": "flex", "height": "540"}
            ),
        ),
    ],
    style={"display": "flex", "flex-direction": "column", "align-items": "center"},
)


def return_iframe(macaulayID):
    macaulayLink = "https://macaulaylibrary.org/asset/" + macaulayID + "/embed"

    iframeObj = html.Iframe(
        src=macaulayLink,
        height=500,
        width=320,
        style={"border": "none"},
        allow="fullscreen",
    )

    return iframeObj


@callback(
    Output("output-state", "children"),
    Input("submit-button-state", "n_clicks"),
    State("input-text-state", "value"),
)
def update_output(n_clicks, input_text):
    # 3 iframes
    results = match(input_text)
    iframes = []

    for key, similarity in results.items():
        thisIframe = return_iframe(entries[key]["macaulayID"])
        iframes.append(
            html.Div(
                [f"Similarity: {similarity:.4f}", thisIframe],
                style={
                    "display": "flex",
                    "flex-direction": "column",
                    "margin": "20px",
                    "font-family": "jura",
                    "font-weight": "bold",
                },
            )
        )
        # iframes.append(thisIframe)

    return iframes


if __name__ == "__main__":
    app.run(debug=False)