In [None]:
import os
import vecs
import numpy as np
import pandas as pd

In [None]:
# create supabase db connection
url: str = os.environ.get("SUPABASE_URL")
key: str = os.environ.get("SUPABASE_KEY")
db_pass: str = os.environ.get("DATABASE_PASSWORD")
db_connection = "postgresql://postgres:[YOUR-PASSWORD]@db.pedbaridbklowihouaqa.supabase.co:5432/postgres".replace("[YOUR-PASSWORD]", db_pass)

vx = vecs.create_client(db_connection)
docs = vx.get_or_create_collection(name="movies", dimension=384)
docs

In [None]:
# load the embeddings 
EMBEDDINGS_PATH = os.path.join("..", "data", "embeddings.npy")
embeddings = np.load(EMBEDDINGS_PATH)
embeddings.shape

In [None]:
# load the movies csv
DATA_PATH = os.path.join("wiki_movies.csv")
data = pd.read_csv(DATA_PATH)
df = data[["Title", "Wiki Page", "Plot"]]

In [None]:
# get the number of duplicate titles
val=len(list(df["Title"]))
val2=len(list(set(df["Title"])))
val-val2

In [None]:
# create records for supabase out of each 
records = []
for idx, row in df.iterrows():
    record = (
        str(idx),
        embeddings[idx],
        {"title": row["Title"],
         "wiki_page": row["Wiki Page"]}
    )
    records.append(record)
len(records)

In [None]:
# add records to supabase
docs.upsert(records)

In [None]:
from sentence_transformers import SentenceTransformer


def query_db(queries, top_k):
    # get top k results
    top_k = min(5, len(df)) # TODO change len(df)

    # index the collection for fast search performance
    docs.create_index()

    # make queries to database
    results = {}
    for query in queries:
        # get query embedding
        embedder = SentenceTransformer('all-MiniLM-L6-v2')
        query_embedding = embedder.encode(query)

        result = docs.query(
            data=query_embedding,            # embedding to search
            limit=top_k,                     # number of records to return
            filters={},                      # metadata filters -- none right now
            include_metadata=True,           # include metadata in results -- {title, wiki page}
        )
        
        results[query] = result

    # disconnect from the database
    vx.disconnect()
    return results

In [None]:
top_k = 3
queries = ["jack gets a beanstalk and a giant golden egg", 
            "a guy shoots 100 guys",
            "child goes on magical adventure",
            "animated pirates fight over treasure"] 
results = query_db(queries, top_k)

In [None]:
for query in queries:
    print(f"Top {top_k} most similar sentences in corpus:\n")
    print("===================================================")
    print("Query:", query,"\n")

    query_results = results[query]

    for result in query_results:
        print(result[1]["title"])