In [None]:
"""
4.1.3 LangChain/LangGraph Orchestration
We will use LangChain and LangGraph (from Modules 4, 9, 10) to manage the workflow.
LangGraph will handle the conversation flow by checking question safety, retrieving
relevant passages, generating a response, validating response safety, and sending it to the
user. If something fails the safety checks, the system can try again or use a backup
response.
"""

In [None]:
# import statements
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode, tools_condition

### Creating Message State Tool Specified for Use Case

In [None]:
# in recitation, we saw an example with Messages State,
# but that does not suit the specific needs of our project
# so we define our own below:

class ChildMessagesState:
  def __init__(self):
    # to store message history
    # TODO: only store history up to a certain amount?
    # then reset
    self.messages = []
    # the user input
    self.user_query = ""
    # the RAG results
    self.retrieved_passages = []
    # system initial response
    self.response = ""
    # system output after safety check
    self.final_output = ""
    # counts how many questions have been asked
    # limit 10 to limit screen time
    self.turn_count = 0

### Creating Nodes for Graph

In [None]:
# get user input

def get_user_input(state: ChildMessagesState):
  user_input = state.user_query

  # store in message history, under role of user since user input
  state.messages.append({"role": "user", "content": user_input})

  # add one to the turn count
  state.turn_count += 1

  return state

In [None]:
# check question safety --> use function from Mia's safety filtering section

def check_input_safety(state: ChildMessagesState):
  user_input = state.user_query

  # use safety filtering to check if input is safe
  # if it is not, fall back on sfae response
  if is_input_safe(user_input) == False:
    state.final_output = """I'm sorry, that response is not safe for me to answer.
    Please ask a trusted adult for further information on this question.
    What else would you like to learn about?"""

    # add system response to history as it is final response automatically
    state.messages.append({"role": "system", "content": state.final_output})

    return END

  return state

In [None]:
# make sure that user is not spending too long on system
# within certain time period? TODO

def check_exchange_count(state: ChildMessagesState):

  # if we have reached limit, return safe message as final output
  if state.turn_count >= 10:
    state.final_output = "It was great talking to you! However, your librAIrian \
    needs to take a break. Come back later to ask more questions!"

    # add to history
    state.messages.append({"role": "system", "content": state.final_output})

    return END

  return state

In [None]:
# retrieve relevant passages

def retrieve_passages(state: ChildMessagesState):
  # using rag to retrieve passages (ada?)
  state.retrieved_passages = rag_passages(state.user_query)

  return state

In [None]:
# generate response

def generate_response(state: ChildMessagesState):
  # generate_answer from ada
  # call it using the query, the results of RAG, and message history
  state.response = generate_answer(
      state.user_query,
      state.retrieved_passages,
      state.messages
    )

  return state

In [None]:
# validate response safety

# if answer is unsafe, only try again a set number of times
max_attempts = 3

def check_answer_safety(state: ChildMessagesState):
  # make sure have not reattempted too many times already
  attempts = 0
  while attempts < max_attempts:
    # output safety check from Mia's safety filtering layer
    if is_answer_safe(state.response) == False:
      # print response letting them know this was not safe
      # not stored in messages history or state as it is a system
      # result, not the user safety
      print("Please wait one second for me to rephrase my response! \
      My initial response was not safe.")

      # try generating a new answer
      # and then reenter this loop
      state.response = generate_answer(
        state.user_query,
        state.retrieved_passages,
        state.messages
      )

      # add 1 to attempts
      attempts += 1

    # if the output is safe, then the draft response is the final response
    else:
      state.final = state.response

      # add to message history
      state.messages.append({"role": "system", "content": state.final_output})

      return state

    # if we exit the while loop because exceeded max attempts
    # give back up message
    state.final_output("I'm sorry, I reached the maximum number of attempts to \
    generate a safe response. Please try another question.")
    # add this to message history for system
    state.messages.append({"role": "system", "content": state.final_output})
    return state

### Creating Actual Graph

In [None]:
# create a graph builder to set nodes and edges using our messages state
# modeled this after the example from lab 7

graph_builder = StateGraph(ChildMessagesState)

In [None]:
# add nodes, and give them the same name
graph_builder.add_node("get_user_input", get_user_input)
graph_builder.add_node("check_input_safety", check_input_safety)
graph_builder.add_node("check_exchange_count", check_exchange_count)
graph_builder.add_node("retrieve_passages", retrieve_passages)
graph_builder.add_node("generate_response", generate_response)
graph_builder.add_node("check_answer_safety", check_answer_safety)

# add edges between subsequent pieces of the pipeline
graph_builder.set_entry_point("get_user_input")
graph_builder.add_edge("get_user_input", "check_input_safety")
graph_builder.add_edge("check_input_safety", "check_exchange_count")
graph_builder.add_edge("check_exchange_count", "retrieve_passages")
graph_builder.add_edge("retrieve_passages", "generate_response")
graph_builder.add_edge("generate_response", "check_answer_safety")
# add edge to allow for a new user query until we reach limit
graph_builder.add_edge("check_answer_safety", "get_user_input")

In [None]:
# compile into a single graph object
graph = graph_builder.compile()

# get user input
user_query = input("How can LibrAIrian help you?")

# initialize state object
state = ChildMessagesState()
# set user query to input from above
state.user_query = user_query
# generate response by invoking the graph we built
output = graph.invoke(state)
# print the final result to the user
print(output.final_output)