In [1]:
import os
import openai
import pandas as pd
import sqlalchemy
from sqlalchemy import create_engine
from sqlalchemy import text

In [2]:
openai.api_key = os.getenv('OPENAI_API_KEY')

In [3]:
df = pd.read_csv('sales_data_sample.csv')

In [4]:
df.head()

Unnamed: 0,ORDERNUMBER,QUANTITYORDERED,PRICEEACH,SALES,ORDERDATE,QTR_ID,MONTH_ID,YEAR_ID,PRODUCTLINE,PHONE,ADDRESSLINE1,CITY,STATE,POSTALCODE,COUNTRY,CONTACTLASTNAME,CONTACTFIRSTNAME
0,10107,30,95.7,2871.0,2/24/2003 0:00,1,2,2003,Motorcycles,2125557818,897 Long Airport Avenue,NYC,NY,10022.0,USA,Yu,Kwai
1,10121,34,81.35,2765.9,5/7/2003 0:00,2,5,2003,Motorcycles,26.47.1555,59 rue de l'Abbaye,Reims,,51100.0,France,Henriot,Paul
2,10134,41,94.74,3884.34,7/1/2003 0:00,3,7,2003,Motorcycles,+33 1 46 62 7555,27 rue du Colonel Pierre Avia,Paris,,75508.0,France,Da Cunha,Daniel
3,10145,45,83.26,3746.7,8/25/2003 0:00,3,8,2003,Motorcycles,6265557265,78934 Hillside Dr.,Pasadena,CA,90003.0,USA,Young,Julie
4,10159,49,100.0,5205.27,10/10/2003 0:00,4,10,2003,Motorcycles,6505551386,7734 Strong St.,San Francisco,CA,,USA,Brown,Julie


In [5]:
df.groupby('QTR_ID').sum(numeric_only=True)['SALES']

QTR_ID
1    2350817.73
2    2048120.30
3    1758910.81
4    3874780.01
Name: SALES, dtype: float64

In [6]:
# Create temporary DB in RAM

temp_db = create_engine('sqlite:///:memory:', echo=True)

In [7]:
# Push Pandas' DataFrame into the temporary DB

data = df.to_sql(name='Sales', con=temp_db)

2023-03-19 17:38:42,959 INFO sqlalchemy.engine.Engine PRAGMA main.table_info("Sales")
2023-03-19 17:38:42,960 INFO sqlalchemy.engine.Engine [raw sql] ()
2023-03-19 17:38:42,961 INFO sqlalchemy.engine.Engine PRAGMA temp.table_info("Sales")
2023-03-19 17:38:42,962 INFO sqlalchemy.engine.Engine [raw sql] ()
2023-03-19 17:38:42,963 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2023-03-19 17:38:42,964 INFO sqlalchemy.engine.Engine 
CREATE TABLE "Sales" (
	"index" BIGINT, 
	"ORDERNUMBER" BIGINT, 
	"QUANTITYORDERED" BIGINT, 
	"PRICEEACH" FLOAT, 
	"SALES" FLOAT, 
	"ORDERDATE" TEXT, 
	"QTR_ID" BIGINT, 
	"MONTH_ID" BIGINT, 
	"YEAR_ID" BIGINT, 
	"PRODUCTLINE" TEXT, 
	"PHONE" TEXT, 
	"ADDRESSLINE1" TEXT, 
	"CITY" TEXT, 
	"STATE" TEXT, 
	"POSTALCODE" TEXT, 
	"COUNTRY" TEXT, 
	"CONTACTLASTNAME" TEXT, 
	"CONTACTFIRSTNAME" TEXT
)


2023-03-19 17:38:42,965 INFO sqlalchemy.engine.Engine [no key 0.00042s] ()
2023-03-19 17:38:42,966 INFO sqlalchemy.engine.Engine CREATE INDEX "ix_Sales_index" ON "Sales" (

In [8]:
# SQL query on temporary DB

# makes the connection
with temp_db.connect() as conn:
    # run code indentation/block
    result = conn.execute(text("SELECT SUM(SALES) from Sales"))
    # auto close connection

2023-03-19 17:38:43,019 INFO sqlalchemy.engine.Engine SELECT SUM(SALES) from Sales
2023-03-19 17:38:43,020 INFO sqlalchemy.engine.Engine [generated in 0.00126s] ()


In [9]:
result.all()

[(10032628.85000001,)]

In [10]:
### sqlite SQL tables, with their properties:
#
# Employee(id, name, department_id)
# Department(id, name, address)
# Salary_Payments(id, employee_id, amount, date)
#
### A query to list the names of the departments which employed more than 10 
# SELECT

In [11]:
def create_table_definition(df):
    prompt = """### sqlite SQL tables, with its properties:
    #
    # Sales({})
    #
    """.format(",".join(str(col) for col in df.columns))
    
    return prompt

In [12]:
print(create_table_definition(df))

### sqlite SQL tables, with its properties:
    #
    # Sales(ORDERNUMBER,QUANTITYORDERED,PRICEEACH,SALES,ORDERDATE,QTR_ID,MONTH_ID,YEAR_ID,PRODUCTLINE,PHONE,ADDRESSLINE1,CITY,STATE,POSTALCODE,COUNTRY,CONTACTLASTNAME,CONTACTFIRSTNAME)
    #
    


In [13]:
def prompt_input():
    nlp_text = input("Enter the info you want: ")
    return nlp_text

In [14]:
prompt_input()

Enter the info you want: return the sum of sales per postal code


'return the sum of sales per postal code'

In [15]:
def combine_prompts(df, query_prompt):
    definition = create_table_definition(df)
    query_init_string = f"### A query to answer: {query_prompt}\nSELECT"
    
    return definition+query_init_string

In [16]:
nlp_text = prompt_input()
prompt = combine_prompts(df, nlp_text)

Enter the info you want: return the sum of sales per postal code


In [17]:
print(prompt)

### sqlite SQL tables, with its properties:
    #
    # Sales(ORDERNUMBER,QUANTITYORDERED,PRICEEACH,SALES,ORDERDATE,QTR_ID,MONTH_ID,YEAR_ID,PRODUCTLINE,PHONE,ADDRESSLINE1,CITY,STATE,POSTALCODE,COUNTRY,CONTACTLASTNAME,CONTACTFIRSTNAME)
    #
    ### A query to answer: return the sum of sales per postal code
SELECT


In [18]:
response = openai.Completion.create(model='code-davinci-002',
                                    prompt=combine_prompts(df, nlp_text), 
                                    temperature=0, 
                                    max_tokens=150, 
                                    top_p=1.0, 
                                    frequency_penalty=0, 
                                    presence_penalty=0, 
                                    stop=['#', ';', '\n'])

In [19]:
response['choices'][0]['text']

' POSTALCODE, SUM(SALES) FROM Sales GROUP BY POSTALCODE'

In [20]:
def handle_response(response):
    query = response['choices'][0]['text']
    if query.startswith(" "):
        query = "SELECT" + query
    return query

In [21]:
handle_response(response)

'SELECT POSTALCODE, SUM(SALES) FROM Sales GROUP BY POSTALCODE'

In [22]:
with temp_db.connect() as conn:
    result = conn.execute(text(handle_response(response)))

2023-03-19 17:39:33,096 INFO sqlalchemy.engine.Engine SELECT POSTALCODE, SUM(SALES) FROM Sales GROUP BY POSTALCODE
2023-03-19 17:39:33,097 INFO sqlalchemy.engine.Engine [generated in 0.00082s] ()


In [23]:
result.all()

[(None, 272407.14),
 ('10022', 560787.7699999998),
 ('10100', 94117.26000000002),
 ('106-0032', 120562.73999999996),
 ('1203', 117713.55999999998),
 ('1227 MM', 94015.73),
 ('13008', 74936.14),
 ('1734', 145041.6),
 ('2', 57756.43),
 ('2060', 153996.13000000003),
 ('2067', 151570.98000000004),
 ('21240', 111250.37999999996),
 ('24067', 85555.98999999998),
 ('24100', 137955.72000000003),
 ('28023', 170257.33000000005),
 ('28034', 912294.1100000002),
 ('3004', 200995.40999999997),
 ('31000', 70488.44),
 ('3150', 64591.46000000001),
 ('4101', 59469.11999999999),
 ('4110', 116599.19),
 ('41101', 54723.62),
 ('42100', 142601.33000000002),
 ('44000', 204304.86),
 ('5020', 149798.63),
 ('50553', 207874.86),
 ('50739', 100306.58),
 ('51003', 154069.65999999997),
 ('51100', 135042.94),
 ('51247', 139243.99999999994),
 ('530-0003', 67605.07),
 ('58339', 165255.20000000004),
 ('59000', 69052.41),
 ('60528', 85171.58999999998),
 ('62005', 131685.30000000002),
 ('67000', 80438.48),
 ('69004', 14287