In [None]:
%pip install -r requirements.txt

In [None]:
from pathlib import Path

# Clone: https://huggingface.co/datasets/lmsys/lmsys-chat-1m

data = Path("lmsys-chat-1m/data")
data_files = list(data.glob("*"))
data_files

In [None]:
import polars as pl

df = pl.read_parquet(data_files).filter(
    (
        # Initial filters
        pl.col.language == "English"
    )
)
display(df.shape)
df.head(3)

In [None]:
# Easier to work with in-memory
df_sampled = df.sample(100_000, seed=42)
df_sampled.shape

In [None]:
# Conversation IDs that contain wanted content.
df_exp_ml = (
    df_sampled.explode("conversation")
    .unnest("conversation")
    .filter(
        # Convos about ML
        pl.col.content.str.contains("machine learning")
    )
)
print(df_exp_ml.shape)
df_sampled_ml = df_sampled.filter(
    pl.col.conversation_id.is_in(df_exp_ml["conversation_id"])
)
df_sampled_ml.head(3)

In [None]:
# Add JSON-encoded field to input to the embedding
df_sampled_ml_embed = df_sampled_ml.with_columns(
    conversation_json=pl.col.conversation.list.to_struct(
        fields=["query", "response"]
    ).struct.json_encode()
)
df_sampled_ml_embed.head(3)

In [None]:
# Embed the JSONs

from dotenv import load_dotenv
from openai import OpenAI

load_dotenv()

client = OpenAI()

In [None]:
from joblib import Parallel, delayed
from tqdm import tqdm


def get_embedding(text):
    # client not pickle-able
    response = OpenAI().embeddings.create(input=text, model="text-embedding-ada-002")
    return response


conv_jsons = df_sampled_ml_embed["conversation_json"].to_list()
embedding_responses = Parallel(n_jobs=50)(
    delayed(get_embedding)(c) for c in tqdm(conv_jsons)
)
embeddings = [e.data[0].embedding for e in embedding_responses]

In [None]:
df_sampled_ml_embedded = df_sampled_ml_embed.with_columns(
    conv_embedding=pl.Series(embeddings)
)
df_sampled_ml_embedded.head(1)

In [None]:
df_sampled_ml_embedded["conv_embedding"][0].shape

In [None]:
import numpy as np

# Prevent confusing integer overflow errors in polars/usearch downstream
i64_safe_max = np.iinfo(np.int64).max // 10


# Define a hashing function
def hash_string(s):
    return hash(s) % i64_safe_max


df_app = df_sampled_ml_embedded.with_columns(
    conv_id_hash=pl.col.conversation_id.map_elements(hash_string)
)
df_app.head(1)

In [None]:
# Create a search index
import numpy as np
from usearch.index import Index

index = Index(ndim=len(df_app["conv_embedding"][0]))

for conv_id_hash, conv_embedding in df_app.select("conv_id_hash", "conv_embedding")[
    :20
].iter_rows():
    index.add(conv_id_hash, np.array(conv_embedding))

# print(df_app[0]["conv_embedding"][0])
example_query_row = df_app[0]
matches = index.search(example_query_row["conv_embedding"][0].to_numpy(), 10)

assert matches[0].key == example_query_row["conv_id_hash"][0]
assert matches[0].distance <= 0.001
matches[0]

In [None]:
df_app.columns

In [None]:
# Write out file for use by the app.
df_app.select(
    "conv_id_hash", "conversation_id", "conversation", "model", "conv_embedding"
).write_parquet("data/app_embeds.parquet")
!du -sh data/*.parquet