# Settings

In [1]:
import os

if os.getcwd().endswith("notebooks"):
    os.chdir("..")
print(os.getcwd())

/Users/cmcoutosilva/Projects/github/nl2sql-agent


In [2]:
import yaml
from sqlalchemy import create_engine, inspect

## Database connection

In [3]:
db_uri = "postgresql+psycopg://postgres:postgres@localhost:5432/olist_ecommerce"

engine = create_engine(db_uri)
inspector = inspect(engine)

# Data Dictionary

In [4]:
from pydantic import BaseModel

In [5]:
# Target table and schema
table = "orders"
schema = "ecommerce"

# Extract column information
columns = inspector.get_columns(table, schema=schema)

# Extract primary key information
primary_keys = inspector.get_pk_constraint(table, schema=schema)

# Extract foreign key information
foreign_keys = inspector.get_foreign_keys(table, schema=schema)

In [6]:
class ColumnInfo(BaseModel):
    """Information about a database column."""

    name: str
    description: str
    type: str
    is_primary_key: bool
    is_nullable: bool
    foreign_keys: list[dict]

    @staticmethod
    def _extract_type(column: dict) -> str:
        """Extract the type of a column from the database."""
        column_type = str(column["type"])
        return column_type if column_type != "NULL" else "USER-DEFINED"


class TableInfo(BaseModel):
    """Information about a database table."""

    name: str
    schema_name: str
    description: str
    primary_keys: list[str]
    foreign_keys: list[dict]
    columns: list[ColumnInfo]

    @classmethod
    def from_inspector(cls, inspector, table, schema) -> "TableInfo":
        """Create TableInfo from SQLAlchemy inspector."""
        # Extract column information
        columns = inspector.get_columns(table, schema=schema)

        # Extract primary key information
        primary_keys = inspector.get_pk_constraint(table, schema=schema)

        # Extract foreign key information
        foreign_keys = inspector.get_foreign_keys(table, schema=schema)

        # Extract columns' info
        column_info = [
            ColumnInfo(
                name=column["name"],
                description=column.get("comment", "") or "",
                type=ColumnInfo._extract_type(column),
                is_primary_key=column["name"]
                in primary_keys.get("constrained_columns", []),
                is_nullable=column["nullable"],
                foreign_keys=[
                    {
                        "referred_table": fk.get("referred_table"),
                        "referred_schema": fk.get("referred_schema"),
                        "referred_columns": fk.get("referred_columns"),
                        "name": fk.get("name"),
                    }
                    for fk in foreign_keys
                    if column["name"] in fk.get("constrained_columns", [])
                ],
            )
            for column in columns
        ]

        return cls(
            name=table,
            schema_name=schema,
            description=(
                inspector.get_table_comment(table, schema=schema).get("text", "") or ""
            ),
            primary_keys=primary_keys.get("constrained_columns", []),
            foreign_keys=[
                {
                    "constrained_columns": fk.get("constrained_columns", []),
                    "referred_table": fk.get("referred_table"),
                    "referred_schema": fk.get("referred_schema"),
                    "referred_columns": fk.get("referred_columns"),
                    "name": fk.get("name"),
                }
                for fk in foreign_keys
            ],
            columns=column_info,
        )

    def format_context(self) -> str:
        """Format table information as a string for context retrieval."""
        # Start with table name and description
        context = f"TABLE: {self.name}\n"
        if self.description:
            context += f"DESCRIPTION: {self.description}\n"

        # Add primary keys
        if self.primary_keys:
            context += f"PRIMARY KEYS: {', '.join(self.primary_keys)}\n"

        # Add foreign keys
        if self.foreign_keys:
            context += "FOREIGN KEYS:\n"
            for fk in self.foreign_keys:
                constrained = ", ".join(fk["constrained_columns"])
                referred = ", ".join(fk["referred_columns"])
                context += (
                    f"  - {constrained} -> "
                    f"{fk['referred_schema']}.{fk['referred_table']}.{referred}\n"
                )

        # Add columns with descriptions
        context += "COLUMNS:\n"
        for column in self.columns:
            if not column.description:
                continue
            # Format column type, nullability, and description
            is_nullable = "NULL" if column.is_nullable else "NOT NULL"
            context += (
                f"  - {column.name} ({column.type}, {is_nullable}): "
                f"{column.description}\n"
            )

        return context


table_info = TableInfo.from_inspector(inspector, table, schema)
print(table_info.format_context())

TABLE: orders
DESCRIPTION: This is the core dataset. From each order you might find all other information.
PRIMARY KEYS: order_id
FOREIGN KEYS:
  - customer_id -> ecommerce.customers.customer_id
COLUMNS:
  - order_id (TEXT, NOT NULL): unique identifier of the order.
  - customer_id (TEXT, NOT NULL): key to the customer dataset. Each order has a unique customer_id.
  - order_status (TEXT, NULL): Reference to the order status (delivered, shipped, etc).
  - order_purchase_timestamp (TIMESTAMP, NULL): Shows the purchase timestamp.
  - order_approved_at (TIMESTAMP, NULL): Shows the payment approval timestamp.
  - order_delivered_carrier_date (TIMESTAMP, NULL): Shows the order posting timestamp. When it was handled to the logistic partner.
  - order_delivered_customer_date (TIMESTAMP, NULL): Shows the actual order delivery date to the customer.
  - order_estimated_delivery_date (TIMESTAMP, NULL): Shows the estimated delivery date that was informed to customer at the purchase moment.



In [7]:
class SchemaInfo(BaseModel):
    """Information about a database schema."""

    name: str
    tables: dict[str, TableInfo]

    def format_context(self) -> str:
        """Format schema information as a string for context retrieval."""
        context = f"SCHEMA: {self.name}\n\n"
        for table_info in self.tables.values():
            context += table_info.format_context() + "\n\n"
        return context


class DatabaseInfo(BaseModel):
    """Information about a database."""

    name: str
    schemas: dict[str, SchemaInfo]

    def format_context(self) -> str:
        """Format database information as a string for context retrieval."""
        context = f"DATABASE: {self.name}\n\n"
        for schema_info in self.schemas.values():
            context += schema_info.format_context()
        return context

In [8]:
class DataDictionary(BaseModel):
    """Main data dictionary containing all database information."""

    databases: dict[str, DatabaseInfo]

    @classmethod
    def from_inspector(
        cls,
        inspector,
        database_schema,
    ) -> "DataDictionary":
        """Create DataDictionary from SQLAlchemy inspector."""
        databases = {}

        for database_name, schemas in database_schema.items():
            schema_dict = {}

            for schema_name, tables in schemas.items():
                table_dict = {}

                for table_name in tables:
                    table_info = TableInfo.from_inspector(
                        inspector, table_name, schema_name
                    )

                    table_dict[table_name] = table_info

                schema_dict[schema_name] = SchemaInfo(
                    name=schema_name, tables=table_dict
                )

            databases[database_name] = DatabaseInfo(
                name=database_name, schemas=schema_dict
            )

        return cls(databases=databases)

    def format_context(self) -> str:
        """Format all schema information as a string for context retrieval."""
        context = ""
        for database_info in self.databases.values():
            context += database_info.format_context()
        return context

In [9]:
with open("configs/schema.yml") as f:
    database_schema = yaml.safe_load(f)

data_dictionary = DataDictionary.from_inspector(inspector, database_schema)
print(data_dictionary.format_context())

DATABASE: olist_ecommerce

SCHEMA: ecommerce

TABLE: customers
DESCRIPTION: This dataset has information about the customer and its location. Use it to identify unique customers in the orders dataset and to find the orders delivery location. At our system each order is assigned to a unique customer_id. This means that the same customer will get different ids for different orders. The purpose of having a customer_unique_id on the dataset is to allow you to identify customers that made repurchases at the store. Otherwise you would find that each order had a different customer associated with.
PRIMARY KEYS: customer_id
COLUMNS:
  - customer_id (TEXT, NOT NULL): key to the orders dataset. Each order has a unique customer_id.
  - customer_unique_id (TEXT, NOT NULL): unique identifier of a customer.
  - customer_zip_code_prefix (TEXT, NULL): first five digits of customer zip code
  - customer_city (TEXT, NULL): customer city name
  - customer_state (TEXT, NULL): customer state


TABLE: geolo