In [45]:
import sys
import os
sys.path.append(os.path.abspath('../version_3'))
from embedding import update_embeddings, get_embeddings, DatabaseConnection

update_embeddings()
embeds = get_embeddings()

Connected to MySQL server
MySQL connection is closed


In [46]:
from openai import OpenAI

client = OpenAI()
EMBEDDING_MODEL = "text-embedding-3-small"

def create_embedding(data):

    response = client.embeddings.create(model=EMBEDDING_MODEL, input=[data])
    embedding = [e.embedding for e in response.data][0]

    return embedding

In [47]:
from scipy import spatial

def get_top_N_related_tables(query, N=8):
    query_embed = create_embedding(query)
    relatedness_fn = lambda x, y: 1 - spatial.distance.cosine(x, y)
    score = []

    for table, embedding in embeds.items():
        score.append((table, relatedness_fn(query_embed, embedding)))

    score.sort(key=lambda x: x[1], reverse=True)
    return score[:N]

In [48]:
def generate_selection_prompt(candidates, query):
    db_conn = DatabaseConnection()
    
    prompt = 'Here are the tables available:\n'
    for table, _ in candidates:
        schema = db_conn.describe_table(*table.split('.'))
        fields = [column[0] for column in schema]
        prompt += f'{table}: {", ".join(fields)}\n'

    prompt += "Please find which tables are relevant to solve queries. You must answer ONLY a valid json array and NOTHING ELSE.\n\n"
    prompt += f"Query: In the {table.split('.')[1]} table of {table.split('.')[0]} database, is the {schema[0][0]} column a primary key?\n"
    prompt += f"Relevant Tables: [\"{table}\"]\n"
    prompt += f"Query: In the {table.split('.')[1]} table of {table.split('.')[0]} database, Are there multiple values of {schema[0][0]} which are same?\n"
    prompt += f"Relevant Tables: [\"{table}\"]\n"
    prompt += f"Query: {query}\n"
    prompt += f"Relevant Tables: "
    return prompt

In [49]:
candidates = get_top_N_related_tables('What is the email id of the customer with customer number 987654?')

In [50]:
candidates

[('grimlock_dev_db.customer', 0.44377467502977974),
 ('hyperface_dev_db.issuer_customer', 0.40736577825273956),
 ('hyperface_dev_db.customer', 0.3886162116864449),
 ('hyperface_dev_db.customer_kyc_detail_kyc_proofs', 0.3754239888906663),
 ('grimlock_dev_db.customer_dump', 0.3729263876696144),
 ('hyperface_dev_db.email_template', 0.36282547742174276),
 ('grimlock_dev_db.customer_number_change_tracker', 0.36249814882874554),
 ('grimlock_dev_db.customer_dump_retry', 0.3578643197879774)]

In [51]:
k = generate_selection_prompt(candidates, 'What is the email id of the customer with customer number 987654?')

Connected to MySQL server


In [52]:
print(k)

Here are the tables available:
grimlock_dev_db.customer: customer_number, created_on, customer_id, email_id, last_updated, mobile
hyperface_dev_db.issuer_customer: id, country_code, created_on, current_address, date_of_birth, email, first_name, gender, last_name, last_updated_on, middle_name, mobile, mobile_country_code, mobile_hash, mobile_masked, nationality, pancard, permanent_address, switch_metadata, title, issuer_id, schedule_ofcreated_on, switch_customer_number, kyc_status
hyperface_dev_db.customer: id, country_code, created_on, current_address, date_of_birth, email, first_name, gender, last_name, last_updated_on, middle_name, mobile, mobile_country_code, mobile_hash, mobile_masked, nationality, pancard, permanent_address, preferred_name, switch_metadata, title, client_id, schedule_ofcreated_on, switch_customer_number, kyc_status, office_address
hyperface_dev_db.customer_kyc_detail_kyc_proofs: customer_kyc_detail_id, kyc_proofs_id
grimlock_dev_db.customer_dump: id, account_creat

In [53]:
from openai import OpenAI
client = OpenAI()

response = client.chat.completions.create(
  model="gpt-3.5-turbo-instruct",
  prompt=k
)

TypeError: Missing required arguments; Expected either ('messages' and 'model') or ('messages', 'model' and 'stream') arguments to be given

In [None]:
from openai import OpenAI
client = OpenAI()

response = client.completions.create(
  model="gpt-3.5-turbo-instruct",
  prompt="Write a tagline for an ice cream shop."
)

In [None]:
from openai import OpenAI
client = OpenAI()

response = client.completions.create(
  model="gpt-3.5-turbo-instruct",
  prompt=k
)

In [None]:
print(response.choices[0].text.strip())

['hyperface_dev_db.customer', 'grimlock_dev_db.customer']


In [None]:
import ast
x = ast.literal_eval(response.choices[0].text.strip())

In [None]:
x

['hyperface_dev_db.customer', 'grimlock_dev_db.customer']

In [None]:
import json
json.loads(response.choices[0].text.strip())
response.choices[0].text.strip()

JSONDecodeError: Expecting value: line 1 column 2 (char 1)

In [None]:
response.choices[0].text.strip()

'"Scoops of happiness in every cone!"'

In [None]:
k

"Here are the tables available:\ngrimlock_dev_db.customer: customer_number, created_on, customer_id, email_id, last_updated, mobile\nhyperface_dev_db.issuer_customer: id, country_code, created_on, current_address, date_of_birth, email, first_name, gender, last_name, last_updated_on, middle_name, mobile, mobile_country_code, mobile_hash, mobile_masked, nationality, pancard, permanent_address, switch_metadata, title, issuer_id, schedule_ofcreated_on, switch_customer_number, kyc_status\nhyperface_dev_db.customer: id, country_code, created_on, current_address, date_of_birth, email, first_name, gender, last_name, last_updated_on, middle_name, mobile, mobile_country_code, mobile_hash, mobile_masked, nationality, pancard, permanent_address, preferred_name, switch_metadata, title, client_id, schedule_ofcreated_on, switch_customer_number, kyc_status, office_address\nhyperface_dev_db.customer_kyc_detail_kyc_proofs: customer_kyc_detail_id, kyc_proofs_id\ngrimlock_dev_db.customer_dump: id, account

In [61]:
import ast

def select_relevant_tables(query):
    candidates = get_top_N_related_tables(query)
    prompt = generate_selection_prompt(candidates, query)
    response = client.completions.create(
        model="gpt-3.5-turbo-instruct",
        prompt=prompt
    )
    return ast.literal_eval(response.choices[0].text.strip())

In [60]:
o = select_relevant_tables('What is the billing period for the credit card with switch card id 123456?')

Connected to MySQL server


SyntaxError: unexpected EOF while parsing (<unknown>, line 0)

In [58]:
ast.literal_eval(o)

['hyperface_dev_db.credit_card_program']