In [1]:
from llama_index.query_pipeline import (
    QueryPipeline as QP,
    Link,
    InputComponent
)

from llama_index.query_engine.pandas import PandasInstructionParser
from llama_index.llms import OpenAI
from llama_index.prompts import PromptTemplate

In [3]:
import os
from dotenv import load_dotenv
load_dotenv()

True

In [4]:
import pandas as pd 

df = pd.read_csv("imdb.csv")
df

Unnamed: 0,Rank,Movie_name,Year,Certificate,Runtime_in_min,Genre,Rating_from_10
0,1,The Shawshank Redemption,1994,R,142,Drama,9.3
1,2,The Godfather,1972,R,175,"Crime, Drama",9.2
2,3,The Dark Knight,2008,PG-13,152,"Action, Crime, Drama",9.0
3,4,The Lord of the Rings: The Return of the King,2003,PG-13,201,"Action, Adventure, Drama",9.0
4,5,Schindler's List,1993,R,195,"Biography, Drama, History",9.0
...,...,...,...,...,...,...,...
995,996,Sabrina,1954,Passed,113,"Comedy, Drama, Romance",7.6
996,997,From Here to Eternity,1953,Passed,118,"Drama, Romance, War",7.6
997,998,Snow White and the Seven Dwarfs,1937,Approved,83,"Animation, Adventure, Family",7.6
998,999,The 39 Steps,1935,Approved,86,"Crime, Mystery, Thriller",7.6


In [5]:
instruction_str = (
    "1. Convert the query to executable Python code using Pandas.\n"
    "2. The final line of code should be a Python expression that can be called with the `eval()` function.\n"
    "3. The code should represent a solution to the query.\n"
    "4. PRINT ONLY THE EXPRESSION.\n"
    "5. Do not quote the expression.\n"
)

pandas_prompt_str = (
    "You are working with a pandas dataframe in Python.\n"
    "The name of the dataframe is `df`.\n"
    "This is the result of `print(df.head())`:\n"
    "{df_str}\n\n"
    "Follow these instructions:\n"
    "{instruction_str}\n"
    "Query: {query_str}\n\n"
    "Expression:"
)

response_synthesis_prompt_str = (
    "Given an input question, synthesize a response from the query results.\n"
    "Query: {query_str}\n\n"
    "Pandas Instructions (optional):\n{pandas_instructions}\n\n"
    "Pandas Output: {pandas_output}\n\n"
    "Response: "
)

In [6]:
pandas_prompt = PromptTemplate(pandas_prompt_str).partial_format(
    instruction_str=instruction_str, df_str=df.head(5)
)

pandas_output_parser = PandasInstructionParser(df)
response_synthesis_prompt = PromptTemplate(response_synthesis_prompt_str)
llm = OpenAI(model="gpt-3.5-turbo")

#### Query Pipeline

In [7]:
qp = QP(
    modules={
        "input": InputComponent(),
        "pandas_prompt": pandas_prompt,
        "llm1": llm,
        "pandas_output_parser": pandas_output_parser,
        "response_synthesis_prompt": response_synthesis_prompt,
        "llm2": llm,
    },
    verbose=True,
)

qp.add_chain(["input", "pandas_prompt", "llm1", "pandas_output_parser"])
qp.add_links(
    [
        Link("input", "response_synthesis_prompt", dest_key="query_str"),
        Link("llm1", "response_synthesis_prompt", dest_key="pandas_instructions"),
        Link("pandas_output_parser", "response_synthesis_prompt", dest_key="pandas_output")
    ]
)
qp.add_link("response_synthesis_prompt", "llm2")

In [11]:
from pyvis.network import Network
import networkx as nx
from IPython.display import display, HTML


net = Network(notebook=True, cdn_resources="in_line", directed=True)
net.from_nx(qp.dag)
html = net.generate_html()
with open("text2sql_dag.html", mode='w', encoding='utf-8') as fp:
        fp.write(html)

In [12]:
response = qp.run(
    query_str="What is the best rated Biography movie"
)

[1;3;38;2;155;135;227m> Running module input with input: 
query_str: What is the best rated Biography movie

[0m[1;3;38;2;155;135;227m> Running module pandas_prompt with input: 
query_str: What is the best rated Biography movie

[0m[1;3;38;2;155;135;227m> Running module llm1 with input: 
messages: You are working with a pandas dataframe in Python.
The name of the dataframe is `df`.
This is the result of `print(df.head())`:
  Rank                                     Movie_name  Year Certificate ...

[0m[1;3;38;2;155;135;227m> Running module pandas_output_parser with input: 
input: assistant: df[df['Genre'].str.contains('Biography')].nlargest(1, 'Rating_from_10')

[0m[1;3;38;2;155;135;227m> Running module response_synthesis_prompt with input: 
query_str: What is the best rated Biography movie
pandas_instructions: assistant: df[df['Genre'].str.contains('Biography')].nlargest(1, 'Rating_from_10')
pandas_output:   Rank        Movie_name  Year Certificate  Runtime_in_min  \
4    5  

In [13]:
print(response.message.content)

The best rated Biography movie is "Schindler's List" from 1993 with a rating of 9.0 out of 10.


#### Query Pipeline for Text-To-SQL

In [14]:
import pandas as pd
from pathlib import Path

data_dir = Path("")
csv_files = sorted([f for f in data_dir.glob("*.csv")])
dfs = []

for csv_file in csv_files:
    print(f"Processing file: {csv_file}")
    try:
        df = pd.read_csv(csv_file)
        dfs.append(df)
    except Exception as e:
        print(f"Error parsing {csv_file}: {str(e)}")

Processing file: documentaries.csv
Processing file: movies.csv
Processing file: sitcoms.csv


In [15]:
tableinfo_dir = "IMDb"
!mkdir {tableinfo_dir}

In [18]:
from llama_index.program import LLMTextCompletionProgram
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.llms import OpenAI


class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(
        ..., description="table name (must be underscores and NO spaces)"
    )
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )


prompt_str = """\
Give me a summary of the table with the following JSON format.

- The table name must be unique to the table and describe it while being concise.
- Do NOT output a generic table name (e.g. table, my_table).

Do NOT make the table name one of the following: {exclude_table_name_list}

Table:
{table_str}

Summary: """

program = LLMTextCompletionProgram.from_defaults(
    output_cls=TableInfo,
    llm=OpenAI(model="gpt-3.5-turbo"),
    prompt_template_str=prompt_str,
)

In [19]:
import json


def _get_tableinfo_with_index(idx: int) -> str:
    results_gen = Path(tableinfo_dir).glob(f"{idx}_*")
    results_list = list(results_gen)
    if len(results_list) == 0:
        return None
    elif len(results_list) == 1:
        path = results_list[0]
        return TableInfo.parse_file(path)
    else:
        raise ValueError(
            f"More than one file matching index: {list(results_gen)}"
        )


table_names = set()
table_infos = []
for idx, df in enumerate(dfs):
    table_info = _get_tableinfo_with_index(idx)
    if table_info:
        table_infos.append(table_info)
    else:
        while True:
            df_str = df.head(10).to_csv()
            table_info = program(
                table_str=df_str,
                exclude_table_name_list=str(list(table_names)),
            )
            table_name = table_info.table_name
            print(f"Processed table: {table_name}")
            if table_name not in table_names:
                table_names.add(table_name)
                break
            else:
                # try again
                print(f"Table name {table_name} already exists, trying again.")
                pass

        out_file = f"{tableinfo_dir}/{idx}_{table_name}.json"
        json.dump(table_info.dict(), open(out_file, "w"))
    table_infos.append(table_info)

Processed table: Documentary_Movies
Processed table: Top_Rated_Movies
Processed table: Top_TV_Series


#### Putting data in sql database

In [20]:
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
)
import re

def sanitize_column_name(col_name):
    return re.sub(r"\W+", "_", col_name)

def create_table_from_dataframe(
        df: pd.DataFrame, table_name: str, engine, metadata_obj
):
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]

    table = Table(table_name, metadata_obj, *columns)

    metadata_obj.create_all(engine)

    with engine.connect() as conn:
        for _, row in df.iterrows():
            insert_stmt = table.insert().values(**row.to_dict())
            conn.execute(insert_stmt)
        conn.commit()

engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
for idx, df in enumerate(dfs):
    tableinfo = _get_tableinfo_with_index(idx)
    print(f"Creating table: {tableinfo.table_name}")
    create_table_from_dataframe(df, tableinfo.table_name, engine, metadata_obj)

Creating table: Documentary_Movies
Creating table: Top_Rated_Movies
Creating table: Top_TV_Series


#### Query-Time Table Retrieval

In [21]:
from llama_index.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema
)
from llama_index import SQLDatabase, VectorStoreIndex

sql_database = SQLDatabase(engine)

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
    for t in table_infos
]

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex
)

obj_retriever = obj_index.as_retriever(similarity_top_k=3)

#### SQLRetriever + Table Parser

In [22]:
from llama_index.retrievers import SQLRetriever
from typing import List
from llama_index.query_pipeline import FnComponent

sql_retriever = SQLRetriever(sql_database)

def _get_table_context_str(table_schema_objs: List[SQLTableSchema]):
    """Get table context string"""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        table_info = sql_database.get_single_table_info(
            table_schema_obj.table_name
        )
        if table_schema_obj.context_str:
            table_opt_context = " The table description is: "
            table_opt_context += table_schema_obj.context_str
            table_info += table_opt_context
        
        context_strs.append(table_info)
    return "\n\n".join(context_strs)

table_parser_component = FnComponent(fn=_get_table_context_str)

In [25]:
from llama_index.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from llama_index.prompts import PromptTemplate
from llama_index.query_pipeline import FnComponent
from llama_index.llms import ChatResponse

def parse_response_to_sql(response: ChatResponse) -> str:
    """Parse response to SQL."""
    response = response.message.content
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()

sql_parser_component = FnComponent(fn=parse_response_to_sql)

text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
    dialect=engine.dialect.name
)

print(text2sql_prompt.template)

Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use tables listed below.
{schema}

Question: {query_str}
SQLQuery: 


In [26]:
response_synthesis_prompt_str = (
    "Given an input question, synthesize a response from the query results.\n"
    "Query: {query_str}\n"
    "SQL: {sql_query}\n"
    "SQL Response: {context_str}\n"
    "Response: "
)
response_synthesis_prompt = PromptTemplate(
    response_synthesis_prompt_str,
)

In [27]:
llm = OpenAI(model="gpt-3.5-turbo")

In [28]:
from llama_index.query_pipeline import (
    QueryPipeline as QP,
    Link,
    InputComponent,
    CustomQueryComponent,
)

qp = QP(
    modules={
        "input": InputComponent(),
        "table_retriever": obj_retriever,
        "table_output_parser": table_parser_component,
        "text2sql_prompt": text2sql_prompt,
        "text2sql_llm": llm,
        "sql_output_parser": sql_parser_component,
        "sql_retriever": sql_retriever,
        "response_synthesis_prompt": response_synthesis_prompt,
        "response_synthesis_llm": llm,
    },
    verbose=True,
)

In [29]:
qp.add_chain(["input", "table_retriever", "table_output_parser"])
qp.add_link("input", "text2sql_prompt", dest_key="query_str")
qp.add_link("table_output_parser", "text2sql_prompt", dest_key="schema")
qp.add_chain(
    ["text2sql_prompt", "text2sql_llm", "sql_output_parser", "sql_retriever"]
)
qp.add_link(
    "sql_output_parser", "response_synthesis_prompt", dest_key="sql_query"
)
qp.add_link(
    "sql_retriever", "response_synthesis_prompt", dest_key="context_str"
)
qp.add_link("input", "response_synthesis_prompt", dest_key="query_str")
qp.add_link("response_synthesis_prompt", "response_synthesis_llm")

In [30]:
from pyvis.network import Network
import networkx as nx
from IPython.display import display, HTML


net = Network(notebook=True, cdn_resources="in_line", directed=True)
net.from_nx(qp.dag)
html = net.generate_html()
with open("table_chain.html", mode='w', encoding='utf-8') as fp:
        fp.write(html)

In [32]:
response = qp.run(
    query="What is the best rated movie of the year 1994"
)
print(str(response))

[1;3;38;2;155;135;227m> Running module input with input: 
query: What is the best rated movie of the year 1994

[0m[1;3;38;2;155;135;227m> Running module table_retriever with input: 
input: What is the best rated movie of the year 1994

[0m[1;3;38;2;155;135;227m> Running module table_output_parser with input: 
table_schema_objs: [SQLTableSchema(table_name='Top_Rated_Movies', context_str='Ranking, movie names, release year, certificate, runtime, genre, and ratings of top-rated movies.'), SQLTableSchema(table_name='Documentary_...

[0m[1;3;38;2;155;135;227m> Running module text2sql_prompt with input: 
query_str: What is the best rated movie of the year 1994
schema: Table 'Top_Rated_Movies' has columns: Rank (VARCHAR), Movie_name (VARCHAR), Year (VARCHAR), Certificate (VARCHAR), Runtime_in_min (INTEGER), Genre (VARCHAR), Rating_from_10 (INTEGER), and foreign keys...

[0m[1;3;38;2;155;135;227m> Running module text2sql_llm with input: 
messages: Given an input question, first creat