## imports

In [12]:
import csv
import logging
import os
import sys
import json

import boto3
from sqlalchemy import create_engine

from llama_index.core import SQLDatabase, VectorStoreIndex
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts import Prompt, PromptTemplate
from llama_index.core.schema import TextNode
from llama_index.embeddings.bedrock import BedrockEmbedding
from llama_index.llms.bedrock import Bedrock

from prompt_templates import RESPONSE_TEMPLATE_STR, SQL_TEMPLATE_STR, TABLE_DETAILS

## logging

In [2]:
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

## variables & connections

In [3]:
region = "us-east-1"
athena_bucket_name = os.getenv("ATHENA_BUCKET_NAME")
text2sql_database = "main"

# athena connection url for pulling metadata and querying
athena_connection_url = f"awsathena+rest://athena.{region}.amazonaws.com/{text2sql_database}?s3_staging_dir=s3://{athena_bucket_name}"

# used for embeddings
bedrock_client = boto3.client("bedrock-runtime", region_name=region)
embed_model = BedrockEmbedding(model_name="amazon.titan-embed-g1-text-02", region_name=region)
fewshot_examples_path = "dynamic_examples.csv"
NUM_FEW_SHOT_EXAMPLES_TO_CONSIDER = 2

INFO:botocore.credentials:Found credentials in environment variables.
INFO:botocore.credentials:Found credentials in environment variables.


In [5]:
# embeddings = embed_model.get_text_embedding("hello world")
# print(embeddings)

## 1. few-shot examples retrieval

In [6]:
def get_few_shot_examples(embed_model, user_input, fewshot_examples_path, **kwargs):
    with open(fewshot_examples_path, newline="", encoding="utf-8-sig") as csvfile:
        reader = csv.DictReader(csvfile, skipinitialspace=True)
        data_dict = {row["example_input_question"]: row for row in reader}
        few_shot_nodes = [TextNode(text=json.dumps(q)) for q in data_dict.keys()]

    # VectorStoreIndex takes your Documents and splits them up into Nodes.
    # It then creates vector embeddings of the text of every node, ready to be queried by an LLM.
    # Vector stores accept a list of Node objects and build an index from them

    few_shot_index = VectorStoreIndex(few_shot_nodes, embed_model=embed_model)
    few_shot_retriever = few_shot_index.as_retriever(similarity_top_k=NUM_FEW_SHOT_EXAMPLES_TO_CONSIDER)

    retrieved_nodes = few_shot_retriever.retrieve(user_input)
    result_strs = []
    for node in retrieved_nodes:
        content = json.loads(node.get_content())
        raw_dict = data_dict[content]
        example = [f"{k.capitalize()}: {v}" for k, v in raw_dict.items()]
        result_strs.append("\n".join(example))

    example_set = "\n\n".join(result_strs)
    logging.info(f"Example set provided:\n{example_set}")
    return example_set

## 2. database schema retrieval

In [7]:
## get connect to db
sql_engine = create_engine(athena_connection_url)
sql_database = SQLDatabase(sql_engine)

# list of lightweight representations of a SQL tables: name & context string
table_schema_objs = [
    SQLTableSchema(table_name=table, context_str=TABLE_DETAILS[table])
    for table in sql_database._all_tables
]

# takes in a SQLDatabase and produces a Node object for each SQLTableSchema object passed into the ObjectIndex constructor below
table_node_mapper = SQLTableNodeMapping(sql_database)

# gives us a VectorStoreIndex where each Node contains table schema
obj_index = ObjectIndex.from_objects(
    table_schema_objs, table_node_mapper, VectorStoreIndex, embed_model=embed_model
)

INFO:botocore.credentials:Found credentials in environment variables.


  sql_engine = create_engine(athena_connection_url)


In [None]:
# object_retriever = obj_index.as_retriever(similarity_top_k=1)
# object_retriever.retrieve("llamaindex")

## 3. txt to sql retriever

In [8]:
user_input = "What's the most ordered product?"

In [9]:
prompt_template = PromptTemplate(
    template=SQL_TEMPLATE_STR,
    function_mappings={
        "few_shot_examples": lambda **kwargs: get_few_shot_examples(
            embed_model, user_input, fewshot_examples_path, **kwargs
        )
    },
)

In [10]:
llm = Bedrock(
    model="anthropic.claude-3-sonnet-20240229-v1:0",
    region_name=region,
    max_tokens=1024,
    temperature=0,
)

query_engine = SQLTableRetrieverQueryEngine(
    sql_database,
    obj_index.as_retriever(similarity_top_k=10),
    text_to_sql_prompt=prompt_template,
    response_synthesis_prompt=Prompt(RESPONSE_TEMPLATE_STR),
    llm=llm,
)

INFO:botocore.credentials:Found credentials in environment variables.


In [13]:
response = query_engine.query(user_input)

INFO:llama_index.core.indices.struct_store.sql_retriever:> Table desc str: Table 'products' has columns: productid (BIGINT), productnumber (VARCHAR): 'Uniquely identifies a product within an ERP system', productname (VARCHAR), modelname (VARCHAR), makeflag (VARCHAR), standardcost (FLOAT), listprice (FLOAT), subcategoryid (BIGINT), . The table description is: Products are supplied by vendors. Products belong to subcategories

Table 'orders' has columns: salesorderid (BIGINT), salesorderdetailid (BIGINT), orderdate (VARCHAR), duedate (VARCHAR), shipdate (VARCHAR), employeeid (BIGINT), customerid (BIGINT), subtotal (FLOAT), taxamt (FLOAT), freight (FLOAT), totaldue (FLOAT), productid (BIGINT), orderqty (BIGINT), unitprice (FLOAT), unitpricediscount (FLOAT), linetotal (FLOAT), . The table description is: Events of customers purchasing from employees

Table 'customers' has columns: customerid (BIGINT), firstname (VARCHAR), lastname (VARCHAR), fullname (VARCHAR), . The table description is: 

In [14]:
response

Response(response="According to the latest information, the most ordered product is the 'Long-Sleeve Logo Jersey, L' with a total order quantity of 6,140.", source_nodes=[NodeWithScore(node=TextNode(id_='4884f111-d969-4680-a7d7-23a95bdffa23', embedding=None, metadata={'sql_query': 'SELECT p.productname, SUM(o.orderqty) AS total_ordered\nFROM orders o\nJOIN products p ON o.productid = p.productid\nGROUP BY p.productname\nORDER BY total_ordered DESC\nLIMIT 1;', 'result': [('Long-Sleeve Logo Jersey, L', 6140)], 'col_keys': ['productname', 'total_ordered']}, excluded_embed_metadata_keys=['sql_query', 'result', 'col_keys'], excluded_llm_metadata_keys=['sql_query', 'result', 'col_keys'], relationships={}, text="[('Long-Sleeve Logo Jersey, L', 6140)]", mimetype='text/plain', start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=None)], metadata={'4884f111-d969-4680-a7d7-23a95bdffa23': {'sql_quer

In [15]:
print(response)

According to the latest information, the most ordered product is the 'Long-Sleeve Logo Jersey, L' with a total order quantity of 6,140.


In [16]:
print(response.metadata["sql_query"])

SELECT p.productname, SUM(o.orderqty) AS total_ordered
FROM orders o
JOIN products p ON o.productid = p.productid
GROUP BY p.productname
ORDER BY total_ordered DESC
LIMIT 1;
