In [18]:
import pandas as pd
from openai import OpenAI
from scipy import spatial

client = OpenAI()
EMBEDDING_MODEL = "text-embedding-3-small"

In [19]:
SAVE_PATH = "embeddings/embedding_1.csv"
df = pd.read_csv(SAVE_PATH)
df

Unnamed: 0,text,embedding
0,mysql> desc job_type_config;\n+---------------...,"[-0.030408846214413643, 0.03918106481432915, 0..."
1,mysql> desc job;\n+----------------------+----...,"[-0.011904806829988956, 0.046049900352954865, ..."
2,The system offers the following services:\n- B...,"[-0.03950067237019539, 0.02617255598306656, 0...."


In [44]:
import json

def strings_ranked_by_relatedness(
    query: str,
    df: pd.DataFrame,
    relatedness_fn=lambda x, y: 1 - spatial.distance.cosine(x, y),
    top_n: int = 100
) -> tuple[list[str], list[float]]:
    """Returns a list of strings and relatednesses, sorted from most related to least."""
    query_embedding_response = client.embeddings.create(
        model=EMBEDDING_MODEL,
        input=query,
    )
    query_embedding = query_embedding_response.data[0].embedding
    strings_and_relatednesses = [
        (row["text"], relatedness_fn(query_embedding, json.loads(row["embedding"])))
        for i, row in df.iterrows()
    ]
    strings_and_relatednesses.sort(key=lambda x: x[1], reverse=True)
    strings, relatednesses = zip(*strings_and_relatednesses)
    return strings[:top_n], relatednesses[:top_n]

In [24]:
query = 'What is the batch size for job type config with id job_config_a3vafgaw?'

In [45]:
strings, relatednesses = strings_ranked_by_relatedness('What is the batch size for job type config with id job_config_a3vafgaw?', df)

In [46]:
strings

('mysql> desc job_type_config;\n+--------------------------+-----------------+------+-----+---------+-------+\n| Field                    | Type            | Null | Key | Default | Extra |\n+--------------------------+-----------------+------+-----+---------+-------+\n| id                       | varchar(255)    | NO   | PRI | NULL    |       |\n| batch_size               | int             | YES  |     | NULL    |       |\n| created_at               | datetime        | NO   |     | NULL    |       |\n| delimiter                | varchar(255)    | YES  |     | NULL    |       |\n| file_decryption_password | varchar(255)    | YES  |     | NULL    |       |\n| header_list              | varbinary(2048) | YES  |     | NULL    |       |\n| header_record_number     | int             | YES  |     | NULL    |       |\n| is_file_encrypted        | bit(1)          | YES  |     | NULL    |       |\n| is_single_file_report    | bit(1)          | YES  |     | NULL    |       |\n| modified_at       

In [47]:
relatednesses

(0.6069754620268071, 0.4106450860148787, 0.39569149214101906)