# 4-LangGraph

In [None]:
!pip install langgraph -q --disable-pip-version-check

In [None]:
import os
import boto3
import pandas as pd
redshift_client = boto3.client('redshift-data')

In [None]:
%run ../utilities/utils.py
%run ../utilities/prompt_utils.py
%run ../utilities/bedrock_utils.py
%run ../utilities/database_utils.py

In [None]:
from IPython.display import Image, display
def visualize_graph(graph):
    try:
        display(Image(graph.get_graph().draw_mermaid_png()))
    except Exception:
        # This requires some extra dependencies and is optional
        pass
    return

## Low-complexity graph

In [None]:
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages

class BaseState(TypedDict):
    # Messages have the type "list". The `add_messages` function
    # in the annotation defines how this state key should be updated
    # (in this case, it appends messages to the list, rather than overwriting them)
    messages: Annotated[list, add_messages]

In [None]:
def create_simple_graph():
    graph_builder = StateGraph(BaseState)
    
    def get_context(state: BaseState):
        # When context is already extracted to file, read from file.
        # If not, extrac context.
        context_file_path = "db_schema.txt"
        if os.path.exists(context_file_path):
            with open(context_file_path, "r") as f:
                db_schema = f.read()
        else:            
            schemas_dict = get_all_table_schema()
            db_schema = "\n\n".join(schemas_dict)
            
            with open(context_file_path, "w") as f:
                f.write(db_schema)
        return {
            "messages": state["messages"] + [db_schema]
        }
    
    # The first argument is the unique node name
    # The second argument is the function or object that will be called whenever
    # the node is used.
    graph_builder.add_node("context", get_context)
    
    # This tells our graph how to navigate through the graph.
    # The first argument is the source and second argument is the sink.
    # Direction of graph: source --> sink.
    graph_builder.add_edge(START, "context")
    graph_builder.add_edge("context", END)
    graph = graph_builder.compile()
    
    return graph

In [None]:
simple_graph = create_simple_graph()
visualize_graph(simple_graph)

In [None]:
result = simple_graph.invoke(BaseState(messages=[]))
print(f"Simple graph result: {result}")

## Medium-complexity graph

In [None]:
class SqlState(TypedDict):
    # Messages have the type "list". The `add_messages` function
    # in the annotation defines how this state key should be updated
    # (in this case, it appends messages to the list, rather than overwriting them)
    user_input: str
    num_sql_attempt: int
    if_sql_valid: bool
    context: str
    sql_query: str
    feedback: str
    error: str
    data: pd.DataFrame
    if_viz_success: bool
    viz_code: str
    response: str
#    response: Annotated[list, add_messages]
    
def create_sql_graph():
    graph_builder = StateGraph(SqlState)
    
    def get_context(state: BaseState):
        # When context is already extracted to file, read from file.
        # If not, extrac context.
        context_file_path = "db_schema.txt"
        if os.path.exists(context_file_path):
            with open(context_file_path, "r") as f:
                db_schema = f.read()
        else:            
            schemas_dict = get_all_table_schema()
            db_schema = "\n\n".join(schemas_dict)
            
            with open(context_file_path, "w") as f:
                f.write(db_schema)
        return {
            "context": db_schema
        }
    
    def run_query_no_feedback_wrapper(state: SqlState):
        print("Running txt2sql...")
        if_passed, response, sql_query, current_attempt = run_query_no_feedback(question = state["user_input"], 
                                                                                db_schema = state["context"],
                                                                               prompt_callback=None)
        
        update = {
            "num_sql_attempt": current_attempt,
            "if_passed": if_passed,
            "sql_query": sql_query,
            "feedback": "",
        }
        
        if if_passed:
            update['data'] = response
        else:
            update['error'] = response['Error']

        return update
    
    def run_query_with_feedback_wrapper(state: SqlState):
        
        if_passed, current_feedback, response, sql_query, current_attempt = run_query_with_feedback(user_query = state["user_input"], 
                                                                                                    prev_feedback = state["feedback"],
                                                                                                    prev_sql_query = state["sql_query"], 
                                                                                                    prev_response = {"Error": state['error']}, 
                                                                                                    num_attempt = state["num_sql_attempt"], 
                                                                                                    db_schema= state["context"])
        update = {
            "num_sql_attempt": current_attempt,
            "if_passed": if_passed,
            "sql_query": sql_query,
            "feedback": current_feedback,
        }
        
        if if_passed:
            update['data'] = response
        else:
            update['error'] = response['Error']

        return update
        
    def check_for_valid_sql(state: SqlState):
        max_attempts = 10
        if state['if_passed']:
            return 'valid'
        elif state['num_sql_attempt'] == max_attempts:
            return 'max_attempt_reached'
        else:
            return 'not_valid'
        
    def generate_viz_wrapper(state: SqlState):
        
        if_viz_success, viz_code = generate_viz_v1(sql_query = state['sql_query'], execute=False)
        
        return {"if_viz_success": if_viz_success, "viz_code":viz_code}
        
    
    
    def generate_response(state: SqlState):
        
        prompt_template = """Generate a natural language response to the <quesiton> below using the <data_table>.
        
<data_table>
{data_table}
</data_table>

<question>
{question}
</question>
        """
        prompt = prompt_template.format(question=state['user_input'], data_table=state['data'].to_string())
        response = invoke_model(prompt, SONNET35_MODEL_ID)
        
        return {"response" : response}
        
    
    # The first argument is the unique node name
    # The second argument is the function or object that will be called whenever
    # the node is used.
    graph_builder.add_node("get_context", get_context)
    graph_builder.add_node("txt_to_sql", run_query_no_feedback_wrapper)
    graph_builder.add_node("fix_sql", run_query_with_feedback_wrapper)
    graph_builder.add_node("sql_to_viz", generate_viz_wrapper)
    graph_builder.add_node("generate_response", generate_response)
    
    graph_builder.add_edge(START, "get_context")
    graph_builder.add_edge("get_context", "txt_to_sql")
    graph_builder.add_conditional_edges('txt_to_sql', check_for_valid_sql, {
        "valid" : "sql_to_viz", 
        "not_valid" : "fix_sql"
    })
    
    
    
    graph_builder.add_conditional_edges('fix_sql', check_for_valid_sql, {
        "valid" : "sql_to_viz", 
        "not_valid" : "fix_sql",
        "max_attempt_reached": END
    })
    graph_builder.add_edge("sql_to_viz", "generate_response")
    graph_builder.add_edge("generate_response", END)
    
    graph = graph_builder.compile()
    
    return graph

In [None]:
sql_graph = create_sql_graph()
visualize_graph(sql_graph)

In [None]:
usr_query = "How many games were played each year?"
# usr_query = "How has the average height of the players changed over the decades?"
result = sql_graph.invoke(SqlState(user_input=usr_query, num_sql_attempt=0))
print(result['response'])

In [None]:
result['data']

In [None]:
print(result['viz_code'])

In [None]:
exec(result['viz_code'])

## Future extensions

* Agentic workflows:
    * Opportunity: Use functions as tools and use agents to orchestrate the workflow
* Attempts:
    * Challenge: When the solution is going down a rabbit hole (i.e. similar to optimizing over a local minimum/maximum)
    * Solution1: Instead of fixing the sql, try generating a new sql. 
    * Solution2: Change temperature to introduce randomness/change.
    * Solution3: Use LLM to rewrite the user query