https://docs.llamaindex.ai/en/stable/examples/pipeline/query_pipeline_sql.html

In [1]:
import pandas as pd
from pathlib import Path

data_dir = Path("./WikiTableQuestions/csv/204-csv")
csv_files = sorted([f for f in data_dir.glob("*.csv")])
dfs = []
for csv_file in csv_files:
    print(f"processing file: {csv_file}")
    try:
        df = pd.read_csv(csv_file)
        dfs.append(df)
    except Exception as e:
        print(f"Error parsing {csv_file}: {str(e)}")


processing file: WikiTableQuestions/csv/204-csv/0.csv
processing file: WikiTableQuestions/csv/204-csv/1.csv
processing file: WikiTableQuestions/csv/204-csv/10.csv
processing file: WikiTableQuestions/csv/204-csv/100.csv
processing file: WikiTableQuestions/csv/204-csv/101.csv
processing file: WikiTableQuestions/csv/204-csv/102.csv
processing file: WikiTableQuestions/csv/204-csv/103.csv
processing file: WikiTableQuestions/csv/204-csv/104.csv
processing file: WikiTableQuestions/csv/204-csv/105.csv
processing file: WikiTableQuestions/csv/204-csv/106.csv
processing file: WikiTableQuestions/csv/204-csv/107.csv
processing file: WikiTableQuestions/csv/204-csv/108.csv
processing file: WikiTableQuestions/csv/204-csv/109.csv
processing file: WikiTableQuestions/csv/204-csv/11.csv
processing file: WikiTableQuestions/csv/204-csv/110.csv
processing file: WikiTableQuestions/csv/204-csv/111.csv
processing file: WikiTableQuestions/csv/204-csv/112.csv
processing file: WikiTableQuestions/csv/204-csv/113.cs

In [2]:
tableinfo_dir = "WikiTableQuestions_TableInfo"
!mkdir {tableinfo_dir}

mkdir: WikiTableQuestions_TableInfo: File exists


In [3]:
print(df.head(5))

                                          Contestant Original\nTribe  \
0                        Branka Čudanov\n28, Kikinda        Ga 'dang   
1                       Gordana Berger\n38, Belgrade          Manobo   
2                           Ana Mitrić\n23, Belgrade        Ga 'dang   
3                      Milena Vitanović\n21, Paraćin        Ga 'dang   
4  Nikola Kovačević\nReturned to game from Ghost ...        Ga 'dang   

  First\nSwitch Second\nSwitch Merged\nTribe                 Finish  \
0           NaN            NaN           NaN   2nd Voted Out\nDay 7   
1           NaN            NaN           NaN   1st Voted Out\nDay 4   
2           NaN            NaN           NaN  3rd Voted Out\nDay 10   
3           NaN            NaN           NaN  4th Voted Out\nDay 13   
4           NaN            NaN           NaN  5th Voted Out\nDay 16   

                 Ghost\nIsland  Total\nVotes  
0        1st Eliminated\nDay 9            10  
1       2nd Eliminated\nDay 12             9  

In [4]:
from typing import List
from pydantic import BaseModel, Field
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.bridge.pydantic import BaseModel, Field
#from llama_index.llms.openai import OpenAI
from llama_index.llms.ollama import Ollama
from llama_index.core import Settings
Settings.llm = Ollama(model="llama2", request_timeout=30.0)

class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(
        ..., description="table name (must be underscores and NO spaces)"
    )
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )


prompt_str = """\
Give me a summary of the table with the following JSON format.

- The table name must be unique to the table and describe it while being concise. 
- Do NOT output a generic table name (e.g. table, my_table).

Do NOT make the table name one of the following: {exclude_table_name_list}

Table:
{table_str}

Summary: """

program = LLMTextCompletionProgram.from_defaults(
    output_cls=TableInfo,
    llm=Ollama(model="llama2", request_timeout=30.0),
    prompt_template_str=prompt_str,
)


In [5]:
print(tableinfo_dir)
import os
os.getcwd()
results_gen = Path(tableinfo_dir).glob(f"{0}_*")
results_list = list(results_gen)
print(results_list)

WikiTableQuestions_TableInfo
[PosixPath('WikiTableQuestions_TableInfo/0_Polls.json')]


In [6]:
import json


def _get_tableinfo_with_index(idx: int) -> str:
    results_gen = Path(tableinfo_dir).glob(f"{idx}_*")
    results_list = list(results_gen)
    if len(results_list) == 0:
        return None  # type: ignore
    elif len(results_list) == 1:
        path = results_list[0]
        return TableInfo.parse_file(path) # type: ignore
    else:
        raise ValueError(
            f"More than one file matching index: {list(results_gen)}"
        )


table_names = set()
table_infos = []
for idx, df in enumerate(dfs):
    table_info = _get_tableinfo_with_index(idx)
    if table_info:
        table_infos.append(table_info)
    else:
        while True:
            df_str = df.head(10).to_csv()
            table_info = program(
                table_str=df_str,
                exclude_table_name_list=str(list(table_names)),
            )
            table_name = table_info.table_name
            #table_name = TableInfo.table_name
            print(f"Processed table: {table_name}")
            if table_name not in table_names:
                table_names.add(table_name)
                break
            else:
                # try again
                print(f"Table name {table_name} already exists, trying again.")
                pass

        out_file = f"{tableinfo_dir}/{idx}_{table_name}.json"
        json.dump(table_info.dict(), open(out_file, "w"))
    table_infos.append(table_info)

ValidationError: 2 validation errors for TableInfo
table_name
  field required (type=value_error.missing)
table_summary
  field required (type=value_error.missing)

In [None]:
# put data into sqlite db
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
)
import re


# Function to create a sanitized column name
def sanitize_column_name(col_name):
    # Remove special characters and replace spaces with underscores
    return re.sub(r"\W+", "_", col_name)


# Function to create a table from a DataFrame using SQLAlchemy
def create_table_from_dataframe(
    df: pd.DataFrame, table_name: str, engine, metadata_obj
):
    # Sanitize column names
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    # Dynamically create columns based on DataFrame columns and data types
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]

    # Create a table with the defined columns
    table = Table(table_name, metadata_obj, *columns)

    # Create the table in the database
    metadata_obj.create_all(engine)

    # Insert data from DataFrame into the table
    with engine.connect() as conn:
        for _, row in df.iterrows():
            insert_stmt = table.insert().values(**row.to_dict())
            conn.execute(insert_stmt)
        conn.commit()


engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
for idx, df in enumerate(dfs):
    tableinfo = _get_tableinfo_with_index(idx)
    print(f"Creating table: {tableinfo.table_name}")
    create_table_from_dataframe(df, tableinfo.table_name, engine, metadata_obj)

Advanced Capability 1: Text-to-SQL with Query-Time Table Retrieval.
We now show you how to setup an e2e text-to-SQL with table retrieval.

Define Modules
Here we define the core modules.

Object index + retriever to store table schemas

SQLDatabase object to connect to the above tables + SQLRetriever.

Text-to-SQL Prompt

Response synthesis Prompt

LLM

Object index, retriever, SQLDatabase

In [None]:
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import SQLDatabase, VectorStoreIndex

sql_database = SQLDatabase(engine)

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
    for t in table_infos
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)

SQLRetriever + Table Parser

In [None]:
from llama_index.core.retrievers import SQLRetriever
from typing import List
from llama_index.core.query_pipeline import FnComponent

sql_retriever = SQLRetriever(sql_database)


def get_table_context_str(table_schema_objs: List[SQLTableSchema]):
    """Get table context string."""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        table_info = sql_database.get_single_table_info(
            table_schema_obj.table_name
        )
        if table_schema_obj.context_str:
            table_opt_context = " The table description is: "
            table_opt_context += table_schema_obj.context_str
            table_info += table_opt_context

        context_strs.append(table_info)
    return "\n\n".join(context_strs)


table_parser_component = FnComponent(fn=get_table_context_str)

Text-to-SQL Prompt + Output Parser

In [None]:
from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from llama_index.core import PromptTemplate
from llama_index.core.query_pipeline import FnComponent
from llama_index.core.llms import ChatResponse


def parse_response_to_sql(response: ChatResponse) -> str:
    """Parse response to SQL."""
    response = response.message.content
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        # TODO: move to removeprefix after Python 3.9+
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()


sql_parser_component = FnComponent(fn=parse_response_to_sql)

text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
    dialect=engine.dialect.name
)
print(text2sql_prompt.template)

Response Synthesis Prompt

In [None]:
response_synthesis_prompt_str = (
    "Given an input question, synthesize a response from the query results.\n"
    "Query: {query_str}\n"
    "SQL: {sql_query}\n"
    "SQL Response: {context_str}\n"
    "Response: "
)
response_synthesis_prompt = PromptTemplate(
    response_synthesis_prompt_str,
)

In [None]:
llm = OpenAI(model="gpt-3.5-turbo")

Define Query Pipeline
Now that the components are in place, let’s define the query pipeline!

In [None]:
from llama_index.core.query_pipeline import (
    QueryPipeline as QP,
    Link,
    InputComponent,
    CustomQueryComponent,
)

qp = QP(
    modules={
        "input": InputComponent(),
        "table_retriever": obj_retriever,
        "table_output_parser": table_parser_component,
        "text2sql_prompt": text2sql_prompt,
        "text2sql_llm": llm,
        "sql_output_parser": sql_parser_component,
        "sql_retriever": sql_retriever,
        "response_synthesis_prompt": response_synthesis_prompt,
        "response_synthesis_llm": llm,
    },
    verbose=True,
)

In [None]:
qp.add_chain(["input", "table_retriever", "table_output_parser"])
qp.add_link("input", "text2sql_prompt", dest_key="query_str")
qp.add_link("table_output_parser", "text2sql_prompt", dest_key="schema")
qp.add_chain(
    ["text2sql_prompt", "text2sql_llm", "sql_output_parser", "sql_retriever"]
)
qp.add_link(
    "sql_output_parser", "response_synthesis_prompt", dest_key="sql_query"
)
qp.add_link(
    "sql_retriever", "response_synthesis_prompt", dest_key="context_str"
)
qp.add_link("input", "response_synthesis_prompt", dest_key="query_str")
qp.add_link("response_synthesis_prompt", "response_synthesis_llm")

Visualize Query Pipeline
A really nice property of the query pipeline syntax is you can easily visualize it in a graph via networkx.

In [None]:
from pyvis.network import Network

net = Network(notebook=True, cdn_resources="in_line", directed=True)
net.from_nx(qp.dag)
net.show("text2sql_dag.html")

Run Some Queries!
Now we’re ready to run some queries across this entire pipeline

In [None]:
response = qp.run(
    query="What was the year that The Notorious B.I.G was signed to Bad Boy?"
)
print(str(response))

In [None]:
response = qp.run(query="Who won best director in the 1972 academy awards")
print(str(response))

In [None]:
response = qp.run(query="What was the term of Pasquale Preziosa?")
print(str(response))