In [2]:
import mysql.connector
from mysql.connector import Error
from dotenv import load_dotenv
import os
import json
import os
from openai import OpenAI

client = OpenAI()
EMBEDDING_MODEL = "text-embedding-3-small"

# Load environment variables from .env file in parent directory
load_dotenv(dotenv_path='../../.env')

# Retrieve database credentials from environment variables
db_host = os.getenv('DB_HOST')
db_user = os.getenv('DB_USER')
db_password = os.getenv('DB_PASSWORD')

def create_embedding(data):

    response = client.embeddings.create(model=EMBEDDING_MODEL, input=[data])
    embedding = [e.embedding for e in response.data][0]

    return embedding

class DatabaseConnection:
    def __init__(self):
        self.connection = None
        self.connect()
    
    def connect(self):
        try:
            self.connection = mysql.connector.connect(
                host=db_host,
                user=db_user,
                password=db_password
            )
            if self.connection.is_connected():
                print("Connected to MySQL server")
        except Error as e:
            print(f"Error: {e}")
            self.connection = None
    
    def close(self):
        if self.connection and self.connection.is_connected():
            self.connection.close()
            print("MySQL connection is closed")

    def get_databases(self):
        if not self.connection or not self.connection.is_connected():
            print("No active MySQL connection")
            return []

        try:
            cursor = self.connection.cursor()
            cursor.execute("SHOW DATABASES")
            databases = cursor.fetchall()
            return [db[0] for db in databases]
        
        except Error as e:
            print(f"Error: {e}")
            return []
        
        finally:
            cursor.close()

    def get_tables(self, db_name):
        if not self.connection or not self.connection.is_connected():
            print("No active MySQL connection")
            return []

        try:
            cursor = self.connection.cursor()
            cursor.execute(f"USE {db_name}")
            cursor.execute("SHOW TABLES")
            tables = cursor.fetchall()
            return [table[0] for table in tables]
        
        except Error as e:
            print(f"Error: {e}")
            return []
        
        finally:
            cursor.close()

    def describe_table(self, db_name, table_name):
        if not self.connection or not self.connection.is_connected():
            print("No active MySQL connection")
            return None

        try:
            cursor = self.connection.cursor()
            cursor.execute(f"USE {db_name}")
            cursor.execute(f"DESCRIBE {table_name}")
            schema = cursor.fetchall()
            return schema
        
        except Error as e:
            print(f"Error: {e}")
            return None
        
        finally:
            cursor.close()

embeddings = []

db_conn = DatabaseConnection()
tables = ['job', 'job_type_config', 'shedlock']
for table in tables:
    schema = db_conn.describe_table('hyperface_platform_dev', table)
    if schema:
        table_structure = {
            "database": 'jetfire',
            "table": table,
            "schema": []
        }
        for column in schema:
            column_info = {
                "Field": column[0],
                "Type": column[1],
                "Null": column[2],
                "Key": column[3],
                "Default": column[4],
                "Extra": column[5]
            }
            table_structure["schema"].append(column_info)

        table_structure_json = json.dumps(table_structure, indent=4)
        embedding = create_embedding(table_structure_json)
        embeddings.append((table, embedding))

db_conn.close()

Connected to MySQL server
MySQL connection is closed


In [43]:
# saving to avoid API call
import pickle
pickle.dump(embeddings, open('pickles/json_format.pickle', 'ab'))


In [6]:
# run this cell if file already exists
from openai import OpenAI

client = OpenAI()
EMBEDDING_MODEL = "text-embedding-3-small"

import pickle
embeds = pickle.load(open('pickles/json_format.pickle', 'rb'))

def create_embedding(data):

    response = client.embeddings.create(model=EMBEDDING_MODEL, input=[data])
    embedding = [e.embedding for e in response.data][0]

    return embedding

In [45]:
query = 'What is the status of the most recently created job?'
query_embed = create_embedding(query)

In [50]:
from scipy import spatial

relatedness_fn = lambda x, y: 1 - spatial.distance.cosine(x, y)
score = []

for table, embedding in embeds:
    score.append((table, relatedness_fn(query_embed, embedding)))

score.sort(key=lambda x: x[1], reverse=True)
score

[('job', 0.3808369628796605),
 ('job_type_config', 0.32829081797517456),
 ('shedlock', 0.23842665254553452)]

In [5]:
query = 'What is the batch size of the most recently created job?'
query_embed = create_embedding(query)

from scipy import spatial

relatedness_fn = lambda x, y: 1 - spatial.distance.cosine(x, y)
score = []

for table, embedding in embeds:
    score.append((table, relatedness_fn(query_embed, embedding)))

score.sort(key=lambda x: x[1], reverse=True)
score

[('job_type_config', 0.35014549987500754),
 ('job', 0.32311440244171874),
 ('shedlock', 0.18835632415872738)]