# Text-to-SQL Using LLAMA-INDEX 

In [10]:
import os
import openai

from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI


from sqlalchemy import (
    create_engine, 
    MetaData, 
    Table, 
    Column, 
    Integer, 
    String, 
    select, 
    insert, 
    text)
from IPython.display import display, Markdown

In [2]:
import getpass
import os

os.environ['OPENAI_API_KEY'] = getpass.getpass()

In [3]:
# Instantiate sqlite engine
engine = create_engine('sqlite:///:memory:')
metadata_obj = MetaData()

In [4]:
# Create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name, 
    metadata_obj, 
    Column('city_name', String(16), primary_key=True),
    Column('population', Integer),
    Column('country', String(16), nullable=False),
)

metadata_obj.create_all(engine)

### Define SQL Database

In [7]:
# Instantiate OpenAI LLM
llm = OpenAI(temperature=0.1, model="gpt-4")
# Instantiate SQL Database
sql_database = SQLDatabase(engine, include_tables=['city_stats'])

# Add Testing data to SQL database
rows = [
    {
        "city_name": "Toronto",
        "population": 2731571,
        "country": "Canada"
    }, 
    {
        "city_name": "New York",
        "population": 8336817,
        "country": "USA"
    },
    {
        "city_name": "Los Angeles",
        "population": 3979576,
        "country": "USA"
    }
]

for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.begin() as conn:
        cursor = conn.execute(stmt)

In [8]:
# View Current Table
stmt = select(
    city_stats_table.c.city_name,
    city_stats_table.c.population,
    city_stats_table.c.country
).select_from(city_stats_table)

with engine.connect() as conn:
    results = conn.execute(stmt).fetchall()
    print(results)

[('Toronto', 2731571, 'Canada'), ('New York', 8336817, 'USA'), ('Los Angeles', 3979576, 'USA')]


In [11]:
with engine.connect() as con:
    rows = con.execute(text("SELECT city_name from city_stats"))
    for row in rows:
        print(row)

('Los Angeles',)
('New York',)
('Toronto',)


In [1]:
import pandas as pd
from pathlib import Path

data_dir = Path("learning-llamaindex/WikiTableQuestions/csv/200-csv")
csv_files = sorted([f for f in data_dir.glob("*.csv")])
dfs = []
for csv_file in csv_files:
    print(f"Processing File: {csv_file}")
    try:
        df = pd.read_csv(csv_file)
        dfs.append(df)
    except Exception as e:
        print(f"Error Parsing {csv_file}: {str(e)}")

## Extract Table Name and Summary from each Table

In [3]:
tableinfo_dir = "WikiTableQuestions_TableInfo"
!mkdir {tableinfo_dir}

In [2]:
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.llms.openai import OpenAI

class TableInfo(BaseModel):
    """
    Information regarding a structured table.
    """
    table_name: str = Field(
        ..., description="table name (must be underscore and NO spaces)"
    )
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )

prompt_str = """\
Give me a summary of the table with the following JSON format.
- The table name must be unique to the table and describe it while being concise.
- Do NOT output a generic table name (e.g. table, my_table)

Do NOT make the table name one of the following: {exclude_table_name_list}

Table:
{table_str}

Summary: """
program = LLMTextCompletionProgram.from_defaults(
    output_cls=TableInfo, 
    llm=OpenAI(model="gpt-4"), 
    prompt_template_str=prompt_str
)

In [4]:
import json

def _get_tableinfo_with_index(idx: int) -> str:
    results_gen = Path(tableinfo_dir).glob(f"{idx}_*")
    results_list = list(results_gen)
    if len(results_list) == 0:
        return None
    elif len(results_list) == 1:
        path = results_list[0]
        return TableInfo.parse_file(path)
    else:
        raise ValueError(
            f"More than one file matching index: {list(results_gen)}"
        )

table_names = set()
table_infos = []
for idx, df in enumerate(dfs):
    if table_info:
        table_infos.append(table_info)
    else:
        while True:
            df_str = df.head(10).to_csv()
            table_info = program(
                table_str=df_str, 
                exclude_table_name_list=str(list(table_names))
            )
            table_name = table_info.table_name
            print(f"Processed table: {table_name}")
            if table_name not in table_names:
                table_names.add(table_name)
                break
            else:
                print(f"Table name {table_name} already exists, trying again.")
                pass
        out_file = f"{tableinfo_dir}/{idx}_{table_name}.json"
        json.dump(table_info.dict(), open(out_file, 'w'))
    table_infos.append(table_info)

## Put Data in SQL Database

In [5]:
from sqlalchemy import (create_engine, MetaData, Table, Column, String, Integer)
import re

# Function to create a sanitized column name
def sanitize_column_name(col_name):
    # Remove special characters and replace spaces with underscores
    return re.sub(r"\W+", "_", col_name)

# Function to create a table from a DataFrame using SQLAlchemy
def create_table_from_dataframe(
    df: pd.DataFrame, 
    table_name: str, 
    engine, 
    metadata_obj
):
    # Sanitize Column names
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    # Dynamically create columns based on DataFrame columns and data types
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype, in zip(df.columns, df.types)
    ]
    # Create a table with the defined columns
    table = Table(table_name, metadata_obj, *columns)

    # Create the table in the database
    metadata_obj.create_all(engine)

    # Insert data from dataframe into the table
    with engine.connect() as conn:
        for _, row in df.iterrows():
            insert_stmt = table.insert().values(**row.to_dict())
            conn.execute(insert_stmt)
        conn.commit()

engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
for idx, df in enumerate(dfs):
    tableinfo = _get_tableinfo_with_index(idx)
    print(f"Creating table: {tableinfo.table_name}")
    create_table_from_dataframe(df, tableinfo.table_name, engine, metadata_obj)

In [6]:
# Setup Arize Phoenix for logging/observability
import phoenix as px
import llama_index.core

px.launch_app()
llama_index.core.set_global_handler("arize_phoenix")

🌍 To view the Phoenix app in your browser, visit http://localhost:6006/
📺 To view the Phoenix app in a notebook, run `px.active_session().view()`
📖 For more information on how to use Phoenix, check out https://docs.arize.com/phoenix


In [10]:
px.active_session().view()

📺 Opening a view to the Phoenix app. The app is running at http://localhost:6006/


## Text-to-SQL with Query-Time Table Retrieval

- Object index + retriever to store table schemas
- SQLDatabase Object to connect to the above tables + SQLRetriever
- Text-to-SQL Prompt
- Response Synthesis Prompt
- LLM

In [9]:
from llama_index.core.objects import (SQLTableNodeMapping, ObjectIndex, SQLTableSchema)
from llama_index.core import SQLDatabase, VectorStoreIndex

sql_database = SQLDatabase(engine)

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    SQLTableSchema(table_name=t.table_name, context_str=t.table_summary) 
    for t in table_infos
]

obj_index = ObjectIndex.from_objects(
    table_schema_objs, 
    table_node_mapping, 
    VectorStoreIndex
)

obj_retriever = obj_index.as_retriever(similarity_top_k=3)

### SQLRetriever + Table Parser

In [None]:
from llama_index.core.retrievers import SQLRetriever
from typing import List
from llama_index.core.query_pipeline import FnComponent

sql_retriever = SQLRetriever(sql)

def get_table_context_str(table_schema_objs: List[SQLTAble Schema])