In [None]:
import os
from google.colab import userdata

os.environ["LANGSMITH_API_KEY"] = userdata.get('LANGSMITH_API_KEY')
os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_PROJECT"] = "default"
os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGSMITH_PROJECT"] = "weekend_party"

Let's load and explore a dataset:

In [None]:
from datasets import load_dataset
ds = load_dataset("cais/mmlu", "high_school_geography")

In [None]:
ds_dict = ds["test"].take(100).to_dict()
print(ds_dict["question"][0])

In [None]:
print(ds_dict["choices"][0])

In [None]:
ds_dict["answer"][0]

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI
google_api_key = userdata.get('GOOGLE_API_KEY')
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-001", google_api_key=google_api_key)

In [None]:
from langchain.agents import load_tools
from langgraph.prebuilt import create_react_agent


research_tools = load_tools(
  tool_names=["ddg-search", "arxiv", "wikipedia"],
  llm=llm
)

system_prompt = (
    "You're a hard-working, curious and creative student. "
    "You're working on exam quesion. Think step by step."
    "Always provide an argumentation for your answer. "
    "Do not assume anything, use available tools to search "
    "for evidence and supporting statements."
)


In [None]:
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langgraph.graph import MessagesState
from langgraph.prebuilt.chat_agent_executor import AgentState

raw_prompt_template = (
    "Answer the following multiple-choice question. "
    "\nQUESTION:\n{question}\n\nANSWER OPTIONS:\n{options}\n"
)
prompt = ChatPromptTemplate.from_messages(
    [("system", system_prompt),
     ("user", raw_prompt_template),
     ("placeholder", "{messages}")
     ]
)

class ResearchState(AgentState):
  question: str
  options: str

research_agent = create_react_agent(model=llm, tools=research_tools, state_schema=ResearchState, prompt=prompt)

In [None]:
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate


raw_prompt_template_with_critique = (
    "You tried to answer the exam question and you get feedback from your "
    "professor. Work on improving your answer and incorporating the feedback. "
    "\nQUESTION:\n{question}\n\nANSWER OPTIONS:\n{options}\n\n"
    "INITIAL ANSWER:\n{answer}\n\nFEEDBACK:\n{feedback}"

)
prompt = ChatPromptTemplate.from_messages(
    [("system", system_prompt),
     ("user", raw_prompt_template_with_critique),
     ("placeholder", "{messages}")
     ]
)

class ReflectionState(ResearchState):
  answer: str
  feedback: str

research_agent_with_critique = create_react_agent(model=llm, tools=research_tools, state_schema=ReflectionState, prompt=prompt)

In [None]:
from typing import Optional
from pydantic import BaseModel, Field


reflection_prompt = (
    "You are a university professor and you're supervising a student who is "
    "working on multiple-choice exam question. "
    "nQUESTION: {question}.\nANSWER OPTIONS:\n{options}\n."
    "STUDENT'S ANSWER:\n{answer}\n"
    "Reflect on the answer and provide a feedback whether the answer "
    "is right or wrong. If you think the student's answer is correct, rewrite the final answer "
    "in the `answer` field. "
    "Only provide critique if you think the asnwer is "
    "incorrect or there are reasoning flaws. Do not assume anything, "
    "evaluate only the reasoning the student provided and whether there is "
    "enough evidence for their answer."
)

class Response(BaseModel):
    """A final response to the user."""

    answer: Optional[str] = Field(
        description="The final answer to the original question. Always provide one if it's right and there's no critique.",
        default=None,
    )
    critique: Optional[str] = Field(
        description="A critique of the student's answer. If you think it is incorrect, provide an acitonable feedback",
        default=None,
    )


In [None]:
from typing import Annotated, Literal, TypedDict
from langchain_core.runnables.config import RunnableConfig
from operator import add
from langchain_core.output_parsers import StrOutputParser
from langgraph.graph import StateGraph, START, END, Graph


class ReflectionAgentState(TypedDict):
    question: str
    options: str
    answer: str
    steps: Annotated[int, add]
    response: Response


def _should_end(state: ReflectionAgentState, config: RunnableConfig) -> Literal["research", END]:
    max_reasoning_steps = config["configurable"].get("max_reasoning_steps", 10)
    if state.get("response") and state["response"].answer:
        return END
    if state.get("steps", 1) > max_reasoning_steps:
        return END
    return "research"

reflection_chain = PromptTemplate.from_template(reflection_prompt) | llm.with_structured_output(Response)

def _reflection_step(state: ReflectionAgentState):
    result = reflection_chain.invoke(state)
    return {"response": result, "steps": 1}


def _research_start(state: ReflectionAgentState):
  answer = research_agent.invoke(state)
  return {"answer": answer["messages"][-1].content}


def _research(state: ReflectionAgentState):
  agent_state = {
      "answer": state["answer"],
      "question": state["question"],
      "options": state["options"],
      "feedback": state["response"].critique
  }
  answer = research_agent_with_critique.invoke(agent_state)
  return {"answer": answer["messages"][-1].content}

In [None]:
builder = StateGraph(ReflectionAgentState)
builder.add_node("research_start", _research_start)
builder.add_node("research", _research)
builder.add_node("reflect", _reflection_step)

builder.add_edge(START, "research_start")
builder.add_edge("research_start", "reflect")
builder.add_edge("research", "reflect")
builder.add_conditional_edges("reflect", _should_end)
graph = builder.compile()


from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))

Let's test it out:

In [None]:
i = 3
question = ds_dict["question"][i]
options = "\n".join([f"{i}. {a}" for i, a in enumerate(ds_dict["choices"][i])])

async for _, event in graph.astream({"question": question, "options": options}, stream_mode=["updates"]):
  print(event)

In [None]:
result = await graph.ainvoke({"question": question, "options": options})

In [None]:
result

# Communication through a shared list of messages

In [None]:
from langchain.agents import load_tools
from langgraph.prebuilt import create_react_agent


research_tools = load_tools(
  tool_names=["ddg-search", "arxiv", "wikipedia"],
  llm=llm
)

system_prompt = (
    "You're a hard-working, curious and creative student. "
    "You're working on exam quesion. Think step by step."
    "Always provide an argumentation for your answer. "
    "Do not assume anything, use available tools to search "
    "for evidence and supporting statements."
)

research_agent = create_react_agent(model=llm, tools=research_tools, prompt=system_prompt)

reflection_prompt = (
    "You are a university professor and you're supervising a student who is "
    "working on multiple-choice exam question. "
    "Given the dialogue above, reflect on the answer provided and give a feedback "
    " if needed. If you think the final answer is correct, reply with "
    "an empty message. Only provide critique if you think the last answer might "
    "be incorrect or there are reasoning flaws. Do not assume anything, "
    "evaluate only the reasoning the student provided and whether there is "
    "enough evidence for their answer."
)

In [None]:
from langchain_core.prompts import PromptTemplate
from langgraph.types import Command
from langchain_core.runnables import RunnableConfig


question_template = PromptTemplate.from_template(
    "QUESTION:\n{question}\n\nANSWER OPTIONS:\n{options}\n\n"
)

def _ask_question(state):
  return {"messages": [("human", question_template.invoke(state).text)]}

def _give_feedback(state, config: RunnableConfig):
  messages = event["messages"] + [("human", reflection_prompt)]
  max_messages = config["configurable"].get("max_messages", 20)

  if len(messages) > max_messages:
    return Command(
      update={},
      goto=END
    )

  result = llm.invoke(messages)

  if result.content:
    return Command(
      update={"messages": [
          ("assistant", result.content),
           ("human", "Please, address the feedback above and give an answer.")]},
      goto="research"
  )
  return Command(
      update={},
      goto=END
  )

In [None]:
class ReflectionAgentStateAlternative(MessagesState):
  question: str
  options: str


builder = StateGraph(ReflectionAgentStateAlternative)
builder.add_node("ask_question", _ask_question)
builder.add_node("research", research_agent)
builder.add_node("reflect", _give_feedback)

builder.add_edge(START, "ask_question")
builder.add_edge("ask_question", "research")
builder.add_edge("research", "reflect")
graph = builder.compile()


from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
async for _, event in graph.astream({"question": question, "options": options}, stream_mode=["values"]):
  print(len(event["messages"]))

In [None]:
for m in event["messages"]:
  print(type(m))

In [None]:
for m in event["messages"]:
  m.pretty_print()

# Installation

In [None]:
!pip install --upgrade langsmith langchain-google-genai duckduckgo-search langchain-community langgraph arxiv wikipedia datasets huggingface_hub fsspec