In [1]:
import pandas as pd
from pathlib import Path

data_dir = Path("./CSV-data")
csv_files = sorted([f for f in data_dir.glob("*.csv")])
dfs = []
for csv_file in csv_files:
    print(f"processing file: {csv_file}")
    try:
        df = pd.read_csv(csv_file)
        df.rename(columns=lambda x: x.replace(".", "_"), inplace=True)
        dfs.append(df)
    except Exception as e:
        print(f"Error parsing {csv_file}: {str(e)}")

processing file: CSV-data/chunk-data.csv
processing file: CSV-data/relation__cityGroup.csv


In [3]:
import openai
import os
api_key_path = "./../apiKey.txt"

with open(api_key_path, "r") as f:
  api_key = ' '.join(f.readlines())
  openai.api_key = api_key

os.environ['OPENAI_API_KEY'] = openai.api_key

In [8]:
tableinfo_dir = "LEI_TableInfo"
!mkdir {tableinfo_dir}

### Table containing relationship details between nodes, including start and end node IDs, relationship type, status, periods, qualifiers, quantifiers, registration information, and validation sources.\n\nFollowing is the Foreign key ralationship\n- FOREIGN KEY Relationship_StartNode_NodeID REFERENCES EntityInformation(LEI),\n- FOREIGN KEY Relationship_EndNode_NodeID REFERENCES EntityInformation(LEI).\n\nRelationship_RelationshipType defines the relationship between Relationship_StartNode_NodeID and Relationship_EndNode_NodeID.\nFollowing are the avilable relations type:\n- 'IS_ULTIMATELY_CONSOLIDATED_BY' defines ultimate child,\n- IS_DIRECTLY_CONSOLIDATED_BY defines direct child\n- IS_FUND-MANAGED_BY defines fund managed by

In [5]:
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.llms.openai import OpenAI


class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(
        ..., description="table name (must be underscores and NO spaces)"
    )
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )


prompt_str = """\
Give me a summary of the table with the following JSON format.

- The table name must be unique to the table and describe it while being concise. 
- Do NOT output a generic table name (e.g. table, my_table).

Do NOT make the table name one of the following: {exclude_table_name_list}

Table:
{table_str}

Summary: """

program = LLMTextCompletionProgram.from_defaults(
    output_cls=TableInfo,
    llm=OpenAI(model="gpt-3.5-turbo"),
    prompt_template_str=prompt_str,
)

In [29]:
import json


def _get_tableinfo_with_index(idx: int) -> str:
    results_gen = Path(tableinfo_dir).glob(f"{idx}_*")
    results_list = list(results_gen)
    print(results_list)
    if len(results_list) == 0:
        return None
    elif len(results_list) == 1:
        path = results_list[0]
        return TableInfo.parse_file(path)
    else:
        raise ValueError(
            f"More than one file matching index: {list(results_gen)}"
        )

In [30]:
table_names = set()
table_infos = []
for idx, df in enumerate(dfs):
    table_info = _get_tableinfo_with_index(idx)
    print(table_info)
    if table_info:
        table_infos.append(table_info)
    else:
        while True:
            df_str = df.head(10).to_csv()
            table_info = program(
                table_str=df_str,
                exclude_table_name_list=str(list(table_names)),
            )
            table_name = table_info.table_name
            print(f"Processed table: {table_name}")
            if table_name not in table_names:
                table_names.add(table_name)
                break
            else:
                # try again
                print(f"Table name {table_name} already exists, trying again.")
                pass

        out_file = f"{tableinfo_dir}/{idx}_{table_name}.json"
        json.dump(table_info.dict(), open(out_file, "w"))
        table_infos.append(table_info)

[PosixPath('LEI_TableInfo/0_Entity_Legal_Info.json')]
table_name='Entity_Legal_Info' table_summary='Summary of legal information for various entities including legal names, addresses, registration details, entity status, creation dates, and conformity flags. LEI is the primary key in Entity_Legal_Info table.'
[PosixPath('LEI_TableInfo/1_Relationship_Info_Table.json')]
table_name='Relationship_Info_Table' table_summary="Table containing relationship details between nodes, including start and end node IDs, relationship type, status, periods, qualifiers, quantifiers, registration information, and validation sources.\n\nFollowing is the Foreign key ralationship\n- FOREIGN KEY Relationship_StartNode_NodeID REFERENCES EntityInformation(LEI),\n- FOREIGN KEY Relationship_EndNode_NodeID REFERENCES EntityInformation(LEI).\n\nRelationship_RelationshipType defines the relationship between Relationship_StartNode_NodeID and Relationship_EndNode_NodeID.\nFollowing are the avilable relations type:\n- 

In [31]:
table_infos

[TableInfo(table_name='Entity_Legal_Info', table_summary='Summary of legal information for various entities including legal names, addresses, registration details, entity status, creation dates, and conformity flags. LEI is the primary key in Entity_Legal_Info table.'),
 TableInfo(table_name='Relationship_Info_Table', table_summary="Table containing relationship details between nodes, including start and end node IDs, relationship type, status, periods, qualifiers, quantifiers, registration information, and validation sources.\n\nFollowing is the Foreign key ralationship\n- FOREIGN KEY Relationship_StartNode_NodeID REFERENCES EntityInformation(LEI),\n- FOREIGN KEY Relationship_EndNode_NodeID REFERENCES EntityInformation(LEI).\n\nRelationship_RelationshipType defines the relationship between Relationship_StartNode_NodeID and Relationship_EndNode_NodeID.\nFollowing are the avilable relations type:\n- 'IS_ULTIMATELY_CONSOLIDATED_BY' defines ultimate child,\n- IS_DIRECTLY_CONSOLIDATED_BY d

In [11]:
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
)
import re

In [12]:
# Function to create a sanitized column name
def sanitize_column_name(col_name):
    # Remove special characters and replace spaces with underscores
    return re.sub(r"\W+", "_", col_name)

In [32]:
# Function to create a table from a DataFrame using SQLAlchemy
def create_table_from_dataframe(
    df: pd.DataFrame, table_name: str, engine, metadata_obj
):
    # Sanitize column names
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    # Dynamically create columns based on DataFrame columns and data types
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]

    # Create a table with the defined columns
    table = Table(table_name, metadata_obj, *columns)

    # Create the table in the database
    metadata_obj.create_all(engine)

    # Insert data from DataFrame into the table
    with engine.connect() as conn:
        for _, row in df.iterrows():
            insert_stmt = table.insert().values(**row.to_dict())
            conn.execute(insert_stmt)
        conn.commit()

In [33]:
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
for idx, df in enumerate(dfs):
    tableinfo = _get_tableinfo_with_index(idx)
    print(f"Creating table: {tableinfo.table_name}")
    create_table_from_dataframe(df, tableinfo.table_name, engine, metadata_obj)

[PosixPath('LEI_TableInfo/0_Entity_Legal_Info.json')]
Creating table: Entity_Legal_Info
[PosixPath('LEI_TableInfo/1_Relationship_Info_Table.json')]
Creating table: Relationship_Info_Table


In [34]:
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import SQLDatabase, VectorStoreIndex

sql_database = SQLDatabase(engine)
table_schema_objs = []
table_node_mapping = SQLTableNodeMapping(sql_database)
    
# table_schema_objs = [
#     SQLTableSchema(table_name=t[tindex].table_name, context_str=_get_tableinfo_with_index(tindex).table_summary)
#     for tindex, t in table_infos
# ]  # add a SQLTableSchema for each table

for idx in range(len(table_infos)):
   table_data = _get_tableinfo_with_index(idx)
   table_schema_objs.append(SQLTableSchema(table_name=table_data.table_name, context_str=table_data.table_summary))


obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=2)

[PosixPath('LEI_TableInfo/0_Entity_Legal_Info.json')]
[PosixPath('LEI_TableInfo/1_Relationship_Info_Table.json')]


In [17]:
table_schema_objs

[SQLTableSchema(table_name='Entity_Legal_Info', context_str='Summary of legal information for various entities including legal names, addresses, registration details, entity status, creation dates, and conformity flags. LEI is the primary key in Entity_Legal_Info table.'),
 SQLTableSchema(table_name='Relationship_Info_Table', context_str="Table containing relationship details between nodes, including start and end node IDs, relationship type, status, periods, qualifiers, quantifiers, registration information, and validation sources.\n\nFollowing is the Foreign key ralationship\n- FOREIGN KEY Relationship_EndNode_NodeID REFERENCES EntityInformation(LEI).\n\nRelationship_RelationshipType defines the relationship between Relationship_StartNode_NodeID and Relationship_EndNode_NodeID.\nFollowing are the available relations type:\n- 'IS_ULTIMATELY_CONSOLIDATED_BY' defines ultimate child,\n- IS_DIRECTLY_CONSOLIDATED_BY defines direct child\n- IS_FUND-MANAGED_BY defines fund managed by.")]

In [35]:
custom_txt2sql_prompt = """Given an input question, construct a syntactically correct {dialect} SQL query to run, then look at the results of the query and return a comprehensive and detailed answer. Ensure that you:
            - Select only the relevant columns needed to answer the question.
            - Use correct column and table names as provided in the schema description. Avoid querying for columns that do not exist.
            - Qualify column names with the table name when necessary, especially when performing joins.
            - Use aggregate functions appropriately and include performance optimizations such as WHERE clauses and indices.
            - Add additional related information for the user.
            - Use background & definitions provided for more detailed answer. Follow the instructions.
            - Avoid hallucination. If you can't find an answer, say I'm not sure.
            
         
             Special Instructions:
            - Treat "province" and "region" as interchangeable terms in your queries.
            - Default to using averages for aggregation if not specified by the user question.
            - Recognize both short and full forms of Canadian provinces.
            - If the question involves a KPI not listed below, inform the user by showing the list of available KPIs.
            - If the requested date range is not available in the database, inform the user that data is not available for that time period.
            - Use bold and large fonts to highlight keywords in your answer.
            - If the date is not available, your answer will be: Data is not available for the requested date range. Please modify your query to include a valid date range.
            - Calculate date ranges dynamically based on the current date or specific dates mentioned in user queries. Use relative time expressions such as "last month" or "past year".
            - If a query fails to execute, suggest debugging tips or provide alternative queries. Ensure to handle common SQL errors gracefully."
            - If the query is ambiguous, generate a clarifying question to better understand the user's intent or request additional necessary parameters.
            - Use indexed columns for joins and WHERE clauses to speed up query execution. Use `EXPLAIN` plans for complex queries to ensure optimal performance.

            Additional Instructions:
            - Encourage users to provide specific date ranges or intervals for more accurate results.
            - Mention the importance of specifying provinces or regions for targeted analysis.
            - Provide examples of common SQL syntax errors and how to correct them.
            - Offer guidance on interpreting query results, including outliers or unexpected patterns.
            - Emphasize the significance of data integrity and potential implications of incomplete or inaccurate data.
            - Inform users that data exists from July 2023 for the list of KPIs.

            You are required to use the following format, each taking one line:
            Question: Question here
            SQLQuery: SQL Query to run
            SQLResult: Result of the SQLQuery
            Answer: Final answer here
            
            Use these real examples for complex queries:

            Example 1:
            Question: What market had the highest average VoLTE accessibility last month?
            SQLQuery: SELECT site_info.market as market, AVG(daily_kpis.volte_accessibility) AS avg_volte_accessibility FROM daily_kpis JOIN site_info ON daily_kpis.site = site_info.site WHERE daily_kpis.date_id >= (current_date - INTERVAL '1 month') GROUP BY market ORDER BY avg_volte_accessibility DESC LIMIT 1;
            SQLResult:
            market         | avg_volte_accessibility
            --------------------------
            quebec | 80.98
            Answer: The market with the highest average VoLTE accessibility last month was the Quebec with 80.98 volte accessibility.
			
            Example 2:
            Question: Provide the monthly average VoLTE delay over the past year for Quebec.
            SQLQuery: SELECT EXTRACT(YEAR FROM daily_kpis.date_id) AS year, EXTRACT(MONTH FROM daily_kpis.date_id) AS month, AVG(daily_kpis.volte_delay) AS avg_volte_delay FROM seed_data.daily_kpis JOIN seed_data.site_info ON daily_kpis.site = site_info.site WHERE site_info.province = 'QC' AND EXTRACT(YEAR FROM daily_kpis.date_id) = EXTRACT(YEAR FROM current_date) - 1 GROUP BY year, month ORDER BY year, month;
            SQLResult:
            year | month | avg_volte_delay
            -----|-------|---------
            2023	7	 x
			.....

            Answer: The average VoLTE delay in Quebec showed varying values throughout the year with July having an average of x ms and December an average of y ms.


            Only use tables listed below.
            {schema}

            Question: {query_str}
            SQLQuery:
        """

In [36]:
from llama_index.core.retrievers import SQLRetriever
from typing import List
from llama_index.core.query_pipeline import FnComponent

sql_retriever = SQLRetriever(sql_database)


def get_table_context_str(table_schema_objs: List[SQLTableSchema]):
    """Get table context string."""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        table_info = sql_database.get_single_table_info(
            table_schema_obj.table_name
        )
        if table_schema_obj.context_str:
            table_opt_context = " The table description is: "
            table_opt_context += table_schema_obj.context_str
            table_info += table_opt_context

        context_strs.append(table_info)
    return "\n\n".join(context_strs)


table_parser_component = FnComponent(fn=get_table_context_str)

In [37]:
from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from llama_index.core import PromptTemplate
from llama_index.core.llms import ChatResponse


def parse_response_to_sql(response: ChatResponse) -> str:
    """Parse response to SQL."""
    response = response.message.content
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()


sql_parser_component = FnComponent(fn=parse_response_to_sql)

text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
    dialect=engine.dialect.name     # sqlite
)
text2sql_prompt.template = custom_txt2sql_prompt
print(text2sql_prompt.template)

Given an input question, construct a syntactically correct {dialect} SQL query to run, then look at the results of the query and return a comprehensive and detailed answer. Ensure that you:
            - Select only the relevant columns needed to answer the question.
            - Use correct column and table names as provided in the schema description. Avoid querying for columns that do not exist.
            - Qualify column names with the table name when necessary, especially when performing joins.
            - Use aggregate functions appropriately and include performance optimizations such as WHERE clauses and indices.
            - Add additional related information for the user.
            - Use background & definitions provided for more detailed answer. Follow the instructions.
            - Avoid hallucination. If you can't find an answer, say I'm not sure.
            
         
             Special Instructions:
            - Treat "province" and "region" as interchangeable

In [38]:
response_synthesis_prompt_str = (
    "Given an input question, synthesize a response from the query results.\n"
    "Query: {query_str}\n"
    "SQL: {sql_query}\n"
    "SQL Response: {context_str}\n"
    "Response: "
)
response_synthesis_prompt = PromptTemplate(
    response_synthesis_prompt_str,
)

In [22]:
llm=OpenAI(model="gpt-3.5-turbo")

In [39]:
from llama_index.core import VectorStoreIndex, load_index_from_storage
from sqlalchemy import text
from llama_index.core.schema import TextNode
from llama_index.core import StorageContext
import os
from pathlib import Path
from typing import Dict


def index_all_tables(
    sql_database: SQLDatabase, table_index_dir: str = "table_index_dir"
) -> Dict[str, VectorStoreIndex]:
    """Index all tables."""
    if not Path(table_index_dir).exists():
        os.makedirs(table_index_dir)

    vector_index_dict = {}
    engine = sql_database.engine
    for table_name in sql_database.get_usable_table_names():
        print(f"Indexing rows in table: {table_name}")
        if not os.path.exists(f"{table_index_dir}/{table_name}"):
            # get all rows from table
            with engine.connect() as conn:
                cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
                result = cursor.fetchall()
                row_tups = []
                for row in result:
                    row_tups.append(tuple(row))

            # index each row, put into vector store index
            nodes = [TextNode(text=str(t)) for t in row_tups]

            # put into vector store index (use OpenAIEmbeddings by default)
            index = VectorStoreIndex(nodes)

            # save index
            index.set_index_id("vector_index")
            index.storage_context.persist(f"{table_index_dir}/{table_name}")
        else:
            # rebuild storage context
            storage_context = StorageContext.from_defaults(
                persist_dir=f"{table_index_dir}/{table_name}"
            )
            # load index
            index = load_index_from_storage(
                storage_context, index_id="vector_index"
            )
        vector_index_dict[table_name] = index

    return vector_index_dict


vector_index_dict = index_all_tables(sql_database)

Indexing rows in table: Entity_Legal_Info
Indexing rows in table: Relationship_Info_Table


In [40]:
from llama_index.core.retrievers import SQLRetriever
from typing import List
from llama_index.core.query_pipeline import FnComponent

sql_retriever = SQLRetriever(sql_database)


def get_table_context_and_rows_str(
    query_str: str, table_schema_objs: List[SQLTableSchema]
):
    """Get table context string."""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        # first append table info + additional context
        table_info = sql_database.get_single_table_info(
            table_schema_obj.table_name
        )
        if table_schema_obj.context_str:
            table_opt_context = " The table description is: "
            table_opt_context += table_schema_obj.context_str
            table_info += table_opt_context

        # also lookup vector index to return relevant table rows
        vector_retriever = vector_index_dict[
            table_schema_obj.table_name
        ].as_retriever(similarity_top_k=2)
        relevant_nodes = vector_retriever.retrieve(query_str)
        if len(relevant_nodes) > 0:
            table_row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
            for node in relevant_nodes:
                table_row_context += str(node.get_content()) + "\n"
            table_info += table_row_context

        context_strs.append(table_info)
    return "\n\n".join(context_strs)


table_parser_component = FnComponent(fn=get_table_context_and_rows_str)

In [41]:
from llama_index.core.query_pipeline import (
    QueryPipeline as QP,
    Link,
    InputComponent,
    CustomQueryComponent,
)

qp = QP(
    modules={
        "input": InputComponent(),
        "table_retriever": obj_retriever,
        "table_output_parser": table_parser_component,
        "text2sql_prompt": text2sql_prompt,
        "text2sql_llm": llm,
        "sql_output_parser": sql_parser_component,
        "sql_retriever": sql_retriever,
        "response_synthesis_prompt": response_synthesis_prompt,
        "response_synthesis_llm": llm,
    },
    verbose=True,
)

In [42]:
qp.add_link("input", "table_retriever")
qp.add_link("input", "table_output_parser", dest_key="query_str")
qp.add_link(
    "table_retriever", "table_output_parser", dest_key="table_schema_objs"
)
qp.add_link("input", "text2sql_prompt", dest_key="query_str")
qp.add_link("table_output_parser", "text2sql_prompt", dest_key="schema")
qp.add_chain(
    ["text2sql_prompt", "text2sql_llm", "sql_output_parser", "sql_retriever"]
)
qp.add_link(
    "sql_output_parser", "response_synthesis_prompt", dest_key="sql_query"
)
qp.add_link(
    "sql_retriever", "response_synthesis_prompt", dest_key="context_str"
)
qp.add_link("input", "response_synthesis_prompt", dest_key="query_str")
qp.add_link("response_synthesis_prompt", "response_synthesis_llm")

In [27]:
response = qp.run(
    query="What is the LEI for CITYGROUP?"
)
print(str(response))

[1;3;38;2;155;135;227m> Running module input with input: 
query: What is the LEI for CITYGROUP?

[0m[1;3;38;2;155;135;227m> Running module table_retriever with input: 
input: What is the LEI for CITYGROUP?

[0m[1;3;38;2;155;135;227m> Running module table_output_parser with input: 
query_str: What is the LEI for CITYGROUP?
table_schema_objs: [SQLTableSchema(table_name='Entity_Legal_Info', context_str='Summary of legal information for various entities including legal names, addresses, registration details, entity status, creation dates, an...

[0m[1;3;38;2;155;135;227m> Running module text2sql_prompt with input: 
query_str: What is the LEI for CITYGROUP?
schema: Table 'Entity_Legal_Info' has columns: Unnamed_0 (INTEGER), LEI (VARCHAR), Entity_LegalName (VARCHAR), Entity_LegalName_xmllang (VARCHAR), Entity_OtherEntityNames_OtherEntityName_1 (VARCHAR), Entity_O...

[0m[1;3;38;2;155;135;227m> Running module text2sql_llm with input: 
messages: Given an input question, first create a

In [43]:
response = qp.run(
    query="How many relation are direct child of CITIGROUP INC?"
)
print(str(response))

[1;3;38;2;155;135;227m> Running module input with input: 
query: How many relation are direct child of CITIGROUP INC?

[0m[1;3;38;2;155;135;227m> Running module table_retriever with input: 
input: How many relation are direct child of CITIGROUP INC?

[0m[1;3;38;2;155;135;227m> Running module table_output_parser with input: 
query_str: How many relation are direct child of CITIGROUP INC?
table_schema_objs: [SQLTableSchema(table_name='Relationship_Info_Table', context_str="Table containing relationship details between nodes, including start and end node IDs, relationship type, status, periods, qualifiers...

[0m[1;3;38;2;155;135;227m> Running module text2sql_prompt with input: 
query_str: How many relation are direct child of CITIGROUP INC?
schema: Table 'Relationship_Info_Table' has columns: Unnamed_0 (INTEGER), Relationship_StartNode_NodeID (VARCHAR), Relationship_StartNode_NodeIDType (VARCHAR), Relationship_EndNode_NodeID (VARCHAR), Relations...

[0m[1;3;38;2;155;135;227m> Ru

In [44]:
response = qp.run(
    query="How many relation are ultimate child of CITIGROUP INC?"
)
print(str(response))

[1;3;38;2;155;135;227m> Running module input with input: 
query: How many relation are ultimate child of CITIGROUP INC?

[0m[1;3;38;2;155;135;227m> Running module table_retriever with input: 
input: How many relation are ultimate child of CITIGROUP INC?

[0m[1;3;38;2;155;135;227m> Running module table_output_parser with input: 
query_str: How many relation are ultimate child of CITIGROUP INC?
table_schema_objs: [SQLTableSchema(table_name='Relationship_Info_Table', context_str="Table containing relationship details between nodes, including start and end node IDs, relationship type, status, periods, qualifiers...

[0m[1;3;38;2;155;135;227m> Running module text2sql_prompt with input: 
query_str: How many relation are ultimate child of CITIGROUP INC?
schema: Table 'Relationship_Info_Table' has columns: Unnamed_0 (INTEGER), Relationship_StartNode_NodeID (VARCHAR), Relationship_StartNode_NodeIDType (VARCHAR), Relationship_EndNode_NodeID (VARCHAR), Relations...

[0m[1;3;38;2;155;135;