In [1]:
"""
4.1.3 LangChain/LangGraph Orchestration
We will use LangChain and LangGraph (from Modules 4, 9, 10) to manage the workflow.
LangGraph will handle the conversation flow by checking question safety, retrieving
relevant passages, generating a response, validating response safety, and sending it to the
user. If something fails the safety checks, the system can try again or use a backup
response.
"""

'\n4.1.3 LangChain/LangGraph Orchestration\nWe will use LangChain and LangGraph (from Modules 4, 9, 10) to manage the workflow.\nLangGraph will handle the conversation flow by checking question safety, retrieving\nrelevant passages, generating a response, validating response safety, and sending it to the\nuser. If something fails the safety checks, the system can try again or use a backup\nresponse.\n'

In [3]:
# import statements
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode, tools_condition

import re, math
import pandas as pd
from sentence_transformers import SentenceTransformer
import numpy as np
import torch
import time
from pathlib import Path
from sklearn.metrics.pairwise import cosine_similarity
import openai
import os

from transformers import pipeline
from typing import Dict, List, Set, Tuple

In [4]:
from input_filter import InputFilter, filter_input
from output_filter import OutputFilter, filter_output

In [5]:
from rag_class import RAG_GPT, create_embeddings, rag_passages, generate_response

In [None]:
openai.api_key = os.environ.get("OPENAI_API_KEY")

path_to_embeddings = create_embeddings('cleaned_merged_fairy_tales_without_eos.txt')

The Happy Prince.
HIGH above the city, on a tall column, stood the statue of the Happy Prince.  He was gilded all over with thin leaves of fine gold, for eyes he had two bright sapphires, and a large red ruby glowed on his sword-hilt.
He was very much admired indeed.  “He is as beautiful as a weathercock,” remarked one of the Town Councillors who wished to gain a reputation for having artistic tastes; “only not quite so useful,” he added, fearing lest people should think him unpractical, which he really was not.
“Why can’t you be like the Happy Prince?” asked a sensible mother of her little boy who was crying for the moon.  “The Happy Prince never dreams of crying for anything.”
“I am glad there is some one in the world who is quite happy,” muttered a disappointed man as he gazed at the wonderful statue.
“He looks just like an angel,” said the Charity Children as they came out of the cathedral in their bright scarlet cloaks and their clean white pinafores.
“How do you know?” said the M

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

### Creating Message State Tool Specified for Use Case

In [None]:
# in recitation, we saw an example with Messages State,
# but that does not suit the specific needs of our project
# so we define our own below:

class ChildMessagesState:
  def __init__(self):
    # to store message history
    # TODO: only store history up to a certain amount?
    # then reset
    self.messages = []
    # the user input
    self.user_query = ""
    # the RAG results
    self.retrieved_passages = []
    # system initial response
    self.response = ""
    # system output after safety check
    self.final_output = ""
    # counts how many questions have been asked
    # limit 10 to limit screen time
    self.turn_count = 0

### Creating Nodes for Graph

In [None]:
# get user input

def get_user_input(state: ChildMessagesState):
  user_input = state.user_query

  # store in message history, under role of user since user input
  state.messages.append({"role": "user", "content": user_input})

  # add one to the turn count
  state.turn_count += 1

  return state

In [None]:
# check question safety --> use function from Mia's safety filtering section

def check_input_safety(state: ChildMessagesState):
  user_input = state.user_query

  # use safety filtering to check if input is safe
  # if it is not, fall back on sfae response
  if filter_input(user_input) == False:
    state.final_output = """I'm sorry, that response is not safe for me to answer.
    Please ask a trusted adult for further information on this question.
    What else would you like to learn about?"""

    # add system response to history as it is final response automatically
    state.messages.append({"role": "system", "content": state.final_output})

    return END

  return state

In [None]:
# make sure that user is not spending too long on system
# within certain time period? TODO

def check_exchange_count(state: ChildMessagesState):

  # if we have reached limit, return safe message as final output
  if state.turn_count >= 10:
    state.final_output = "It was great talking to you! However, your librAIrian \
    needs to take a break. Come back later to ask more questions!"

    # add to history
    state.messages.append({"role": "system", "content": state.final_output})

    return END

  return state

In [None]:
# retrieve relevant passages

def retrieve_passages(state: ChildMessagesState):
  # using rag to retrieve passages (ada?)
  state.retrieved_passages = rag_passages(state.user_query, path_to_embeddings)

  return state

In [None]:
# generate response

def generate_response(state: ChildMessagesState):
  # generate_answer from ada
  # call it using the query, the results of RAG, and message history
  state.response = generate_response(
      state.user_query,
      state.retrieved_passages,
      state.messages
    )

  return state

In [None]:
# validate response safety

# if answer is unsafe, only try again a set number of times
max_attempts = 3

def check_answer_safety(state: ChildMessagesState):
  # make sure have not reattempted too many times already
  attempts = 0
  while attempts < max_attempts:
    # output safety check from Mia's safety filtering layer
    if filter_output(state.response) == False:
      # print response letting them know this was not safe
      # not stored in messages history or state as it is a system
      # result, not the user safety
      print("Please wait one second for me to rephrase my response! \
      My initial response was not safe.")

      # try generating a new answer
      # and then reenter this loop
      state.response = ask_gpt(
        state.user_query,
        state.retrieved_passages,
        state.messages
      )

      # add 1 to attempts
      attempts += 1

    # if the output is safe, then the draft response is the final response
    else:
      state.final = state.response

      # add to message history
      state.messages.append({"role": "system", "content": state.final_output})

      return state

    # if we exit the while loop because exceeded max attempts
    # give back up message
    state.final_output("I'm sorry, I reached the maximum number of attempts to \
    generate a safe response. Please try another question.")
    # add this to message history for system
    state.messages.append({"role": "system", "content": state.final_output})
    return state

### Creating Actual Graph

In [None]:
# create a graph builder to set nodes and edges using our messages state
# modeled this after the example from lab 7

graph_builder = StateGraph(ChildMessagesState)

In [None]:
# add nodes, and give them the same name
graph_builder.add_node("get_user_input", get_user_input)
graph_builder.add_node("check_input_safety", check_input_safety)
graph_builder.add_node("check_exchange_count", check_exchange_count)
graph_builder.add_node("retrieve_passages", retrieve_passages)
graph_builder.add_node("generate_response", generate_response)
graph_builder.add_node("check_answer_safety", check_answer_safety)

# add edges between subsequent pieces of the pipeline
graph_builder.set_entry_point("get_user_input")
graph_builder.add_edge("get_user_input", "check_input_safety")
graph_builder.add_edge("check_input_safety", "check_exchange_count")
graph_builder.add_edge("check_exchange_count", "retrieve_passages")
graph_builder.add_edge("retrieve_passages", "generate_response")
graph_builder.add_edge("generate_response", "check_answer_safety")
# add edge to allow for a new user query until we reach limit
graph_builder.add_edge("check_answer_safety", "get_user_input")

In [None]:
# compile into a single graph object
graph = graph_builder.compile()

# get user input
user_query = input("How can LibrAIrian help you?")

# initialize state object
state = ChildMessagesState()
# set user query to input from above
state.user_query = user_query
# generate response by invoking the graph we built
output = graph.invoke(state)
# print the final result to the user
print(output.final_output)