purpose of this notebook is to evaluate on the static embedding model `minishlab/potion-retrieval-32M`

In [1]:
from model2vec import StaticModel
from sentence_transformers import SentenceTransformer
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

In [2]:
persona_name = "creative_hobyist"
# persona_name = "chidam"

# MODEL_NAME = "Xenova/all-MiniLM-L6-v2" 
MODEL_NAME = "minishlab/potion-retrieval-32M"

GOLDEN_QUERY_COLS = ["search_query","url"]
PROFILE_COLS = ["url", "title", "description", "frecency", "last_visit_date"] #"url_hash"

PROFILE_FILENAMES = {
    "creative_hobyist": {
        "golden_query_data" : "../data/creative_hobyist_profile_golden_queries.csv",
        "profile_data": "../data/profiles/creative_hobbyist.csv"
    },
    "chidam": {
        "golden_query_data" : "../data/chidam_golden_query.csv",
        "profile_data": "../data/history_output_file.csv"
    }
}
# PROFILE_FILENAME = "../data/chidam_golden_query.csv"


In [3]:
def get_model(MODEL_NAME):
    if MODEL_NAME == "Xenova/all-MiniLM-L6-v2":
        return SentenceTransformer("all-MiniLM-L6-v2")
    elif MODEL_NAME == "minishlab/potion-retrieval-32M":
        return StaticModel.from_pretrained("minishlab/potion-retrieval-32M")

def get_profile_filenames(persona_name):
    if persona_name not in PROFILE_FILENAMES.keys():
        return ValueError(f"persona_name entered not in {list(PROFILE_FILENAMES.keys())}")
    golden_queries_df = pd.read_csv(PROFILE_FILENAMES[persona_name]['golden_query_data'])[GOLDEN_QUERY_COLS]
    profile_df = pd.read_csv(PROFILE_FILENAMES[persona_name]["profile_data"])
    #[PROFILE_COLS]
    return golden_queries_df, profile_df

In [4]:
model = get_model(MODEL_NAME)

In [5]:
golden_queries_df, profile_df = get_profile_filenames(persona_name)
golden_queries_df

Unnamed: 0,search_query,url
0,Best mosquito repellent outdoor,http://www.wikihow.com/Get-Rid-of-Mosquitoes-i...
1,How to get rid of heat stains on wood,http://tipnut.com/diy-how-to-remove-white-heat...
2,Zucchini recipes,http://startcooking.com/how-to-zucchini
3,Planting grass,http://www.wikihow.com/Grow-Pampas-Grass
4,Hamstring remedies,http://www.runnersworld.com/for-beginners-only...
5,Saving crickets,https://www.wikihow.com/Keep-Crickets-Alive
6,Organic fertilizers,http://www.sunset.com/garden/garden-basics/cra...
7,Whats a choir,https://simple.wikipedia.org/wiki/Choir_(music)
8,Make money from music,http://hiphopmakers.com/how-to-sell-beats-onli...
9,Chlorine for pool,https://www.inyopools.com/blog/what-type-of-po...


In [6]:
profile_df

Unnamed: 0.1,Unnamed: 0,url,title,description,frecency,url_hash,last_visit_date
0,0,http://www.ehow.com/how_5409995_cook-fennel.html,How to Cook Fennel,"A decidedly odd-looking vegetable, fennel rese...",5000,41483952323940,1738272723614003
1,1,http://www.phschool.com/language_arts/,Language Arts,Language Arts Student Resources Textbook Compa...,5000,38606764313101,1738272733950727
2,2,https://www.8notes.com/biographies/adams.asp,Items to buy by Bryan Adams,Items to buy by Bryan Adams (Everything I Do) ...,4999,30122553772281,1738272724722098
3,3,http://www.wikihow.com/Get-Rid-of-Mosquitoes-i...,How to Get Rid of Mosquitoes in Your Yard,1 Drain any areas with standing water. Mosquit...,4998,54238596538326,1738272735386823
4,4,http://www.goddessgift.com/goddess-myths/godde...,Goddess Symbols: Hestia,Goddess Symbols: Hestia Goddess Symbols and Sa...,4997,79098488697285,1738272735282816
...,...,...,...,...,...,...,...
6247,6247,http://www.cnn.com/2008/TRAVEL/traveltips/06/2...,How to get airport lounge discounts,By Andrea Bennett (Travel + Leisure) -- The fo...,107,69832714863365,1738272729266412
6248,6248,https://www.diabetesselfmanagement.com/managin...,What Makes Blood Glucose Go Up or Down?,What Makes Blood Glucose Go Up or Down? Update...,106,31807037351241,1738272724750099
6249,6249,https://fit4less.ca/facts,Fit4Less Facts,Fit4Less Facts Membership Card You must have y...,105,25780086107407,1738272735062801
6250,6250,https://blog.onestoppoppyshoppe.com/articles/p...,One Stop Poppy Shoppe Blog,Poppy Flower Seeds – Germinating and Growing P...,103,39631050424751,1738272734902791


In [7]:
%%time
profile_embeddings = model.encode(profile_df['title'].str.lower().values.tolist())

CPU times: user 232 ms, sys: 32.6 ms, total: 265 ms
Wall time: 95.1 ms


In [8]:


def search(query, model, profile_embeddings, profile_df, top_k=2, threshold=0.6):
    query_embeddings = model.encode([query])
    sim_scores = cosine_similarity(query_embeddings, profile_embeddings)[0]  # shape: (N,)

    valid_idx = np.where(sim_scores >= threshold)[0]
    if valid_idx.size == 0:
        return []

    ranked_idx = valid_idx[np.argsort(-sim_scores[valid_idx])[:top_k]]
    return profile_df.iloc[ranked_idx]['url'].tolist()


In [9]:
# search("Best mosquito repellent outdoor", model, profile_embeddings, profile_df, top_k=5)
# search("Zucchini recipes", model, profile_embeddings, profile_df, top_k=2)

In [10]:
correct = 0
for idx, row in golden_queries_df.iterrows():
    actual = row['url']
    preds = search(row['search_query'].lower(), model, profile_embeddings, profile_df, top_k=3, threshold=0.5)
    if actual in preds:
        correct += 1
        
    # else:
    #     print(idx, row['search_query'], actual, preds)
        


In [11]:
correct

30

In [12]:
correct/len(golden_queries_df)

0.7692307692307693

In [13]:
## Results for creative_hobyist (38 queries)
## allXenova/all-MiniLM-L6-v2: 33 correct (0.84)
## minishlab/potion-retrieval-32M: 30 correct (0.78)

In [14]:
## Results for chidam (48 queries)
## allXenova/all-MiniLM-L6-v2: 17 correct (0.35)
## minishlab/potion-retrieval-32M: 20 correct (0.40)

#### Using long & short Q datasets and profiles

In [15]:
longQ = [
      "alternative medicine to manage ringing in my ear",
      "solutions for noisy neighbors",
      "how to train a cat to use the litter box",
      "how to make your dog loss weight",
      "what are the best foods for dogs",
      "how to keep your kids safe online",
      "some recommendations for vegan recipes for kids",
      "best places to visit in new york in summer",
      "how much does it cost to go to niagara falls",
      "how to take screenshot on windows 11",
      "what are some of the symptoms of high blood pressure",
      "what are some treatment options for stroke",
      "organic dog food for sensitive stomachs",
      "best budget-friendly family vacation spots in California",
      "what are some popular open source projects",
    ]

shortQ = [
      "healthy rice",
      "endodontist",
      "arsenal",
      "gavaskar border trophy",
      "grand slam champion",
      "capital gain",
      "credit limit",
      "annuity",
      "physiology",
      "myopia",
      "amyloidosis",
      "alopecia",
      "dermatitis",
      "coffee shop",
      "ball drop",
    ]

In [16]:
len(longQ), len(shortQ)

(15, 15)

In [17]:
lonQ_longD_file_path = "../data/profiles/test/profile_longQ_longD_df.json"
lonQ_shortD_file_path = "../data/profiles/test/profile_longQ_shortD_df.json"

shortQ_longD_file_path = "../data/profiles/test/profile_shortQ_longD_df.json"
shortQ_shortD_file_path = "../data/profiles/test/profile_shortQ_shortD_df.json"

In [18]:
lonQ_longD_df = pd.read_json(lonQ_longD_file_path)[PROFILE_COLS]
lonQ_shortD_df = pd.read_json(lonQ_shortD_file_path)[PROFILE_COLS]

shortQ_longD_df = pd.read_json(shortQ_longD_file_path)[PROFILE_COLS]
shortQ_shortD_df = pd.read_json(shortQ_shortD_file_path)[PROFILE_COLS]


lonQ_longD_df.shape, lonQ_shortD_df.shape, shortQ_longD_df.shape, shortQ_shortD_df.shape

((8743, 5), (7411, 5), (8717, 5), (7444, 5))

In [19]:
def search_long_short_queries(query, model, profile_embeddings, profile_df, top_k=2, threshold=0.6):
    query_embeddings = model.encode([query])
    sim_scores = cosine_similarity(query_embeddings, profile_embeddings)[0]

    valid_idx = np.where(sim_scores >= threshold)[0]
    if valid_idx.size == 0:
        return []
    for idx in valid_idx:
        if idx <= 30:
            return True
    return False
    # ranked_idx = valid_idx[np.argsort(-sim_scores[valid_idx])[:top_k]]
    # return profile_df.iloc[ranked_idx]['url'].tolist()


In [20]:
lonQ_longD_profile_embeddings = model.encode(lonQ_longD_df['title'].str.lower() + " " + lonQ_longD_df['description'].str.lower())
print(lonQ_longD_profile_embeddings.shape)
correct_longQ_longD = 0
for query in longQ:
    res = search_long_short_queries(query, model, lonQ_longD_profile_embeddings, lonQ_longD_df, top_k=2, threshold=0.6)
    if res:
        correct_longQ_longD += 1

print("correct_longQ_longD", correct_longQ_longD)
print("correct percentage", round(correct_longQ_longD/len(longQ), 3))

(8743, 512)
correct_longQ_longD 10
correct percentage 0.667


In [21]:
lonQ_shortD_profile_embeddings = model.encode(lonQ_shortD_df['title'].str.lower() + " " + lonQ_shortD_df['description'].str.lower())
print(lonQ_shortD_profile_embeddings.shape)
correct_longQ_shortD = 0
for query in longQ:
    res = search_long_short_queries(query, model, lonQ_shortD_profile_embeddings, lonQ_shortD_df, top_k=2, threshold=0.6)
    if res:
        correct_longQ_shortD += 1

print("correct_longQ_shortD", correct_longQ_shortD)
print("correct percentage", round(correct_longQ_shortD/len(longQ), 3))

(7411, 512)
correct_longQ_shortD 14
correct percentage 0.933


In [22]:
shortQ_longD_profile_embeddings = model.encode(shortQ_longD_df['title'].str.lower() + " " + shortQ_longD_df['description'].str.lower())
print(shortQ_longD_profile_embeddings.shape)
correct_shortQ_longD = 0
for query in shortQ:
    res = search_long_short_queries(query, model, shortQ_longD_profile_embeddings, shortQ_longD_df, top_k=2, threshold=0.6)
    if res:
        correct_shortQ_longD += 1

print("correct_shortQ_longD", correct_shortQ_longD)
print("correct percentage", round(correct_shortQ_longD/len(shortQ), 3))

(8717, 512)
correct_shortQ_longD 10
correct percentage 0.667


In [23]:
shortQ_shortD_profile_embeddings = model.encode(shortQ_shortD_df['title'].str.lower() + " " + shortQ_shortD_df['description'].str.lower())
print(shortQ_shortD_profile_embeddings.shape)
correct_shortQ_shortD = 0
for query in shortQ:
    res = search_long_short_queries(query, model, shortQ_shortD_profile_embeddings, shortQ_shortD_df, top_k=2, threshold=0.6)
    if res:
        correct_shortQ_shortD += 1

print("correct_shortQ_longD", correct_shortQ_shortD)
print("correct percentage", round(correct_shortQ_shortD/len(shortQ), 3))

(7444, 512)
correct_shortQ_longD 14
correct percentage 0.933


In [24]:
## store these in sqlite-vec DB

In [25]:
import numpy as np
import sqlite3
import sqlite_vec

from typing import List
import struct

In [26]:
def serialize_f32(vector: List[float]) -> bytes:
    """serializes a list of floats into a compact "raw bytes" format"""
    return struct.pack("%sf" % len(vector), *vector)

In [27]:
db = sqlite3.connect(":memory:")
db.enable_load_extension(True)
sqlite_vec.load(db)
db.enable_load_extension(False)

sqlite_version, vec_version = db.execute(
    "select sqlite_version(), vec_version()"
).fetchone()
print(f"sqlite_version={sqlite_version}, vec_version={vec_version}")


sqlite_version=3.50.1, vec_version=v0.1.6


In [28]:
lonQ_shortD_profile_embeddings.shape

(7411, 512)

In [29]:
persona = "lonQ_shortD"
EMBEDDING_SIZE = lonQ_shortD_profile_embeddings.shape[1]
db.execute(f"CREATE VIRTUAL TABLE vec_items_{persona}_2 USING vec0(embedding float[{EMBEDDING_SIZE}], embedding_coarse bit[{EMBEDDING_SIZE}])")

with db:
    for idx, vec in enumerate(lonQ_shortD_profile_embeddings):
        embedding = serialize_f32(vec)
        db.execute(
            f"INSERT INTO vec_items_{persona}_2(rowid, embedding, embedding_coarse) VALUES (?, ?, vec_quantize_binary(?))",
            [idx, embedding, embedding],  # Convert vector to binary format
        )


In [30]:
longQ

['alternative medicine to manage ringing in my ear',
 'solutions for noisy neighbors',
 'how to train a cat to use the litter box',
 'how to make your dog loss weight',
 'what are the best foods for dogs',
 'how to keep your kids safe online',
 'some recommendations for vegan recipes for kids',
 'best places to visit in new york in summer',
 'how much does it cost to go to niagara falls',
 'how to take screenshot on windows 11',
 'what are some of the symptoms of high blood pressure',
 'what are some treatment options for stroke',
 'organic dog food for sensitive stomachs',
 'best budget-friendly family vacation spots in California',
 'what are some popular open source projects']

In [31]:
def predict_coarse(query, lonQ_shortD_df):
    query_serialized_vec = serialize_f32(model.encode(query))
    
    retrived_results = db.execute(f"""
    with coarse_matches as (
      select
        rowid,
        embedding
      from vec_items_{persona}_2
      where embedding_coarse match vec_quantize_binary(:query_serialized_vec)
      order by distance
      limit 200
    )
    select
      rowid,
      vec_distance_cosine(embedding, :query_serialized_vec)
    from coarse_matches
    order by 2
    limit 2;
    """, {"query_serialized_vec": query_serialized_vec}).fetchall()
    print(retrived_results)
    return lonQ_shortD_df.iloc[[row for row,dist in retrived_results]]

In [32]:
query = longQ[0]
print(query)
predict_coarse(query, lonQ_shortD_df)

alternative medicine to manage ringing in my ear
[(1, 0.33442482352256775), (0, 0.33442482352256775)]


Unnamed: 0,url,title,description,frecency,last_visit_date
1,http://www.rightdiagnosis.com/sym/ringing_in_e...,Ringing in ears,,281700,1736189753947959
0,http://symptomchecker.webmd.com/single-symptom...,Ringing in ears,,1421560,1736196807185413
