In [None]:
# These should already be installed on your workspace
!pip install --disable-pip-version-check --quiet -U langchain==0.2.16
!pip install --disable-pip-version-check --quiet -U langchain_openai==0.1.23
!pip install --disable-pip-version-check --quiet -U langgraph==0.2.19
!pip install --disable-pip-version-check --quiet -U langchainhub==0.1.21
!pip install --disable-pip-version-check --quiet -U tavily-python==0.4.0
!pip install --disable-pip-version-check --quiet -U langchain-community==0.2.16
!pip install --disable-pip-version-check --quiet -U python-dotenv==1.0.1

In [None]:
# Load in the OpenAI key and Tavily key.
# In the project folder, create a file named 'config.env'
# ensure your .env file contains keys named OPENAI_API_KEY="your key" and TAVILY_API_KEY="your key"
from dotenv import load_dotenv
import os

load_dotenv('config.env')
assert os.getenv('OPENAI_API_KEY') is not None
assert os.getenv('TAVILY_API_KEY') is not None

In [None]:
import os
import time
import requests
from pprint import pprint
from typing import Dict, List, Literal

from langgraph.graph import START, END, StateGraph
from langgraph.graph.message import MessagesState, add_messages
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import MemorySaver

from langchain_core.runnables import RunnableConfig
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.tools import tool
from langchain_core.messages import (
    SystemMessage,
    HumanMessage,
    AIMessage,
    ToolMessage,
)
from langchain_openai import ChatOpenAI

from IPython.display import Image, display
from tavily import TavilyClient

## Instantiate Chat Model

In [None]:
llm = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0.0,
    streaming=True,
    api_key=os.getenv('OPENAI_API_KEY')
)

In [None]:
tavily_client = TavilyClient(
    api_key=os.getenv("TAVILY_API_KEY"),
)

## Define State

In [None]:
class State(MessagesState):
    patient_question: str
    answer: str
    summary: str
    quiz_requested: str = "NO"
    quiz_question: str
    quiz_response: str
    graded_result: str
    new_topic_requested: str = "NO"
    # config: RunnableConfig = None

## Create Helper Functions

In [None]:
def display_text_to_user(text: str):
    print()
    pprint(text)
    print()
    time.sleep(2) # wait for it to render before asking for input or it'll never show up.

In [None]:
def display_messages(messages: MessagesState):
    print()
    for message in messages:
        message.pretty_print()
    time.sleep(2) # wait for it to render before asking for input or it'll never show up.

In [None]:
def ask_user_for_input(input_description)->Dict:
    response = input(input_description)
    return response

## Define Tools

In [None]:
@tool
def web_search(question:str)->Dict:
    """
    Return top search results for a given search query
    """
    response = tavily_client.search(question)
    return response

In [None]:
tools = [web_search]
tools_by_name = {tool.name: tool for tool in tools}

In [None]:
llm_with_tools = llm.bind_tools(tools)

## Define Nodes and Routers

In [None]:
# tool node:
def web_search_node(state: State):
    messages = state["messages"]

    for tool_call in state["messages"][-1].tool_calls:
        tool = tools_by_name[tool_call["name"]]
        observation = tool.invoke(tool_call["args"])
        content = str(observation['results'])
        tool_call_id = tool_call["id"]
        messages.append(ToolMessage(content=content, tool_call_id=tool_call_id))

    return {"messages": messages}

In [None]:
# agent nodes:
def entry_point(state: State):
    # thread_id = THREADID
    # config = RunnableConfig(recursion_limit=2000, configurable={"thread_id": thread_id})

    text = "Hello. How are you doing today?"
    display_text_to_user(text)

    system_message = SystemMessage("You are an experienced medical professional. "
                                   "You conduct a web search to respond to a user's question.")
    messages = [system_message]
    # return {"messages": messages, "config": config}
    return {"messages": messages}


def ask_health_topic(state: State):
    messages = state["messages"]

    input_description = "What is your health-related question? Please type a question on a health topic or a medical treatment. "
    human_input = ask_user_for_input(input_description)

    human_message = HumanMessage(
                    "Perform a web search to answer the following question: "
                    f"```{human_input}```"
                    "Use the appropriate tool and conduct the web research for this question. "
                    "Return a detailed response including health topic, medical condition, treatment options, post-treatment care, "
                    "and important citations from the medical literature or medical news outlets."
                    )
    messages.append(human_message)

    # print()
    # print("ask_health_topic:")
    # display_messages(messages)

    return {
        "patient_question": human_input,
        "messages": messages,
        }


def perform_websearch(state: State):
    messages = state["messages"]

    # if state["config"]:
    #     ai_message = llm_with_tools.invoke(messages, config=state["config"])
    # else:
    #     ai_message = llm_with_tools.invoke(messages)

    ai_message = llm_with_tools.invoke(messages)
    messages.append(ai_message)

    # print()
    # print("perform_websearch:")
    # display_messages(messages)

    return {"messages": messages, "answer": ai_message.content}


def summarize_search(state: State):
    answer = state["answer"]
    messages = state["messages"]

    template = ChatPromptTemplate.from_messages(
    [
    ("system", "You are an excellent medical writer who can summarize complex medical topics \n"
               "into easy to understand language that a layman, i.e. somebody who is not a medical expert, can understand \n"
               "so they can make appropriate decisions for their health with the information provided by you."),
    ("human", "Create a concise summary of the provided {answer} for patients \n"
              "that are not medical experts. \n"
              "Create nicely formatted, easily readable text as output \n"
              "with several sections with the headers 'health topic', 'medical condition', 'treatment options', 'post-treatment care' and 'citations'. \n"
              "'health topic': a short description of health topic being discussed, \n"
              "'medical condition': a description of the medical condition, \n"
              "'treatment options': a brief discussion of several treatment options, \n"
              "'post-treatment care': a listing of recommended patient activities during post-treatment care, \n"
              "'citations': include relevant citations from the literature and important medical news outlets if appropriate. \n"
    ),
    ]
    )

    chain = template | llm_with_tools | StrOutputParser()

    ai_message = chain.invoke({"answer": answer})
    messages.append(ai_message)

    # print()
    # print("summarize_search output:")
    # display_messages(messages)

    return {"messages": messages, "summary": ai_message}


def present_summary(state: State):
    display_text_to_user(state["summary"])


def check_quiz_request(state: State):
    input_description = "Do you want to check your comprehension based on a generated quiz? (YES or NO)"
    human_input = ask_user_for_input(input_description)
    return {"quiz_requested": human_input}


def create_quiz(state: State):
    summary = state["summary"]
    messages = state["messages"]

    template = ChatPromptTemplate.from_messages(
    [
    ("system", "You are an experienced medical educator and experienced in the creation of medical exams."),
    ("human", "Generate a single, relevant quiz question based on the provided information {summary} ."
              "Create a quiz question that expects a comprehensive answer in 3 to 4 sentences as a response. "
              "Create a quiz question that tests knowledge and understanding of "
              "a medical condition, a treatment option, or of a post-treatment care. "
              "Do not create a multiple-choice quiz. Do not show the solution. "
    ),
    ]
    )

    chain = template | llm_with_tools | StrOutputParser()

    ai_message = chain.invoke({"summary": summary})
    messages.append(ai_message)

    # print()
    # print("create_quiz:")
    # display_messages(messages)

    return {"messages": messages, "quiz_question": ai_message}


def present_quiz(state: State):
    text=("Please answer following quiz question on the health topic."
          "Please answer in about 3 to 4 sentences in a comprehensive response to the question."
          "Please type your answer in the box provided and hit RETURN:")
    display_text_to_user(text)

    quiz_question = state["quiz_question"]
    human_input = ask_user_for_input(quiz_question)
    return {"quiz_response": human_input}


def grade_quiz(state: State):
    quiz_response = state["quiz_response"]
    quiz_question = state["quiz_question"]
    summary = state["summary"]
    messages = state["messages"]

    template = ChatPromptTemplate.from_messages(
    [
    ("system", "You are an expert educator with medical expertise tasked with grading a student's answer to a quiz question.."),
    ("human",  '''Evaluate the response {quiz_response} carefully and assign a letter grade (A, B, C, D, or F)
                  along with a detailed explanation.

                    ## Input Format
                    - **Question:** [The original quiz question] {quiz_question}
                    - **Summary:** [Summary on the health topic] {summary}
                    - **Student's Response:** [The student's submitted answer] {quiz_response}

                    ## Grading Criteria

                    Evaluate based on the following dimensions:

                    1. **Accuracy (40%)** - Is the information factually correct?
                    2. **Completeness (30%)** - Does it address all parts of the question?
                    3. **Understanding (20%)** - Does it demonstrate comprehension of underlying concepts?
                    4. **Clarity (10%)** - Is it well-organized and clearly communicated?

                    ## Grading Scale

                    - **A (90-100%):** Exceptional. Accurate, complete, demonstrates deep understanding, well-articulated.
                    - **B (80-89%):** Good. Mostly accurate and complete with minor gaps or unclear points.
                    - **C (70-79%):** Satisfactory. Basic understanding but missing key elements or contains some errors.
                    - **D (60-69%):** Poor. Significant gaps, multiple errors, or fundamental misunderstandings.
                    - **F (<60%):** Failing. Incorrect, incomplete, or demonstrates lack of understanding.

                    ## Output Format

                    **Grade:** [Letter Grade]

                    **Explanation:**
                    - **Strengths:** [What the student did well]
                    - **Weaknesses:** [What was missing, incorrect, or unclear]
                    - **Key Missing Elements:** [Specific points from the expected answer that were omitted]
                    - **Suggestions for Improvement:** [Constructive feedback]

                    **Overall Assessment:** [5 to 6 sentence summary justifying the grade]

                  Steps:
                  - Generate a solution for the {quiz_question} from the {summary} in 3 to 4 sentences.
                    Print your solution titled with 'AI response:'.
                  - Then show the grade for the student's response {quiz_response} to the question {quiz_question}
                    titled with 'Grade:' and an explanation for your grading titled with 'Explanation of Grading:'.
                  - Be fair, constructive, and specific in your evaluation. Focus on helping the student understand
                    where they succeeded and where they can improve.
                  - Include relevant citations of the {summary} at the end of the explanation.
                  '''
    ),
    ]
    )

    chain = template | llm_with_tools | StrOutputParser()

    ai_message = chain.invoke({
        "quiz_response": quiz_response,
        "quiz_question": quiz_question,
        "summary": summary,
        }
    )
    messages.append(ai_message)

    # print()
    # print("grade_quiz:")
    # display_messages(messages)

    return {"messages": messages, "graded_result": ai_message}


def present_grade(state: State):
    display_text_to_user("Here is your graded result: ")
    display_text_to_user(state["graded_result"])


def check_new_topic(state: State):
    input_description = "Do you want to ask another question about another health-related topic? (YES or NO):"
    human_input = ask_user_for_input(input_description)
    return {"new_topic_requested": human_input}

In [None]:
# router funtions:
def tool_router(state: MessagesState):
    last_message = state["messages"][-1]
    if last_message.tool_calls:
        return "web_search"
    return END


def check_quiz_request_router(state: State):
    quiz_requested = state["quiz_requested"]
    if quiz_requested.lower() == "yes":
        return "create_quiz"
    return END


def check_new_topic_router(state: State):
    new_topic_requested = state["new_topic_requested"]
    if new_topic_requested.lower() == "yes":
        # new_thread_id = THREADID + 1
        # state["config"] = RunnableConfig(recursion_limit=2000, configurable={"thread_id": new_thread_id})
        return "ask_health_topic"
    return END

## Create Workflow 

In [None]:
workflow = StateGraph(State)

In [None]:
workflow.add_node("entry_point", entry_point)
workflow.add_node("ask_health_topic", ask_health_topic)
workflow.add_node("perform_websearch", perform_websearch)

# workflow.add_node("web_search", ToolNode([web_search]))
workflow.add_node("web_search", web_search_node)

workflow.add_node("summarize_search", summarize_search)
workflow.add_node("present_summary", present_summary)
workflow.add_node("check_quiz_request", check_quiz_request)
workflow.add_node("create_quiz", create_quiz)
workflow.add_node("present_quiz", present_quiz)
workflow.add_node("grade_quiz", grade_quiz)
workflow.add_node("present_grade", present_grade)
workflow.add_node("check_new_topic", check_new_topic)

In [None]:
workflow.add_edge(START, "entry_point")
workflow.add_edge("entry_point", "ask_health_topic")
workflow.add_edge("ask_health_topic", "perform_websearch")

workflow.add_conditional_edges(
    source="perform_websearch",
    path=tool_router,
    path_map=["web_search", END]
)
workflow.add_edge("web_search", "perform_websearch")

workflow.add_edge("perform_websearch", "summarize_search")
workflow.add_edge("summarize_search", "present_summary")
workflow.add_edge("present_summary", "check_quiz_request")

workflow.add_conditional_edges(
    source="check_quiz_request",
    path=check_quiz_request_router,
    path_map=["create_quiz", END]
)

workflow.add_edge("create_quiz", "present_quiz")
workflow.add_edge("present_quiz", "grade_quiz")
workflow.add_edge("grade_quiz", "present_grade")
workflow.add_edge("present_grade", "check_new_topic")

workflow.add_conditional_edges(
    source="check_new_topic",
    path=check_new_topic_router,
    path_map=["ask_health_topic", END]
)

## Display Workflow and Add Memory Management

In [None]:
memory = MemorySaver()
graph = workflow.compile(
            checkpointer=memory,
        )

In [None]:
display(Image(graph.get_graph().draw_mermaid_png()))

## Run Initial Workflow

Example human input when AI asks:

'What are treatment methods for skin melanoma?'

In [None]:
THREADID=1
config = RunnableConfig(recursion_limit=2000, configurable={"thread_id": THREADID})

In [None]:
input_question = {"patient_question": ""}
# input_question = {"question": "What are treatment methods for skin melanoma?"}

In [None]:
# for event in graph.stream(
#     input=input_question,
#     config=config,
#     stream_mode="values"
#     ):
#     if not event['messages']:
#         continue
#     event['messages'][-1].pretty_print()

In [None]:
output = graph.invoke(
    input=input_question,
    config=config,
)

In [None]:
output["patient_question"]

In [None]:
output["answer"]

In [None]:
output["summary"]

In [None]:
output["quiz_requested"]

In [None]:
if output["quiz_requested"] == "yes":
    print(output["quiz_question"])
    print(output["quiz_response"])
    print(output["graded_result"])
    print(output["new_topic_requested"])

In [None]:
for message in output["messages"]:
    message.pretty_print()

In [None]:
output