In [4]:
import mysql.connector
from mysql.connector import Error
from dotenv import load_dotenv
import os
import json
import os
from openai import OpenAI

client = OpenAI()
EMBEDDING_MODEL = "text-embedding-3-small"

# Load environment variables from .env file in parent directory
load_dotenv(dotenv_path='../../.env')

# Retrieve database credentials from environment variables
db_host = os.getenv('DB_HOST')
db_user = os.getenv('DB_USER')
db_password = os.getenv('DB_PASSWORD')

def create_embedding(data):

    response = client.embeddings.create(model=EMBEDDING_MODEL, input=[data])
    embedding = [e.embedding for e in response.data][0]

    return embedding

class DatabaseConnection:
    def __init__(self):
        self.connection = None
        self.connect()
    
    def connect(self):
        try:
            self.connection = mysql.connector.connect(
                host=db_host,
                user=db_user,
                password=db_password
            )
            if self.connection.is_connected():
                print("Connected to MySQL server")
        except Error as e:
            print(f"Error: {e}")
            self.connection = None
    
    def close(self):
        if self.connection and self.connection.is_connected():
            self.connection.close()
            print("MySQL connection is closed")

    def get_databases(self):
        if not self.connection or not self.connection.is_connected():
            print("No active MySQL connection")
            return []

        try:
            cursor = self.connection.cursor()
            cursor.execute("SHOW DATABASES")
            databases = cursor.fetchall()
            return [db[0] for db in databases]
        
        except Error as e:
            print(f"Error: {e}")
            return []
        
        finally:
            cursor.close()

    def get_tables(self, db_name):
        if not self.connection or not self.connection.is_connected():
            print("No active MySQL connection")
            return []

        try:
            cursor = self.connection.cursor()
            cursor.execute(f"USE {db_name}")
            cursor.execute("SHOW TABLES")
            tables = cursor.fetchall()
            return [table[0] for table in tables]
        
        except Error as e:
            print(f"Error: {e}")
            return []
        
        finally:
            cursor.close()

    def describe_table(self, db_name, table_name):
        if not self.connection or not self.connection.is_connected():
            print("No active MySQL connection")
            return None

        try:
            cursor = self.connection.cursor()
            cursor.execute(f"USE {db_name}")
            cursor.execute(f"DESCRIBE {table_name}")
            schema = cursor.fetchall()
            return schema
        
        except Error as e:
            print(f"Error: {e}")
            return None
        
        finally:
            cursor.close()

embeddings = []

db_conn = DatabaseConnection()
tables = ['job', 'job_type_config', 'shedlock']
for table in tables:
    schema = db_conn.describe_table('hyperface_platform_dev', table)

    for row in schema:
        # convert row into text description
        row_description = f"The table {table} has a column {row[0]} of type {row[1]}"
        if row[2] == 'YES':
            row_description += " that can be null"
        else:
            row_description += " that cannot be null"
        if row[3] == 'PRI':
            row_description += " and is a primary key"
        if row[4]:
            row_description += f" with a default value of {row[4]}"
        if row[5]:
            row_description += f" and has the extra attribute {row[5]}"

        print('create_embedding')
        embedding = create_embedding(row_description)
        embeddings.append((table, embedding))

db_conn.close()

Connected to MySQL server
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
create_embedding
MySQL connection is closed


In [10]:
# saving to avoid API call
import pickle
pickle.dump(embeddings, open('pickles/per_field_chunking.pickle', 'ab'))


In [15]:
# run this cell if file already exists
from openai import OpenAI

client = OpenAI()
EMBEDDING_MODEL = "text-embedding-3-small"

import pickle
embeds = pickle.load(open('pickles/per_field_chunking.pickle', 'rb'))

def create_embedding(data):

    response = client.embeddings.create(model=EMBEDDING_MODEL, input=[data])
    embedding = [e.embedding for e in response.data][0]

    return embedding

In [12]:
query = 'What is the status of the most recently created job?'
query_embed = create_embedding(query)

In [19]:
from scipy import spatial

relatedness_fn = lambda x, y: 1 - spatial.distance.cosine(x, y)
score = []

for table, embedding in embeds:
    score.append((table, relatedness_fn(query_embed, embedding)))

# group by table
max_score = {}
for table, s in score:
    if table not in max_score or s > max_score[table]:
        max_score[table] = s

score = [(k, v) for k, v in max_score.items()]

score.sort(key=lambda x: x[1], reverse=True)
score

[('job_type_config', 0.5175739728939045),
 ('job', 0.4127968242304896),
 ('shedlock', 0.3518127326604361)]

In [17]:
query = 'What is the batch size of the most recently created job?'
query_embed = create_embedding(query)

from scipy import spatial

relatedness_fn = lambda x, y: 1 - spatial.distance.cosine(x, y)
score = []

for table, embedding in embeds:
    score.append((table, relatedness_fn(query_embed, embedding)))

# group by table
max_score = {}
for table, s in score:
    if table not in max_score or s > max_score[table]:
        max_score[table] = s

score = [(k, v) for k, v in max_score.items()]

score.sort(key=lambda x: x[1], reverse=True)
score

[('job_type_config', 0.5779927323198352),
 ('job', 0.3325697380943399),
 ('shedlock', 0.18369545564856493)]

In [18]:
query = 'Is the most recently modified job encrypted?'
query_embed = create_embedding(query)

from scipy import spatial

relatedness_fn = lambda x, y: 1 - spatial.distance.cosine(x, y)
score = []

for table, embedding in embeds:
    score.append((table, relatedness_fn(query_embed, embedding)))

# group by table
max_score = {}
for table, s in score:
    if table not in max_score or s > max_score[table]:
        max_score[table] = s

score = [(k, v) for k, v in max_score.items()]

score.sort(key=lambda x: x[1], reverse=True)
score

[('job_type_config', 0.5175739728939045),
 ('job', 0.4127968242304896),
 ('shedlock', 0.3518127326604361)]