In [1]:
from langchain.prompts import PromptTemplate,ChatPromptTemplate
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import JsonOutputParser,StrOutputParser

from langchain import hub
from typing_extensions import TypedDict
from typing import List
from langchain.schema import Document
from langgraph.graph import END, StateGraph
from pprint import pprint
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import Chroma
from langchain.memory import ChatMessageHistory
from langchain_core.runnables import RunnablePassthrough

from sentence_transformers import SentenceTransformer
from langchain_community.embeddings import HuggingFaceEmbeddings
hf = HuggingFaceEmbeddings(
    model_name='all-MiniLM-L12-v2'
)

import os
from dotenv import load_dotenv
load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
import sqlite3
import os
import pandas as pd
import re



import sys
sys.path.insert(0, '../')

from xgb_process import shap_summary,xgb_model
from utils import generate_text_summary,create_context

#### Load the Telco Churn Data

In [3]:
###Load the table and add to database
df0=pd.read_csv("..\\data\\telco_churn_data.csv",index_col=False)
df0.columns = [x.lower() for x in df0.columns]
cat_cols = ['gender', 'seniorcitizen', 'partner', 'dependents', 'phoneservice', 'multiplelines', 'internetservice',
            'onlinesecurity', 'onlinebackup', 'deviceprotection', 'techsupport', 'streamingtv', 'streamingmovies',
            'contract', 'paperlessbilling', 'paymentmethod']
num_cols = ['tenure', 'monthlycharges', 'totalcharges']
target = 'churn'

# Ensure the target variable and all feature names in the lists are exactly as in the DataFrame
if set(cat_cols + num_cols + [target]).issubset(df0.columns):
    print("All required columns are in the DataFrame.")
else:
    missing_cols = set(cat_cols + num_cols + [target]) - set(df0.columns)
    print("Missing columns:", missing_cols)

##Ensure target is numeric
df0[target] = df0[target].apply(lambda x: 1 if x == 'Yes' else 0)
# Convert categorical columns to 'category' dtype
df0[cat_cols] = df0[cat_cols].astype('category')

# Convert numeric columns to 'numeric' dtype
for col in num_cols:
    df0[col] = pd.to_numeric(df0[col], errors='coerce')
df0['customerid']=df0['customerid'].astype('category')



# Initialize and use the model
model_instance = xgb_model.XGBoostModel(df=df0, cat_features=cat_cols, num_features=num_cols, target=target, mode='dev')
model,dtrain,X_train,dtest,X_test=model_instance.train_model()


analyzer=shap_summary.ShapAnalyzer(model=model,
                                X_train=X_train,
                                dtrain=dtrain,
                                cat_features=cat_cols,
                                num_features=num_cols)

result_df = analyzer.analyze_shap_values()
summary_df = analyzer.summarize_shap_df()

# Convert object types to category and numeric types to their respective numeric types
for col in summary_df.columns:
    if summary_df[col].dtype == 'object':
        summary_df[col] = summary_df[col].astype('category')
    elif pd.api.types.is_numeric_dtype(summary_df[col]):
        summary_df[col] = pd.to_numeric(summary_df[col])

All required columns are in the DataFrame.
ROC AUC Score: 0.84


#### Class to Add table to Database and to Query DataBase

In [4]:
class AddTabletoDatabase:
    def __init__(self, db_name, db_path='.'):
        self.db_path = os.path.join(db_path, db_name)
        self.conn = sqlite3.connect(self.db_path)
        print(f"Database {db_name} connected at {db_path}")

    def add_table(self, df, table_name):
        dtype_dict = {col: 'TEXT' if ((df[col].dtype=='category')|(df[col].dtype=='object')) else 'REAL' for col in df.columns}
        df.to_sql(table_name, self.conn, if_exists='replace', index=False, dtype=dtype_dict)
        print(f"Table {table_name} Added.")

    def delete_table(self, table_name):
        cur = self.conn.cursor()
        cur.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';")
        if cur.fetchone():
            cur.execute(f"DROP TABLE {table_name};")
            print(f"Table {table_name} deleted.")
        else:
            return f"No such table: {table_name}"

In [5]:
db = AddTabletoDatabase('my_database.db', '../sqldatabase')
# db.add_table(df0, 'telco_data')
db.add_table(summary_df, 'shap_summary')
#db.delete_table('contract_info')

Database my_database.db connected at ../sqldatabase
Table shap_summary Added.


#### Setting Up the Vector DataBase

We will use Vector database for data dictionary and data summary Only. Any information about the model SHAP will be using SQL retriever

In [6]:
text = [
    "..//documents//data_dictionary.txt",
    "..//documents//data_summary.txt",
]

docs = [TextLoader(url).load() for url in text]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=400, chunk_overlap=100
)
doc_splits = text_splitter.split_documents(docs_list)

# Add to vectorDB
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name="churn-rag-chroma-3",
    embedding=hf,
    
)

#### Util Functions Needed for the agent 

Some Util FUnctions are required to Get Schema, execute query, fomart documents etc.
They will be defined below

In [7]:
def format_docs(docs):
    """
    Concatenates the page content of each document in the provided list.

    This function takes a list of document objects and concatenates the 
    'page_content' attribute of each document, separating them with two newline 
    characters.

    Args:
        docs (list): A list of document objects. Each object is expected to have 
                     a 'page_content' attribute.

    Returns:
        str: A string containing the concatenated page content of all documents.
    """
    return "\n\n".join(doc.page_content for doc in docs)


def get_all_schemas(db_name="my_database.db", db_path="../sqldatabase"):
    """
    Retrieves all tables and their schemas from a SQLite database.

    Args:
        db_name (str): The name of the SQLite database file.
        db_path (str): The directory path where the SQLite database file is located.

    Returns:
        str: A string representation of all tables and their schemas in the database.
    """
    db_path = os.path.join(db_path, db_name)
    conn = sqlite3.connect(db_path)
    cur = conn.cursor()
    
    # Get all tables
    cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [table[0] for table in cur.fetchall()]
    
    output = "Tables:\n"
    output += "\n".join(tables)
    output += "\n\nSchemas:\n"
    
    for table in tables:
        # Get table schema
        cur.execute(f"PRAGMA table_info({table});")
        schema = [(column[1], column[2], 'NULLABLE' if column[3] == 0 else 'NOT NULL') for column in cur.fetchall()]
        
        output += f"{table}:\n"
        for column in schema:
            output += f"  Name: {column[0]}, Type: {column[1]}, Mode: {column[2]}\n"
        
        output += "\nDistinct Values for Categorical Variables:\n"
        for column_name, column_type, _ in schema:
            if column_type == 'TEXT':
                cur.execute(f'SELECT DISTINCT "{column_name}" FROM "{table}";')
                values = [row[0] for row in cur.fetchall()]
                if ((len(values) < 20) &(len(values)>0) ):
                    output += f"  {column_name}: {values}\n"
            elif column_type == 'REAL':
                cur.execute(f'SELECT AVG("{column_name}") FROM "{table}";')
                mean_value = cur.fetchone()[0]
                output += f"  Mean value of `{column_name}`: {mean_value}\n"  # Enclose column name in backticks
    
    conn.close()
    return output


# def get_schema(_):
#     """
#     Wrapper function for get_all_schemas with no arguments.

#     Returns:
#         str: A string representation of all tables and their schemas in the database.
#     """
#     return get_all_schemas()

# def get_messages(_):
#     """
#     Retrieves and concatenates all messages from the global 'hist' object.

#     Returns:
#         str: A string containing all messages, separated by three newline characters.
#     """
#     messages = [message.content for message in hist.messages]
#     concatenated_messages = '\n\n\n '.join(messages)
#     return concatenated_messages

# def execute_query_with_retries(my_query,database='my_database.db', max_attempts=5):
#     """
#     Attempts to execute a SQL query on a SQLite database with a specified number of retries.

#     Args:
#         my_query (str): The SQL query to execute.
#         database (str): The name of the SQLite database file.
#         max_attempts (int): The maximum number of attempts to execute the query.

#     Returns:
#         DataFrame, str: A pandas DataFrame containing the query results and the cleaned SQL query.
#         If the query fails, returns None and the cleaned SQL query.
#     """

#     conn = sqlite3.connect(database)
#     cur = conn.cursor()
#     attempts = 0
#     while attempts < max_attempts:
#         attempts += 1
#         print(f"Attempt {attempts} of {max_attempts}")
#         # Invoke the external SQL service
#         print("Generating the SQL")
#         # cur.execute(my_query)
#         #result = cur.fetchall()
#         # answer_generation=sql_generator_rag_chain.invoke({"question": my_query,"schema": agent.get_all_schemas(),"messages":hist.messages})
#         answer_generation=sql_response.invoke({"question": my_query})
#         clean_sql = extract_sql(answer_generation)
#         print(f"SQL Query: {clean_sql}")
#         if clean_sql=='None':
#             print("No query logic found for this question from database")
#         else:
#             try:
#                 print("Attempting to run the query and convert it to a DataFrame")
#                 dataframe = agent.execute_query(clean_sql)
#                 if dataframe.shape[0]==0:
#                     print("Empty DataFrame returned")
#                     hist.add_user_message(clean_sql + ': ' + "Empty DataFrame returned. If using any ID filter make sure it matches the users input")
#                     if attempts == max_attempts:
#                         print("Reached maximum attempt limit. Stopping retries.")
#                         return [],clean_sql  # Return None if all retries fail
#                 else:
#                     print("Query executed successfully.")
#                     return dataframe,clean_sql
#             except Exception as e:
#                 # Print or store the error message
#                 error_message = str(e)
#                 print("Query failed with the following error:")
#                 print(error_message)
#                 hist.add_user_message(clean_sql + ': ' + error_message)
#                 if attempts == max_attempts:
#                     print("Reached maximum attempt limit. Stopping retries.")
#                     return None  # Return None if all retries fail
#                     return [],clean_sql
#     conn.close()

def execute_query(query,database="my_database.db",db_path="../sqldatabase"):
        
        db_path = os.path.join(db_path, database)
        conn = sqlite3.connect(db_path)
        cur = conn.cursor()
        cur.execute(query)
        rows = cur.fetchall()
        column_names = [column[0] for column in cur.description]
        return pd.DataFrame(rows, columns=column_names)

def extract_sql(input_text):
    """
    Extracts SQL code from a string.

    The function looks for SQL code enclosed in triple backticks. If no triple backticks are found,
    it returns the input string as is.

    Args:
        input_text (str): The input string.

    Returns:
        str: The extracted SQL code, or the input string if no SQL code is found.
    """
    # Check if the input contains triple backticks
    if '```' in input_text:
        # Regex to extract content within triple backticks
        pattern = re.compile(r'```(.*?)```', re.DOTALL)
        match = pattern.search(input_text)
        if match:
            return match.group(1).strip()  # Return the cleaned, extracted SQL
    # If no triple backticks are found, return the input as is
    return input_text.strip()

#### Defining Components for The SQL Retriver LangGraph Agents

In [8]:
##Agent 1- Retrieval Grader
retrieval_prompt = PromptTemplate(
        template="""You are a grader assessing relevance of a retrieved document to a user question. \n 
        Here is the retrieved document: \n\n {document} \n\n
        Here is the user question: {question} \n
        If the document contains keywords related to the user question, grade it as relevant. \n
        It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n
        Provide the binary score as a JSON with a single key 'score' and no premable or explaination.""",
        input_variables=["question", "document"],
    )

llm = ChatOllama(model="llama3:latest", format="json", temperature=0)
retrieval_grader = retrieval_prompt | llm | JsonOutputParser()
retriever=vectorstore.as_retriever(search_type='similarity',search_kwargs={'k': 10})

In [9]:
##Agent 2 -  Business Analyst Agent
buisiness_analyst_prompt = PromptTemplate(
    template=""" You are a business analyst for a telecom company. 
            Your job is to create a report for senior executives on questions asked. 
            Answer ONLY with stats you got from the context provided.
            DO NOT add any erronous stats.

            Context: {context}

            Question: {question}""",
    input_variables=["question", "context"],
)

generator_llm = ChatOllama(model="llama3:latest",temperature=0)
# Chain
business_analyst_rag_chain = buisiness_analyst_prompt | generator_llm | StrOutputParser()

In [10]:
##Agent 3 - Hallucination Grader 

llm = ChatOllama(model="llama3:latest", format="json", temperature=0)
# Prompt
hallucination_prompt = PromptTemplate(
    template="""You are a grader assessing whether an answer is grounded in / supported by a set of facts. \n 
    Here are the facts:
    \n ------- \n
    {documents} 
    \n ------- \n
    Here is the answer: {generation}
    Give a binary score 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. \n
    Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.""",
    input_variables=["generation", "documents"],
)

hallucination_grader = hallucination_prompt | llm | JsonOutputParser()

In [11]:
##Agent 4 - Answer Grader 
# Prompt
answer_grader_prompt = PromptTemplate(
    template="""You are a grader assessing whether an answer is useful to resolve a question. \n 
    Here is the answer:
    \n ------- \n
    {generation} 
    \n ------- \n
    Here is the question: {question}
    Give a binary score 'yes' or 'no' to indicate whether the answer is useful to resolve a question. \n
    Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.""",
    input_variables=["generation", "question"],
)

answer_grader = answer_grader_prompt | llm | JsonOutputParser()

In [12]:
##Agent 5 - Question Re-writer
re_write_prompt = PromptTemplate(
    template="""You are a question re-writer that converts an input question to a better version that is optimized \n 
     for vectorstore retrieval. Look at the initial and formulate an improved question. \n
     Here is the initial question: \n\n {question}.
     ONLY RETURN THE REWRITTEN QUESTION.\n
     DONOT ADD ANY TEXT OTHER THAN REWRITTEN QUESTION.\n
     Improved question with no preamble: \n """,
    input_variables=["question"],
)
rewrite_llm = ChatOllama(model="llama3:latest", temperature=0)
question_rewriter = re_write_prompt | rewrite_llm | StrOutputParser()

In [13]:
##Agent 6 - SQL Question Reframer

sql_qprompt = PromptTemplate(
    template="""Given the user's question and the database schema, your task is to reframe the question to better suit the data available.

    Remember:
        - 'feature_group' refers to different groups within a feature.
        - 'feature_effect' indicates if a feature increases or decreases the contribution towards the target.
        - 'probability_contribution' measures the feature group's contribution towards the target.
        - 'Feature_Importance_Rank' shows the importance of the feature in model predictions; a lower rank means higher importance.
        - The target outcome to analyze is 'churn'.

            Schema: {schema}

            Sample User Question: "What are the top N reasons that *prevent* Churn?"
            Expected Reframed Question: "Select top N feature groups along with feature based on the highest probability contribution and feature effect as decreases."

            Sample Question: "What are the top N reasons *for* Churn?"
            Expected Reframed Question: "Select top N feature groups along with feature based on the highest probability contribution and feature effect as increases."

            Sample User Question: "What are the top N features?"
            Expected Reframed Question: "Select top N distict features in the increasing order of feature importance rank."

            Always add feature_effect,probability_contribution and Feature_Importance_Rank in the final table if it was used.

            Please return ONLY the reframed question.

            User Question : {question}
    """,input_variables=["question", "schema"],
)


sql_q_llm = ChatOllama(model="llama3:latest",temperature=0)
sql_question_chain = sql_qprompt | sql_q_llm | StrOutputParser()

In [14]:
## Agent 7 - SQL Query Generator
sql_query_prompt = PromptTemplate(
        template="""Based on the Sqlite database schema below, and the message history, write a
        SQL query that answers the question/request.
        DONOT add sql before the query generated. Return the query only.
        Only return a query if a possible query logic exists.
        Reply None if there is no possible query for given question from schema


        Question: {sql_question}

        Remember to UNNEST repeated records and make sure only to use exisiting fields in the schema:

        schema:{schema}

        **If question asks for Top N remember to return top N records as specified in the order as instructed**
        **Never sort Feature_Importance_Rank in DESC order**


        Query should include all the columns needed to answer the question.
        All columns used in filter and groups should be added in final table

        SQL Query:""",input_variables=["sql_question", "schema"],
    
)

sql_generator_llm = ChatOllama(model="llama3:latest",temperature=0)
#sql_query_chain = RunnablePassthrough.assign(schema=get_schema, messages=get_messages)|sql_query_prompt|sql_generator_llm|StrOutputParser()
sql_query_chain = sql_query_prompt | sql_generator_llm | StrOutputParser()

In [15]:
## Agent 8 - SQL Query Validator

query_evaluator_prompt=PromptTemplate(
    template="""Your task is to assess a sql query and see if it is appropriate to answer the question. You will be provided the question, query and the schema of the database query will be ran on.


        schema:{schema}

        question : {question}

        query : {query}

        Give a binary score 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. \n
        Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.""",
                                        input_variables=["schema", "question","query"],
                                            
                                        )

query_evaluator_llm = ChatOllama(model="llama3:latest", format="json", temperature=0)
query_evaluator_chain = query_evaluator_prompt|query_evaluator_llm|JsonOutputParser()
             

In [16]:
## Agent 9 - Feature Description

feature_description_prompt = PromptTemplate(
    template=""" You are an expert document researcher. Your job is to analyze multiple sources to provide detailed descriptions for the question passed.
      Each description should be concise, informative, and directly related to the values in question. Use the following guidelines for your research:

    1. **Accuracy**: Ensure that the information you provide is accurate and sourced from credible documents.
    2. **Relation**: For each feature provide information about different groups present
    3. Donot add any extra information unrelated to the request
    4. Donot provide any extra summary
    5. Add note if the feature belongs to a static, action or dynamic context. Just mention which one it belongs to

    Here is the question:

            question: {question}

    Here is the context:

            context: {context}
""",
    input_variables=["question","context"],
)

generator_llm = ChatOllama(model="llama3:latest",temperature=0)
# Chain
feature_description_rag_chain = feature_description_prompt | generator_llm | StrOutputParser()

In [17]:
## Agent 10 - Manager Agent
manager_prompt = PromptTemplate(
    template=""" You are manager of data science team for a telecom company. 
                You get different reports on churn from your business analyst in your team.
                Your job is to format and proof read the report provided to you by the business analyst and create a enriched report for senior executive team.
                The end user might not be tech savvy, so keep the language easy to understand
                Answer should be properly formatted and user friendly and easy to understand.

                feature : The general feature
                feature_group : The subgroup of feature that provides churn contribution
                feature_effect : specifies if it increases or decreases churn
                probability_contribution : How much the feature group increases or decreases churn based on the baseline

            
            Question: {question}

            Use the below report from business analyst to create final report. Look athe feature groups carefully and explain based on each and every feature groups:

            Report: {context}

            DONOT add any extra information not available to you. Answer from the infomration provided ONLY.
            Interpret decimal points properly. DONOT read 85.12 as 8512.

            """,
    input_variables=["question", "context"],
)
llm=ChatOllama(model="llama3:latest",temperature=0)
manager_rag_chain = manager_prompt | llm | StrOutputParser()

#### Creating the SQL Agent Graph Functions

In [20]:
class SqlGraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        sql_question: reframed question for SQL query generation
        sql_query: Sql Query Gernerated by Agent
        schema: Schema of SQL database
        hist: History of SQL Queries created 
        sql_output : Output from SQL Query execution
        feature_lookup : description for each important features
        output : final output from SQL Agent
        report : final report from Manager Agent
    """
    question : str
    sql_question : str
    sql_query : str
    schema : str
    hist: str
    sql_output : pd.DataFrame
    feature_lookup: str
    output:str
    report: str


### Nodes

def question_reframe(state):

    print("---Reframe Question for SQL Agent---")
    question = state["question"]
    
    sql_question = sql_question_chain.invoke({"schema": get_all_schemas(), "question": question})
    return {"question": question, "sql_question": sql_question}

def query_creation(state):

    print("---Creating the SQL Query---")
    sql_question = state["sql_question"]
    question = state["question"]
    
    sql_query = sql_query_chain.invoke({"schema": get_all_schemas(), "sql_question": sql_question})
    return {"question": question, "sql_query": sql_query}


def execute_sql_query(state):

    print("--------------Running SQL Query-----------")
    if 'sql_query' not in state:
        print("No SQL query provided in the state.")
        return {"sql_output": [], "sql_query": None}

    sql_query_created = state['sql_query']
    clean_sql = extract_sql(sql_query_created)
    conn = sqlite3.connect('my_database.db')
    cur = conn.cursor()

    #print(f"SQL Query: {clean_sql}")
    if clean_sql == 'None':
        print("No query logic found for this question from database")
    else:
        try:
            print("Attempting to run the query and convert it to a DataFrame")
            dataframe = execute_query(query=clean_sql)
            dataframe = dataframe.reset_index(drop=True)
            if dataframe.shape[0] == 0:
                print("Empty DataFrame returned")
                return {"sql_output": [], "sql_query": clean_sql}
            else:
                print("Query executed successfully.")
                return {"sql_output": dataframe, "sql_query": clean_sql}
        except Exception as e:
            # Print or store the error message
            error_message = str(e)
            print("Query failed with the following error:")
            print(error_message)
            return {"sql_output": [], "sql_query": clean_sql}
    conn.close()

def add_feature_description(state):

    print("--------------Adding Feature Description-----------")
    sql_output=state["sql_output"]

    try:
        if isinstance(sql_output, pd.DataFrame) and 'feature' in sql_output.columns and not sql_output['feature'].empty:
            f_query=f_query = "Provide description for " + ', '.join(sql_output['feature'].unique())+"?"
            f_retriever=vectorstore.as_retriever(search_type='similarity',search_kwargs={'k': 3})
            documents = f_retriever.invoke(f_query)
            retrived_doc=format_docs(docs=documents)
            feature_description=feature_description_rag_chain.invoke({"question":f_query,"context":retrived_doc})
            #print(feature_description)
            return {"feature_lookup": feature_description}
        else:
            print("Either the DataFrame 'sql_output' does not exist, the column 'feature' does not exist in it, or it does not have at least one value.")
            return {"feature_lookup": ""}
    except Exception as e:
        print(f"An error occurred: {e}")
        return {"feature_lookup": ""}

def sql_report(state):
    print("--------------Combine SQL Output and Feature Description------------")
    # Convert the DataFrame to a string
    dataframe_string = state['sql_output'].to_string()

    # Convert the feature lookup list to a string
    feature_lookup_string = state['feature_lookup']

    # Combine the DataFrame string and the feature lookup string
    message = dataframe_string + "\n" + feature_lookup_string

    return {"output": message}


def manager_report(state):
    print("--------------Creating Final Report------------")

    question = state["question"]
    sql_report_output=state["output"]
    final_report=manager_rag_chain.invoke({"question":question,"context":sql_report_output})
    return {"report": final_report}


This is a simple Agent Design. We haven't added the adaptive RAG framework here. This is to explain the SQL retriver component only. We can design hierarchical agent system of RAG and SQL retriver separately built over this

In [21]:
sql_workflow = StateGraph(SqlGraphState)

# Define the nodes
sql_workflow.add_node("question_reframe", question_reframe)
sql_workflow.add_node("query_creation", query_creation)
sql_workflow.add_node("execute_sql_query", execute_sql_query)
sql_workflow.add_node("add_feature_description", add_feature_description)
sql_workflow.add_node("sql_report", sql_report)
sql_workflow.add_node("manager_report", manager_report)


# Build graph
sql_workflow.set_entry_point("question_reframe")
sql_workflow.add_edge("question_reframe", "query_creation")
sql_workflow.add_edge("query_creation", "execute_sql_query")
sql_workflow.add_edge("execute_sql_query", "add_feature_description")
sql_workflow.add_edge("add_feature_description", 'sql_report')
sql_workflow.add_edge("sql_report", "manager_report")
sql_workflow.add_edge("manager_report", END)

# Compile
sql_chain_agent = sql_workflow.compile()

In [23]:


# Run 
inputs = {"question": "What are the top 20 reasons for churn? Provide atleast 5 recommeded actions to reduce churn."}
for output in sql_chain_agent.stream(inputs):
    for key, value in output.items():
        # Node
        pprint(f"Node '{key}':")
        # Optional: print full state at each node
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint("\n---\n")

# Final generation
pprint(value["report"])

---Reframe Question for SQL Agent---
"Node 'question_reframe':"
'\n---\n'
---Creating the SQL Query---
"Node 'query_creation':"
'\n---\n'
--------------Running SQL Query-----------
Attempting to run the query and convert it to a DataFrame
Query executed successfully.
"Node 'execute_sql_query':"
'\n---\n'
--------------Adding Feature Description-----------
"Node 'add_feature_description':"
'\n---\n'
--------------Combine SQL Output and Feature Description------------
"Node 'sql_report':"
'\n---\n'
--------------Creating Final Report------------
"Node 'manager_report':"
'\n---\n'
('**Top 20 Reasons for Churn and Recommended Actions**\n'
 '\n'
 'Based on the report provided by our business analyst, we have identified the '
 'top 20 features that contribute to churn. Here are the results:\n'
 '\n'
 '1. **Total Charges**: The total amount charged to the customer increases '
 'churn by 16.70%. This suggests that customers who experience high charges '
 'are more likely to leave the company.\