In [1]:
import os
from dotenv import load_dotenv
load_dotenv()
os.environ["GOOGLE_API_KEY"] = os.getenv('GG_API_KEY')
from langchain_google_genai import ChatGoogleGenerativeAI

llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
)

In [2]:
from typing import Annotated, TypedDict

from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages

#### Architecture #1: LLM Call

In [3]:
class State(TypedDict):
    # Messages have the type "list". The `add_messages` 
    # function in the annotation defines how this state should 
    # be updated (in this case, it appends new messages to the 
    # list, rather than replacing the previous messages)
    messages: Annotated[list, add_messages]

def chatbot(state: State):
    answer = llm.invoke(state["messages"])
    return {"messages": [answer]}

builder = StateGraph(State)
builder.add_node("chatbot", chatbot)
builder.add_edge(START, 'chatbot')
builder.add_edge('chatbot', END)

graph = builder.compile()

In [4]:

from langchain_core.messages import HumanMessage
input = {"messages": [HumanMessage('hi!')]}
for chunk in graph.stream(input):
    print(chunk)

{'chatbot': {'messages': [AIMessage(content='Hi there! How can I help you today?', additional_kwargs={}, response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'model_name': 'gemini-2.0-flash', 'safety_ratings': []}, id='run--0a6dd5fd-cfd8-4cbb-8fa9-9cbb8808a5e3-0', usage_metadata={'input_tokens': 2, 'output_tokens': 11, 'total_tokens': 13, 'input_token_details': {'cache_read': 0}})]}}


#### Architecture #2: Chain

In [5]:
model_low_temp = llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    temperature=0.1,
    max_tokens=None,
    timeout=None,
    max_retries=2,
)

model_high_temp = llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    temperature=0.7,
    max_tokens=None,
    timeout=None,
    max_retries=2,
)

In [6]:
class State(TypedDict):
    # to track conversation history
    messages: Annotated[list, add_messages]
    # input
    user_query: str
    # output
    sql_query: str
    sql_explanation: str

class Input(TypedDict):
    user_query: str

class Output(TypedDict):
    sql_query: str
    sql_explanation: str

In [7]:
from langchain_core.messages import SystemMessage
generate_prompt = SystemMessage(
    """You are a helpful data analyst who generates SQL queries for users based 
    on their questions."""
)

In [8]:
def generate_sql(state: State) -> State:
    user_message = HumanMessage(state["user_query"])
    messages = [generate_prompt, *state["messages"], user_message]
    res = model_low_temp.invoke(messages)
    return {
        "sql_query": res.content,
        # update conversation history
        "messages": [user_message, res],
    }

In [9]:
explain_prompt = SystemMessage(
    "You are a helpful data analyst who explains SQL queries to users."
)

def explain_sql(state: State) -> State:
    messages = [
        explain_prompt,
        # contains user's query and SQL query from prev step
        *state["messages"],
    ]
    res = model_high_temp.invoke(messages)
    return {
        "sql_explanation": res.content,
        # update conversation history
        "messages": res,
    }

In [11]:
builder = StateGraph(State, input_schema=Input, output_schema=Output)
builder.add_node("generate_sql", generate_sql)
builder.add_node("explain_sql", explain_sql)
builder.add_edge(START, "generate_sql")
builder.add_edge("generate_sql", "explain_sql")
builder.add_edge("explain_sql", END)

graph = builder.compile()

In [12]:
graph.invoke({
  "user_query": "What is the total sales for each product?"
})

{'sql_query': '```sql\nSELECT\n    product_name,\n    SUM(sales) AS total_sales\nFROM\n    sales_table\nGROUP BY\n    product_name;\n```',
 'sql_explanation': "\n\nOkay, I can help you understand this SQL query!  Let's break it down:\n\n**Purpose:**\n\nThe query is designed to calculate the total sales amount for each distinct product in your sales data.\n\n**Explanation:**\n\n1.  **`SELECT product_name, SUM(sales) AS total_sales`**:\n    *   `SELECT product_name`: This specifies that you want to retrieve the name of the product in your result.\n    *   `SUM(sales)`: This is an aggregate function that calculates the sum of the `sales` column.  `SUM()` adds up all the sales values for each group (defined later by the `GROUP BY` clause).\n    *   `AS total_sales`: This assigns an alias (a temporary name) to the calculated sum of sales.  Instead of the column being named `SUM(sales)`, it will be named `total_sales` in the output, which is more descriptive.\n\n2.  **`FROM sales_table`**:\n

In [13]:
for s in graph.stream({
  "user_query": "What is the total sales for each product?"
}):
    print(s)

{'generate_sql': {'sql_query': '```sql\nSELECT\n    product_name,\n    SUM(sales) AS total_sales\nFROM\n    sales_table\nGROUP BY\n    product_name;\n```', 'messages': [HumanMessage(content='What is the total sales for each product?', additional_kwargs={}, response_metadata={}, id='11fea22b-0c73-40f2-9ca8-9f6e7cd4aae3'), AIMessage(content='```sql\nSELECT\n    product_name,\n    SUM(sales) AS total_sales\nFROM\n    sales_table\nGROUP BY\n    product_name;\n```', additional_kwargs={}, response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'model_name': 'gemini-2.0-flash', 'safety_ratings': []}, id='run--ae89aaa7-24d5-4a52-a1c5-e6254bc7fe5c-0', usage_metadata={'input_tokens': 29, 'output_tokens': 38, 'total_tokens': 67, 'input_token_details': {'cache_read': 0}})]}}
{'explain_sql': {'sql_explanation': '\n\n**Explanation:**\n\n1.  **`SELECT product_name, SUM(sales) AS total_sales`**:\n    *   This part specifies what we want to retrieve.\n 

In [14]:
img = graph.get_graph().draw_mermaid_png()
filename = "my_langgraph_workflow_2.png"
try:
    with open(filename, 'wb') as f:
        f.write(img)
except:
    print('Error')

#### Architecture #3: Router

In [15]:
from typing import Annotated, Literal, TypedDict

from langchain_core.documents import Document
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.vectorstores.in_memory import InMemoryVectorStore
from langchain_huggingface import HuggingFaceEmbeddings
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages

In [16]:
embeddings = HuggingFaceEmbeddings()

model_low_temp = llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    temperature=0.1,
    max_tokens=None,
    timeout=None,
    max_retries=2,
)

model_high_temp = llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    temperature=0.7,
    max_tokens=None,
    timeout=None,
    max_retries=2,
)

In [17]:
class State(TypedDict):
    # to track conversation history
    messages: Annotated[list, add_messages]
    # input
    user_query: str
    # output
    domain: Literal["records", "insurance"]
    documents: list[Document]
    answer: str

class Input(TypedDict):
    user_query: str

class Output(TypedDict):
    documents: list[Document]
    answer: str

In [18]:
medical_records_store = InMemoryVectorStore.from_documents([], embeddings)
medical_records_retriever = medical_records_store.as_retriever()

insurance_faqs_store = InMemoryVectorStore.from_documents([], embeddings)
insurance_faqs_retriever = insurance_faqs_store.as_retriever()

In [19]:
router_prompt = SystemMessage(
    """You need to decide which domain to route the user query to. You have two 
        domains to choose from:
          - records: contains medical records of the patient, such as 
          diagnosis, treatment, and prescriptions.
          - insurance: contains frequently asked questions about insurance 
          policies, claims, and coverage.

Output only the domain name."""
)

In [20]:
def router_node(state: State) -> State:
    user_message = HumanMessage(state["user_query"])
    messages = [router_prompt, *state["messages"], user_message]
    res = model_low_temp.invoke(messages)
    return {
        "domain": res.content,# update conversation history
        "messages": [user_message, res],
    }

In [21]:
def pick_retriever(
    state: State,
) -> Literal["retrieve_medical_records", "retrieve_insurance_faqs"]:
    if state["domain"] == "records":
        return "retrieve_medical_records"
    else:
        return "retrieve_insurance_faqs"


In [22]:
def retrieve_medical_records(state: State) -> State:
    documents = medical_records_retriever.invoke(state["user_query"])
    return {
        "documents": documents,
    }

def retrieve_insurance_faqs(state: State) -> State:
    documents = insurance_faqs_retriever.invoke(state["user_query"])
    return {
        "documents": documents,
    }

In [23]:
medical_records_prompt = SystemMessage(
    """You are a helpful medical chatbot who answers questions based on the 
        patient's medical records, such as diagnosis, treatment, and 
        prescriptions."""
)

insurance_faqs_prompt = SystemMessage(
    """You are a helpful medical insurance chatbot who answers frequently asked 
        questions about insurance policies, claims, and coverage."""
)

In [27]:
def generate_answer(state: State) -> State:
    if state["domain"] == "records":
        prompt = medical_records_prompt
    else:
        prompt = insurance_faqs_prompt
    messages = [
        prompt,
        *state["messages"],
        HumanMessage(f'Documents: {state["documents"]}'),
    ]
    res = model_high_temp.invoke(messages)
    return {
        "answer": res.content,# update conversation history
        "messages": res,
    }

In [29]:
builder = StateGraph(State, input_schema=Input, output_schema=Output)
builder.add_node("router", router_node)
builder.add_node("retrieve_medical_records", retrieve_medical_records)
builder.add_node("retrieve_insurance_faqs", retrieve_insurance_faqs)
builder.add_node("generate_answer", generate_answer)
builder.add_edge(START, "router")
builder.add_conditional_edges("router", pick_retriever)
builder.add_edge("retrieve_medical_records", "generate_answer")
builder.add_edge("retrieve_insurance_faqs", "generate_answer")
builder.add_edge("generate_answer", END)

graph = builder.compile()

In [30]:
img = graph.get_graph().draw_mermaid_png()
filename = "my_langgraph_workflow_3.png"
try:
    with open(filename, 'wb') as f:
        f.write(img)
except:
    print('Error')

In [31]:
input = {
    "user_query": "Am I covered for COVID-19 treatment?"
}
for c in graph.stream(input):
    print(c)

{'router': {'domain': 'insurance', 'messages': [HumanMessage(content='Am I covered for COVID-19 treatment?', additional_kwargs={}, response_metadata={}, id='fc208307-328f-4608-962d-65328e1b3ba9'), AIMessage(content='insurance', additional_kwargs={}, response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'model_name': 'gemini-2.0-flash', 'safety_ratings': []}, id='run--f5d6f606-da7a-4ebc-bbbd-b946a763ffdd-0', usage_metadata={'input_tokens': 86, 'output_tokens': 2, 'total_tokens': 88, 'input_token_details': {'cache_read': 0}})]}}
{'retrieve_insurance_faqs': {'documents': []}}
{'generate_answer': {'answer': 'Yes, your policy covers COVID-19 treatment. Please check your policy \n        documents for specific details on coverage limits and cost-sharing.', 'messages': AIMessage(content='Yes, your policy covers COVID-19 treatment. Please check your policy \n        documents for specific details on coverage limits and cost-sharing.', additio