In [39]:
import pandas as pd
from pathlib import Path
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.llms.openai import OpenAI

In [40]:
import os

base_path = "/Users/gourav/Documents/ProjectsRepos/AI/Metadata/Docs"
filepath = f"{base_path}/CHATGPT_API_KEY.txt"
with open(filepath, "r") as f:
  api_key = ' '.join(f.readlines())
  os.environ["OPENAI_API_KEY"] = api_key
  os.environ["GOOGLE_API_KEY"] = api_key

In [46]:
# import pandas as pd
# from pathlib import Path

# data_dir = Path("/Users/gourav/Documents/ProjectsRepos/AI/Metadata/Docs")
# csv_files = sorted([f for f in data_dir.glob("*.csv")])
# dfs = []
# for csv_file in csv_files:
#     print(f"processing file: {csv_file}")
#     try:
#         df = pd.read_csv(csv_file)
#         dfs.append(df)
#     except Exception as e:
#         print(f"Error parsing {csv_file}: {str(e)}")

# --------------

import pandas as pd
from pathlib import Path

data_dir = Path("/Users/gourav/Documents/ProjectsRepos/AI/Metadata/Docs")
csv_files = sorted([f for f in data_dir.glob("*.csv")])
dfs = []

# Define a function to read and process chunks
def process_chunks(chunk):
    # Append the chunk to the list of dataframes
    dfs.append(chunk)

# Read the first 20 entries from each CSV file in chunks
for csv_file in csv_files:
    print(f"processing file: {csv_file}")
    try:
        # Read the CSV file in chunks of 1000 rows
        chunk_iter = pd.read_csv(csv_file, chunksize=20)
        # Process each chunk
        for chunk in chunk_iter:
            process_chunks(chunk.head(20))  # Extract the first 20 entries from each chunk
    except Exception as e:
        print(f"Error parsing {csv_file}: {str(e)}")

# Concatenate all dataframes
final_df = pd.concat(dfs, ignore_index=True)

# Print the first few rows of the concatenated dataframe
# print(final_df.head())

processing file: /Users/gourav/Documents/ProjectsRepos/AI/Metadata/Docs/lei.csv
processing file: /Users/gourav/Documents/ProjectsRepos/AI/Metadata/Docs/relation__cityGroup.csv


In [47]:
tableinfo_dir = "/Users/gourav/Documents/ProjectsRepos/AI/Metadata/CSVTableInfo"
!mkdir {tableinfo_dir}

mkdir: /Users/gourav/Documents/ProjectsRepos/AI/Metadata/CSVTableInfo: File exists


In [48]:
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.llms.openai import OpenAI


# class TableInfo(BaseModel):
#     """Information regarding a structured table."""

#     table_name: str = Field(
#         ..., description="table name (must be underscores and NO spaces)"
#     )
#     table_summary: str = Field(
#         ..., description="short, concise summary/caption of the table"
#     )
#     foreign_key: str = Field(
#         ..., description="Use the column that is common in tables and create foreign key mapping between two tables"
#     )

class TableInfo(BaseModel):
    """Information regarding a structured table."""
    
    table_name: str = Field(
        ..., description="table name (must be underscores and NO spaces)"
    )
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )
    


prompt_str = """\
Give me a summary of the table with the following JSON format.

- The table name must be unique to the table and describe it while being concise. 
- Do NOT output a generic table name (e.g. table, my_table).

Do NOT make the table name one of the following: {exclude_table_name_list}

Table:
{table_str}

Summary: """

program = LLMTextCompletionProgram.from_defaults(
    output_cls=TableInfo,
    llm=OpenAI(model="gpt-3.5-turbo"),
    prompt_template_str=prompt_str,
)

In [49]:
import json

# Define a function to get table information with a given index
def _get_tableinfo_with_index(idx: int) -> str:
    # Use Pathlib to find all files with the given index pattern
    results_gen = Path(tableinfo_dir).glob(f"{idx}_*")
    # Convert the generator to a list
    results_list = list(results_gen)
    # Check if there are any matching files
    if len(results_list) == 0:
        # Return None if no matching files are found
        return None
    elif len(results_list) == 1:
        # Return the table information if only one file is found
        path = results_list[0]
        return TableInfo.parse_file(path)
    else:
        # Raise an error if more than one file is found
        raise ValueError(
            f"More than one file matching index: {list(results_gen)}"
        )

# Initialize an empty set to store table names
table_names = set()
# Initialize an empty list to store table information
table_infos = []

# Iterate through each dataframe in the dfs list
for idx, df in enumerate(dfs):
    # Get the table information with the current index
    table_info = _get_tableinfo_with_index(idx)
    # Check if the table information is not None
    if table_info:
        # Append the table information to the list
        table_infos.append(table_info)
    else:
        # If the table information is None, enter an infinite loop
        while True:
            # Convert the first 10 rows of the dataframe to a CSV string
            df_str = df.head(10).to_csv()
            # Call the program function with the CSV string and a list of excluded table names
            table_info = program(
                table_str=df_str,
                exclude_table_name_list=str(list(table_names)),
            )
            # Get the table name from the table information
            table_name = table_info.table_name
            # Print the processed table name
            print(f"Processed table: {table_name}")
            # Check if the table name is already in the set of table names
            if table_name not in table_names:
                # Add the table name to the set of table names
                table_names.add(table_name)
                # Break out of the infinite loop
                break
            else:
                # If the table name already exists, print a message and try again
                print(f"Table name {table_name} already exists, trying again.")
                pass

        # Create a file name with the index and table name
        out_file = f"{tableinfo_dir}/{idx}_{table_name}.json"
        # Dump the table information to a JSON file
        json.dump(table_info.dict(), open(out_file, "w"))
    # Append the table information to the list
    table_infos.append(table_info)

Processed table: Entity_Information_Table
Processed table: Relationship_Information_Table
Processed table: Ultimate_Consolidation_Relationship_Table
Processed table: Ultimate_Consolidation_Relationship_Table_Summary
Processed table: Ultimate_Consolidation_Relationship_Summary_Table
Processed table: Ultimate_Consolidation_Relationship_Summary
Processed table: Ultimate_Consolidation_Relationship_Summary
Table name Ultimate_Consolidation_Relationship_Summary already exists, trying again.
Processed table: Ultimate_Consolidation_Relationship_Summary_Table
Table name Ultimate_Consolidation_Relationship_Summary_Table already exists, trying again.
Processed table: Ultimate_Consolidation_Relationship_Summary_Table_2
Processed table: Ultimate_Consolidation_Relationship_Summary_Table_3
Processed table: Ultimate_Consolidation_Relationship_Summary_Table_4
Processed table: Ultimate_Consolidation_Relationship_Summary_Table_5
Processed table: Ultimate_Consolidation_Relationship_Summary_Table_6
Process

KeyboardInterrupt: 

In [38]:
# Import necessary modules from sqlalchemy
from sqlalchemy import (
    create_engine,  # Create a database engine
    MetaData,  # Create a metadata object
    Table,  # Create a table object
    Column,  # Create a column object
    String,  # Define a string data type
    Integer,  # Define an integer data type
)

import re  # Import the re module for regular expressions
import pandas as pd  # Import pandas for data manipulation
from tqdm import tqdm  # Import tqdm for progress bars

# Define a function to sanitize column names by removing special characters and replacing spaces with underscores
def sanitize_column_name(col_name):
    # Use regular expression to replace one or more non-alphanumeric characters with an underscore
    return re.sub(r"\W+", "_", col_name)

# Define a function to create a table from a DataFrame using SQLAlchemy
def create_table_from_dataframe(
    df: pd.DataFrame,  # Input DataFrame
    table_name: str,  # Table name
    engine,  # Database engine
    metadata_obj  # Metadata object
):
    # Sanitize column names by replacing special characters and spaces with underscores
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    # Dynamically create columns based on DataFrame columns and data types
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]

    # Create a table with the defined columns
    table = Table(table_name, metadata_obj, *columns)

    # Create the table in the database
    metadata_obj.create_all(engine)

    # Insert data from DataFrame into the table
    with engine.connect() as conn:
        total_rows = len(df)
        with tqdm(total=total_rows, desc=f"Inserting data into {table_name}") as pbar:
            for _, row in df.iterrows():
                insert_stmt = table.insert().values(**row.to_dict())
                conn.execute(insert_stmt)
                pbar.update(1)

        conn.commit()

# Create a SQLite engine in memory
engine = create_engine("sqlite:///:memory:")

# Create a metadata object
metadata_obj = MetaData()

# Assuming dfs is a list of DataFrames
for idx, df in enumerate(dfs):
    # Get table information with the current index
    tableinfo = _get_tableinfo_with_index(idx)
    # Get the table name from the table information, or use a default name if tableinfo is None
    table_name = tableinfo.table_name if tableinfo else f"Table_{idx}"
    print(f"Creating table: {table_name}")
    # Create a table from the DataFrame
    create_table_from_dataframe(df, table_name, engine, metadata_obj)
    # Print the shape of the DataFrame
    print(f"DataFrame shape for {table_name}: {df.shape}")

Creating table: LegalEntityInfo


Inserting data into LegalEntityInfo: 100%|████████| 1/1 [00:00<00:00,  4.26it/s]


DataFrame shape for LegalEntityInfo: (1, 338)
Creating table: RelationshipInfo


Inserting data into RelationshipInfo: 100%|███| 15/15 [00:00<00:00, 1400.68it/s]

DataFrame shape for RelationshipInfo: (15, 54)





In [26]:
# Import necessary modules from llama_index
from llama_index.core.objects import (
    SQLTableNodeMapping,  # Import SQLTableNodeMapping class
    ObjectIndex,  # Import ObjectIndex class
    SQLTableSchema,  # Import SQLTableSchema class
)
from llama_index.core import SQLDatabase, VectorStoreIndex  # Import SQLDatabase and VectorStoreIndex classes

# Create a SQLDatabase object with the given engine
sql_database = SQLDatabase(engine)

# Create a SQLTableNodeMapping object with the SQLDatabase object
table_node_mapping = SQLTableNodeMapping(sql_database)

# Create the schema of the table from table infos
table_schema_objs = []

# Create a list of SQLTableSchema objects, one for each table in table_infos
# table_schema_objs = [
#     SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
#     for t in table_infos
# ]  # add a SQLTableSchema for each table

for idx in range(len(table_infos)):
    table_data = _get_tableinfo_with_index(idx)
    table_schema_objs.append(SQLTableSchema(table_name=table_data.table_name, context_str=table_data.table_summary))

# Create an ObjectIndex object from the list of SQLTableSchema objects, table_node_mapping, and VectorStoreIndex
obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)

# Create a retriever object from the ObjectIndex object
obj_retriever = obj_index.as_retriever()

In [36]:
table_schema_objs

[SQLTableSchema(table_name='LegalEntityInfo', context_str='Information about a legal entity including legal name, addresses, registration details, entity status, and associated entities.'),
 SQLTableSchema(table_name='LegalEntityInfo', context_str='Information about a legal entity including legal name, addresses, registration details, entity status, and associated entities.'),
 SQLTableSchema(table_name='RelationshipInfo', context_str='Information about relationships between nodes with various attributes.'),
 SQLTableSchema(table_name='RelationshipInfo', context_str='Information about relationships between nodes with various attributes.')]

In [27]:
# Import necessary modules from llama_index
from llama_index.core.retrievers import SQLRetriever  # Import SQLRetriever class
from typing import List  # Import List type hint
from llama_index.core.query_pipeline import FnComponent  # Import FnComponent class

# Create a SQLRetriever object with the sql_database object
sql_retriever = SQLRetriever(sql_database)

# Define a function to get table context strings
def get_table_context_str(table_schema_objs: List[SQLTableSchema]):
    """Get table context string."""
    # Initialize an empty list to store context strings
    context_strs = []
    # Iterate over each table schema object
    for table_schema_obj in table_schema_objs:
        # Get the table information from the sql_database
        table_info = sql_database.get_single_table_info(
            table_schema_obj.table_name
        )
        # Check if the table schema object has a context string
        if table_schema_obj.context_str:
            # If it does, append the context string to the table information
            table_opt_context = " The table description is: "
            table_opt_context += table_schema_obj.context_str
            table_info += table_opt_context

        # Append the table information to the list of context strings
        context_strs.append(table_info)
    # Join the context strings with newline characters and return the result
    return "\n\n".join(context_strs)

# Create a FnComponent object with the get_table_context_str function
table_parser_component = FnComponent(fn=get_table_context_str)

In [14]:
# Import necessary modules from llama_index
from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT  # Import DEFAULT_TEXT_TO_SQL_PROMPT constant
from llama_index.core import PromptTemplate  # Import PromptTemplate class
from llama_index.core.query_pipeline import FnComponent  # Import FnComponent class
from llama_index.core.llms import ChatResponse  # Import ChatResponse class

# Define a function to parse a response to SQL
def parse_response_to_sql(response: ChatResponse) -> str:
    """Parse response to SQL."""
    # Get the message content from the response
    response = response.message.content
    # Find the index of the substring "SQLQuery:"
    sql_query_start = response.find("SQLQuery:")
    # If the substring is found
    if sql_query_start!= -1:
        # Remove the substring from the response
        response = response[sql_query_start:]
        # Remove the prefix "SQLQuery:" from the response
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    # Find the index of the substring "SQLResult:"
    sql_result_start = response.find("SQLResult:")
    # If the substring is found
    if sql_result_start!= -1:
        # Remove the substring from the response
        response = response[:sql_result_start]
    # Remove any leading or trailing whitespace from the response
    # Remove any leading or trailing "```" characters from the response
    return response.strip().strip("```").strip()

# Create a FnComponent object with the parse_response_to_sql function
sql_parser_component = FnComponent(fn=parse_response_to_sql)

# Create a PromptTemplate object with the DEFAULT_TEXT_TO_SQL_PROMPT constant and the engine dialect
text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
    dialect=engine.dialect.name
)
# Print the prompt template
print(text2sql_prompt.template)

Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use tables listed below.
{schema}

Question: {query_str}
SQLQuery: 


In [15]:
# Define a string for the response synthesis prompt
response_synthesis_prompt_str = (
    "Given an input question, synthesize a response from the query results.\n"
    "Query: {query_str}\n"
    "SQL: {sql_query}\n"
    "SQL Response: {context_str}\n"
    "Response: "
)

# Create a PromptTemplate object with the response_synthesis_prompt_str
response_synthesis_prompt = PromptTemplate(
    response_synthesis_prompt_str,
)

In [16]:
llm = OpenAI(model="gpt-3.5-turbo")

In [17]:
# Import necessary modules from llama_index
from llama_index.core.query_pipeline import (
    QueryPipeline as QP,  # Import QueryPipeline class
    Link,  # Import Link class
    InputComponent,  # Import InputComponent class
    CustomQueryComponent,  # Import CustomQueryComponent class
)

# Create a QueryPipeline object with the following modules:
qp = QP(
    modules={
        "input": InputComponent(),  # InputComponent module
        "table_retriever": obj_retriever,  # Object retriever module
        "table_output_parser": table_parser_component,  # Table output parser module
        "text2sql_prompt": text2sql_prompt,  # Text-to-SQL prompt module
        "text2sql_llm": llm,  # Text-to-SQL language model module
        "sql_output_parser": sql_parser_component,  # SQL output parser module
        "sql_retriever": sql_retriever,  # SQL retriever module
        "response_synthesis_prompt": response_synthesis_prompt,  # Response synthesis prompt module
        "response_synthesis_llm": llm,  # Response synthesis language model module
    },
    verbose=True,  # Enable verbose output
)

In [18]:
# Add a chain to the query pipeline with the following modules:
qp.add_chain(["input", "table_retriever", "table_output_parser"])
# Add a link from the input module to the text2sql_prompt module with the key "query_str"
qp.add_link("input", "text2sql_prompt", dest_key="query_str")
# Add a link from the table_output_parser module to the text2sql_prompt module with the key "schema"
qp.add_link("table_output_parser", "text2sql_prompt", dest_key="schema")
# Add a chain to the query pipeline with the following modules:
qp.add_chain(
    ["text2sql_prompt", "text2sql_llm", "sql_output_parser", "sql_retriever"]
)
# Add a link from the sql_output_parser module to the response_synthesis_prompt module with the key "sql_query"
qp.add_link("sql_output_parser", "response_synthesis_prompt", dest_key="sql_query")
# Add a link from the sql_retriever module to the response_synthesis_prompt module with the key "context_str"
qp.add_link("sql_retriever", "response_synthesis_prompt", dest_key="context_str")
# Add a link from the input module to the response_synthesis_prompt module with the key "query_str"
qp.add_link("input", "response_synthesis_prompt", dest_key="query_str")
# Add a link from the response_synthesis_prompt module to the response_synthesis_llm module
qp.add_link("response_synthesis_prompt", "response_synthesis_llm")

In [17]:
from pyvis.network import Network

net = Network(notebook=True, cdn_resources="in_line", directed=True)
net.from_nx(qp.dag)

In [18]:
# Save the network as "text2sql_dag.html"
net.write_html("text2sql_dag.html")

In [21]:
# Run the query pipeline with the following query:
response = qp.run(
    query="What are the details of the LEI 6SHGI4ZSSLCXXQSBB395 and it's Direct children and Ultimate children?"
)
# Print the response
print(str(response))

[1;3;38;2;155;135;227m> Running module input with input: 
query: What are the details of the LEI 6SHGI4ZSSLCXXQSBB395 and it's Direct children and Ultimate children?

[0m[1;3;38;2;155;135;227m> Running module table_retriever with input: 
input: What are the details of the LEI 6SHGI4ZSSLCXXQSBB395 and it's Direct children and Ultimate children?

[0m[1;3;38;2;155;135;227m> Running module table_output_parser with input: 
table_schema_objs: [SQLTableSchema(table_name='LegalEntityInfo', context_str='Information about a legal entity including legal name, addresses, registration details, entity status, and associated entities.'), SQLTableSc...

[0m[1;3;38;2;155;135;227m> Running module text2sql_prompt with input: 
query_str: What are the details of the LEI 6SHGI4ZSSLCXXQSBB395 and it's Direct children and Ultimate children?
schema: Table 'LegalEntityInfo' has columns: LEI (VARCHAR), Entity_LegalName (VARCHAR), Entity_LegalName_xmllang (VARCHAR), Entity_OtherEntityNames_OtherEntityName_

NotImplementedError: Statement "SELECT LEI, Entity_LegalName, Entity_LegalAddress_City, Entity_LegalAddress_Country, Entity_EntityStatus\nFROM LegalEntityInfo\nWHERE LEI = '6SHGI4ZSSLCXXQSBB395'\nORDER BY LEI;\nSELECT LEI, Entity_LegalName, Entity_LegalAddress_City, Entity_LegalAddress_Country, Entity_EntityStatus\nFROM LegalEntityInfo\nWHERE LEI IN (\n    SELECT Relationship_EndNode_NodeID\n    FROM RelationshipInfo\n    WHERE Relationship_StartNode_NodeID = '6SHGI4ZSSLCXXQSBB395'\n    AND Relationship_RelationshipType = 'Direct'\n)\nORDER BY LEI;\nSELECT LEI, Entity_LegalName, Entity_LegalAddress_City, Entity_LegalAddress_Country, Entity_EntityStatus\nFROM LegalEntityInfo\nWHERE LEI IN (\n    SELECT Relationship_EndNode_NodeID\n    FROM RelationshipInfo\n    WHERE Relationship_StartNode_NodeID = '6SHGI4ZSSLCXXQSBB395'\n    AND Relationship_RelationshipType = 'Ultimate'\n)\nORDER BY LEI;" is invalid SQL.

In [21]:
from llama_index.core import VectorStoreIndex, load_index_from_storage
from sqlalchemy import text
from llama_index.core.schema import TextNode
from llama_index.core import StorageContext
import os
from pathlib import Path
from typing import Dict


def index_all_tables(
    sql_database: SQLDatabase, table_index_dir: str = "table_index_dir"
) -> Dict[str, VectorStoreIndex]:
    """Index all tables."""
    if not Path(table_index_dir).exists():
        os.makedirs(table_index_dir)

    vector_index_dict = {}
    engine = sql_database.engine
    for table_name in sql_database.get_usable_table_names():
        print(f"Indexing rows in table: {table_name}")
        if not os.path.exists(f"{table_index_dir}/{table_name}"):
            # get all rows from table
            with engine.connect() as conn:
                cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
                result = cursor.fetchall()
                row_tups = []
                for row in result:
                    row_tups.append(tuple(row))

            # index each row, put into vector store index
            nodes = [TextNode(text=str(t)) for t in row_tups]

            # put into vector store index (use OpenAIEmbeddings by default)
            index = VectorStoreIndex(nodes)

            # save index
            index.set_index_id("vector_index")
            index.storage_context.persist(f"{table_index_dir}/{table_name}")
        else:
            # rebuild storage context
            storage_context = StorageContext.from_defaults(
                persist_dir=f"{table_index_dir}/{table_name}"
            )
            # load index
            index = load_index_from_storage(
                storage_context, index_id="vector_index"
            )
        vector_index_dict[table_name] = index

    return vector_index_dict


vector_index_dict = index_all_tables(sql_database)

Indexing rows in table: LegalEntityInfo
Indexing rows in table: RelationshipInfo


In [22]:
from llama_index.core.retrievers import SQLRetriever
from typing import List
from llama_index.core.query_pipeline import FnComponent

sql_retriever = SQLRetriever(sql_database)


def get_table_context_and_rows_str(
    query_str: str, table_schema_objs: List[SQLTableSchema]
):
    """Get table context string."""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        # first append table info + additional context
        table_info = sql_database.get_single_table_info(
            table_schema_obj.table_name
        )
        if table_schema_obj.context_str:
            table_opt_context = " The table description is: "
            table_opt_context += table_schema_obj.context_str
            table_info += table_opt_context

        # also lookup vector index to return relevant table rows
        vector_retriever = vector_index_dict[
            table_schema_obj.table_name
        ].as_retriever(similarity_top_k=2)
        relevant_nodes = vector_retriever.retrieve(query_str)
        if len(relevant_nodes) > 0:
            table_row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
            for node in relevant_nodes:
                table_row_context += str(node.get_content()) + "\n"
            table_info += table_row_context

        context_strs.append(table_info)
    return "\n\n".join(context_strs)


table_parser_component = FnComponent(fn=get_table_context_and_rows_str)

In [23]:
from llama_index.core.query_pipeline import (
    QueryPipeline as QP,
    Link,
    InputComponent,
    CustomQueryComponent,
)

qp = QP(
    modules={
        "input": InputComponent(),
        "table_retriever": obj_retriever,
        "table_output_parser": table_parser_component,
        "text2sql_prompt": text2sql_prompt,
        "text2sql_llm": llm,
        "sql_output_parser": sql_parser_component,
        "sql_retriever": sql_retriever,
        "response_synthesis_prompt": response_synthesis_prompt,
        "response_synthesis_llm": llm,
    },
    verbose=True,
)


In [24]:
qp.add_link("input", "table_retriever")
qp.add_link("input", "table_output_parser", dest_key="query_str")
qp.add_link(
    "table_retriever", "table_output_parser", dest_key="table_schema_objs"
)
qp.add_link("input", "text2sql_prompt", dest_key="query_str")
qp.add_link("table_output_parser", "text2sql_prompt", dest_key="schema")
qp.add_chain(
    ["text2sql_prompt", "text2sql_llm", "sql_output_parser", "sql_retriever"]
)
qp.add_link(
    "sql_output_parser", "response_synthesis_prompt", dest_key="sql_query"
)
qp.add_link(
    "sql_retriever", "response_synthesis_prompt", dest_key="context_str"
)
qp.add_link("input", "response_synthesis_prompt", dest_key="query_str")
qp.add_link("response_synthesis_prompt", "response_synthesis_llm")

In [25]:
response = qp.run(
    query="What are the details of the LEI 6SHGI4ZSSLCXXQSBB395 and its related legal entity?"
)
print(str(response))

[1;3;38;2;155;135;227m> Running module input with input: 
query: What are the details of the LEI 6SHGI4ZSSLCXXQSBB395 and its related legal entity?

[0m[1;3;38;2;155;135;227m> Running module table_retriever with input: 
input: What are the details of the LEI 6SHGI4ZSSLCXXQSBB395 and its related legal entity?

[0m[1;3;38;2;155;135;227m> Running module table_output_parser with input: 
query_str: What are the details of the LEI 6SHGI4ZSSLCXXQSBB395 and its related legal entity?
table_schema_objs: [SQLTableSchema(table_name='LegalEntityInfo', context_str='Information about a legal entity including legal name, addresses, registration details, entity status, and associated entities.'), SQLTableSc...

[0m[1;3;38;2;155;135;227m> Running module text2sql_prompt with input: 
query_str: What are the details of the LEI 6SHGI4ZSSLCXXQSBB395 and its related legal entity?
schema: Table 'LegalEntityInfo' has columns: LEI (VARCHAR), Entity_LegalName (VARCHAR), Entity_LegalName_xmllang (VARCHAR), 

NotImplementedError: Statement "SELECT LEI, Entity_LegalName, Entity_LegalJurisdiction, Entity_EntityCategory, Entity_EntityStatus\nFROM LegalEntityInfo\nWHERE LEI = '6SHGI4ZSSLCXXQSBB395'\nUNION\nSELECT Relationship_EndNode_NodeID, Relationship_RelationshipType, Relationship_RelationshipStatus\nFROM RelationshipInfo\nWHERE Relationship_StartNode_NodeID = '6SHGI4ZSSLCXXQSBB395'\nORDER BY Relationship_RelationshipType;" is invalid SQL.