# Settings

In [1]:
import os

if os.getcwd().endswith("notebooks"):
    os.chdir("..")
print(os.getcwd())

/Users/cmcoutosilva/Projects/github/nl2sql-agent


In [2]:
from langchain_core.documents import Document
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_postgres import PGVector

from nl2sql.config import load_schema_config
from nl2sql.database.postgresql import PostgreSQLConnector
from nl2sql.knowledge_base.data_dictionary import DataDictionary
from nl2sql.knowledge_base.sql_examples import SQLExample

In [3]:
# Set up database connector
db_connector = PostgreSQLConnector(config_path="configs/database.yml")

# Creates a vector store for the database
vector_store = PGVector(
    connection=db_connector.engine,
    embeddings=OpenAIEmbeddings(),
    collection_name="nl2sql_embeddings",
)

# Vector Store

In [4]:
def get_documents_from_data_dictionary(
    data_dictionary: DataDictionary,
) -> list[Document]:
    """Get documents from data dictionary."""
    documents = []

    for database_name, database_info in data_dictionary.databases.items():
        for schema_name, schema_info in database_info.schemas.items():
            for table_name, table_info in schema_info.tables.items():
                # Format table information
                content = table_info.format_context()

                # Create metadata
                metadata = {
                    "type": "schema",
                    "database": database_name,
                    "schema": schema_name,
                    "table": table_name,
                    "primary_keys": ", ".join(table_info.primary_keys),
                    "foreign_keys": ", ".join(
                        col
                        for fk in table_info.foreign_keys
                        for col in fk["constrained_columns"]
                    ),
                }

                # Create Document object
                doc = Document(page_content=content, metadata=metadata)
                documents.append(doc)

    return documents


# Load data dictionary
data_dictionary = DataDictionary.from_inspector(
    inspector=db_connector.inspector, database_schema=load_schema_config()
)

In [5]:
# Get documents from data dictionary
shema_documents = get_documents_from_data_dictionary(data_dictionary)
display(shema_documents)

# Show formatted content for the first document
print(shema_documents[0].page_content)

[Document(metadata={'type': 'schema', 'database': 'olist_ecommerce', 'schema': 'ecommerce', 'table': 'customers', 'primary_keys': 'customer_id', 'foreign_keys': ''}, page_content='TABLE: customers\nDESCRIPTION: This dataset has information about the customer and its location. Use it to identify unique customers in the orders dataset and to find the orders delivery location. At our system each order is assigned to a unique customer_id. This means that the same customer will get different ids for different orders. The purpose of having a customer_unique_id on the dataset is to allow you to identify customers that made repurchases at the store. Otherwise you would find that each order had a different customer associated with.\nPRIMARY KEYS: customer_id\nCOLUMNS:\n  - customer_id (TEXT, NOT NULL): key to the orders dataset. Each order has a unique customer_id.\n  - customer_unique_id (TEXT, NOT NULL): unique identifier of a customer.\n  - customer_zip_code_prefix (TEXT, NULL): first five d

TABLE: customers
DESCRIPTION: This dataset has information about the customer and its location. Use it to identify unique customers in the orders dataset and to find the orders delivery location. At our system each order is assigned to a unique customer_id. This means that the same customer will get different ids for different orders. The purpose of having a customer_unique_id on the dataset is to allow you to identify customers that made repurchases at the store. Otherwise you would find that each order had a different customer associated with.
PRIMARY KEYS: customer_id
COLUMNS:
  - customer_id (TEXT, NOT NULL): key to the orders dataset. Each order has a unique customer_id.
  - customer_unique_id (TEXT, NOT NULL): unique identifier of a customer.
  - customer_zip_code_prefix (TEXT, NULL): first five digits of customer zip code
  - customer_city (TEXT, NULL): customer city name
  - customer_state (TEXT, NULL): customer state



In [6]:
def get_documents_from_sql_examples(
    sql_examples: dict[str, SQLExample],
) -> list[Document]:
    """Get documents from SQL examples."""
    documents = []

    for title, example in sql_examples.items():
        doc_content = example.format_context()
        doc_metadata = {"type": "example", "title": title}
        documents.append(Document(page_content=doc_content, metadata=doc_metadata))

    return documents


# Load SQL examples
sql_examples = SQLExample.from_yaml("knowledge/sql_examples.yml")

In [7]:
# Get documents from SQL examples
sql_documents = get_documents_from_sql_examples(sql_examples)
display(sql_documents)

# Show formatted content for the first document
print(sql_documents[0].page_content)

[Document(metadata={'type': 'example', 'title': 'total_orders'}, page_content='Question: How many orders are there in total?\n```sql\nSELECT COUNT(*)\nFROM "ecommerce"."orders"\n```'),
 Document(metadata={'type': 'example', 'title': 'orders_by_status'}, page_content='Question: What is the distribution of orders by status?\n```sql\nSELECT "order_status",\n       COUNT(*)\nFROM "ecommerce"."orders"\nGROUP BY "order_status"\n```'),
 Document(metadata={'type': 'example', 'title': 'top_cities'}, page_content='Question: Which cities have the most customers?\n```sql\nSELECT "customer_city",\n       COUNT(*)\nFROM "ecommerce"."customers"\nGROUP BY "customer_city"\nORDER BY COUNT(*) DESC\nLIMIT 10\n```'),
 Document(metadata={'type': 'example', 'title': 'orders_with_customer'}, page_content='Question: Show me orders with customer information\n```sql\nSELECT o."order_id",\n       o."order_status",\n       c."customer_city"\nFROM "ecommerce"."orders" o\nJOIN "ecommerce"."customers" c ON o."custome

Question: How many orders are there in total?
```sql
SELECT COUNT(*)
FROM "ecommerce"."orders"
```


In [8]:
# Add documents to vector store
vector_store.add_documents(shema_documents + sql_documents)

['3ad88a11-1fa3-4f9b-9cce-8b64e3a387b0',
 '89872387-9524-48fc-9bd6-d2f89742dcaf',
 '451d3384-0e7a-4f2e-8d1f-19914cc70bd6',
 '5fbce216-28f6-43a2-9e9a-4d0dae8938d3',
 '754814c0-93db-432a-833f-41440cd7c021',
 '1893c923-e3b8-4bc7-87c4-c60d3de5e9a0',
 'cbd82500-a16d-443f-9a74-29f70daea121',
 '9845b4a4-bcda-47cb-beac-b27152f99233',
 '27f38c76-833a-4d98-a292-30e84c42e885',
 '7062ea57-6e10-4f4e-ae76-4581ff1ea3f2',
 '6ec62464-dba4-42bb-ab50-e7568b35888c',
 '6b5431cf-bd1e-4ed1-baa5-514cd635be64',
 '69f965ec-2b87-474e-9e3f-0155e60c6390',
 'ee129d6b-50d3-4dfd-96a5-6b2d533eb12a',
 'a511eefc-9c35-422c-8f8d-c441ff6479ff',
 'fc64b075-6acb-4020-ab82-9be2301cc86e',
 'b4a07638-841e-4d29-a6ef-113e26b8150b']

In [9]:
# Search similar documents
vector_store.similarity_search("How many orders with review_score > 4?")

[Document(id='6b5431cf-bd1e-4ed1-baa5-514cd635be64', metadata={'type': 'example', 'title': 'total_orders'}, page_content='Question: How many orders are there in total?\n```sql\nSELECT COUNT(*)\nFROM "ecommerce"."orders"\n```'),
 Document(id='f4f036a7-d32f-4e04-838b-5fd819400315', metadata={'type': 'example', 'title': 'total_orders'}, page_content='Question: How many orders are there in total?\n```sql\nSELECT COUNT(*)\nFROM "ecommerce"."orders"\n```'),
 Document(id='b1fdf0d1-0c4f-49df-90c4-263a1399a986', metadata={'type': 'example', 'title': 'orders_by_status'}, page_content='Question: What is the distribution of orders by status?\n```sql\nSELECT "order_status",\n       COUNT(*)\nFROM "ecommerce"."orders"\nGROUP BY "order_status"\n```'),
 Document(id='69f965ec-2b87-474e-9e3f-0155e60c6390', metadata={'type': 'example', 'title': 'orders_by_status'}, page_content='Question: What is the distribution of orders by status?\n```sql\nSELECT "order_status",\n       COUNT(*)\nFROM "ecommerce"."ord

In [10]:
# Search similar documents with score
vector_store.similarity_search_with_score(
    "How many orders with review_score > 4?", k=10
)

[(Document(id='f4f036a7-d32f-4e04-838b-5fd819400315', metadata={'type': 'example', 'title': 'total_orders'}, page_content='Question: How many orders are there in total?\n```sql\nSELECT COUNT(*)\nFROM "ecommerce"."orders"\n```'),
  0.2258285649367736),
 (Document(id='6b5431cf-bd1e-4ed1-baa5-514cd635be64', metadata={'type': 'example', 'title': 'total_orders'}, page_content='Question: How many orders are there in total?\n```sql\nSELECT COUNT(*)\nFROM "ecommerce"."orders"\n```'),
  0.2258285649367736),
 (Document(id='b1fdf0d1-0c4f-49df-90c4-263a1399a986', metadata={'type': 'example', 'title': 'orders_by_status'}, page_content='Question: What is the distribution of orders by status?\n```sql\nSELECT "order_status",\n       COUNT(*)\nFROM "ecommerce"."orders"\nGROUP BY "order_status"\n```'),
  0.23915316099841222),
 (Document(id='69f965ec-2b87-474e-9e3f-0155e60c6390', metadata={'type': 'example', 'title': 'orders_by_status'}, page_content='Question: What is the distribution of orders by statu