Enter database credentials, to the [TPC-H Snowflake sample database](https://docs.snowflake.com/en/user-guide/sample-data-tpch).

In [4]:
import os

account = "hdb90888"
username = "cristiscu"
password = os.environ['SNOWSQL_PWD']
database = "SNOWFLAKE_SAMPLE_DATA"
schema = "TPCH_SF1"
warehouse = "COMPUTE_WH"
role = "ACCOUNTADMIN"

Connect to the Snowflake sample database from LangChain.

In [5]:
from langchain_community.utilities import SQLDatabase
url = f"snowflake://{username}:{password}@{account}/{database}/{schema}?warehouse={warehouse}&role={role}"
db = SQLDatabase.from_uri(url)
print(db.table_info)


CREATE TABLE customer (
	c_custkey DECIMAL(38, 0) NOT NULL, 
	c_name VARCHAR(25) NOT NULL, 
	c_address VARCHAR(40) NOT NULL, 
	c_nationkey DECIMAL(38, 0) NOT NULL, 
	c_phone VARCHAR(15) NOT NULL, 
	c_acctbal DECIMAL(12, 2) NOT NULL, 
	c_mktsegment VARCHAR(10), 
	c_comment VARCHAR(117)
)

/*
3 rows from customer table:
c_custkey	c_name	c_address	c_nationkey	c_phone	c_acctbal	c_mktsegment	c_comment
60001	Customer#000060001	9Ii4zQn9cX	14	24-678-784-9652	9957.56	HOUSEHOLD	l theodolites boost slyly at the platelets: permanently ironic packages wake slyly pend
60002	Customer#000060002	ThGBMjDwKzkoOxhz	15	25-782-500-8435	742.46	BUILDING	 beans. fluffily regular packages
60003	Customer#000060003	Ed hbPtTXMTAsgGhCr4HuTzK,Md2	16	26-859-847-7640	2526.92	BUILDING	fully pending deposits sleep quickly. blithely unusual accounts across the blithely bold requests ar
*/


CREATE TABLE lineitem (
	l_orderkey DECIMAL(38, 0) NOT NULL, 
	l_partkey DECIMAL(38, 0) NOT NULL, 
	l_suppkey DECIMAL(38, 0) NOT NU

Connect to ChatGPT (the LLM).

In [6]:
from langchain_openai import OpenAI
llm = OpenAI(openai_api_key=os.environ["OPENAI_API_KEY"])

Setup the LLM chain, to generate SQL queries with LangChain.

In [7]:
from langchain.chains import create_sql_query_chain
database_chain = create_sql_query_chain(llm, db)

Connect to Snowflake, with a Snowpark session.

In [8]:
from snowflake.snowpark import Session
pars = {
    "account": account,
    "user": username,
    "password": password,
    "database": database,
    "schema": schema,
    "warehouse": warehouse,
    "role": role
}
session = Session.builder.configs(pars).create()

Create a generic function, to generate and run query.

In [9]:
def run(prompt):
    sql_query = database_chain.invoke({"question": prompt})
    print(sql_query)
    session.sql(sql_query).show()

Run a first query.

In [11]:
# run("Select top 10 paying customers from Canada")
run("Select top 10 customers from Canada with highest sum of C_ACCTBAL value")

SELECT c_name, c_address, c_phone, SUM(c_acctbal) as total_acctbal
FROM customer
WHERE c_nationkey = (SELECT n_nationkey FROM nation WHERE n_name = 'CANADA')
GROUP BY c_name, c_address, c_phone
ORDER BY total_acctbal DESC
LIMIT 10
--------------------------------------------------------------------------------------------------
|"C_NAME"            |"C_ADDRESS"                            |"C_PHONE"        |"TOTAL_ACCTBAL"  |
--------------------------------------------------------------------------------------------------
|Customer#000112796  |hC8PTDi,13j9N                          |13-428-273-3134  |9998.32          |
|Customer#000080736  |mqVUr4Voq5XPO2lswhTkC0sY5p7w           |13-870-874-2736  |9998.01          |
|Customer#000064147  |TfbulTzPrCK                            |13-282-212-7949  |9997.50          |
|Customer#000121024  |6bMiL9MlKgeeo                          |13-422-755-1125  |9996.90          |
|Customer#000138209  |uHUzB66c4u3cj4                         |13-856-253-614

Run a GROUP BY query.

In [12]:
run("Show me the total of Customers per Nation")

SELECT c_nationkey, COUNT(c_custkey) AS total_customers
FROM customer
GROUP BY c_nationkey
-------------------------------------
|"C_NATIONKEY"  |"TOTAL_CUSTOMERS"  |
-------------------------------------
|21             |6008               |
|23             |6011               |
|5              |5952               |
|14             |5992               |
|19             |6100               |
|17             |5975               |
|10             |6009               |
|7              |5908               |
|1              |5975               |
|4              |5995               |
-------------------------------------



Try to generate the complex query described in the documentation.

In [13]:
run("""Show me a query that lists totals for extended price, discounted extended price, 
    discounted extended price plus tax, average quantity, average extended price, 
    and average discount. These aggregates are grouped by RETURNFLAG and LINESTATUS, 
    and listed in ascending order of RETURNFLAG and LINESTATUS. 
    A count of the number of line items in each group is included""")

SELECT 
    l_returnflag, 
    l_linestatus, 
    SUM(l_extendedprice) as total_extended_price,
    SUM(l_extendedprice * (1 - l_discount)) as discounted_extended_price,
    SUM(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as discounted_extended_price_plus_tax,
    AVG(l_quantity) as average_quantity,
    AVG(l_extendedprice) as average_extended_price,
    AVG(l_discount) as average_discount,
    COUNT(*) as line_item_count
FROM lineitem
GROUP BY l_returnflag, l_linestatus
ORDER BY l_returnflag ASC, l_linestatus ASC
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"L_RETURNFLAG"  |"L_LINESTATUS"  |"TOTAL_EXTENDED_PRICE"  |"DISCOUNTED_EXTENDED_PRICE"  |"DISCOUNTED_EXTENDED_PRICE_PLUS_TAX"  |"AVERAGE_QUANTITY"  |"AVERAGE_EXTENDED_PRICE"  |"AVERAGE_DISCOUNT"  |"LINE_ITEM_COUNT"  |
-----------------------------------