# Atlas Vector Search - LangGraph Integration - RAG Chatbot

This notebook is a companion to the [Build AI Agents with LangGraph](https://www.mongodb.com/docs/atlas/atlas-vector-search/ai-integrations/langgraph/) page. Refer to the page for set-up instructions and detailed explanations.

This notebook takes you through how to use LangGraph to implement agentic RAG by using MongoDB Atlas as the vector database, LangChain to implement retrieval tools, and LangGraph to orchestrate the agent workflow.

<a target="_blank" href="https://colab.research.google.com/github/mongodb/docs-notebooks/blob/main/ai-integrations/langgraph.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## Set up the environment

In [None]:
pip install --quiet --upgrade langgraph langgraph-checkpoint-mongodb langchain langchain_mongodb langchain-openai pymongo

In [None]:
import os

os.environ["OPENAI_API_KEY"] = "<api-key>"
MONGODB_URI = "<connection-string>"

## Use Atlas as a vector database

In [None]:
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_openai import OpenAIEmbeddings
from pymongo import MongoClient

# Connect to your Atlas cluster
client = MongoClient(MONGODB_URI)
collection = client["sample_mflix"]["embedded_movies"]
embedding_model = OpenAIEmbeddings(model="text-embedding-ada-002", disallowed_special=())

# Instantiate the vector store
vector_store = MongoDBAtlasVectorSearch(
   collection = collection,
   embedding = embedding_model,
   text_key = "plot",
   embedding_key = "plot_embedding",
   relevance_score_fn = "dotProduct"
)

In [None]:
# Use helper method to create the vector search index
vector_store.create_vector_search_index(
   dimensions = 1536 # The dimensions of the vector embeddings to be indexed 
)

In [None]:
from langchain_mongodb.index import create_fulltext_search_index
import time

# Use helper method to create the search index
create_fulltext_search_index(
   collection = collection,
   field = "title",
   index_name = "search_index"
)

# Wait for the index to build (this can take around a minute)
time.sleep(60)

## Define agent tools


In [None]:
from langchain.agents import tool

# Define a vector search tool
@tool
def vector_search(user_query: str) -> str:
    """
    Retrieve information using vector search to answer a user query.
    """
    
    retriever = vector_store.as_retriever(
       search_type = "similarity",
       search_kwargs = { "k": 5 } # Retrieve top 5 most similar documents
    )

    results = retriever.invoke(user_query)
   
    # Concatenate the results into a string
    context = "\n\n".join([f"{doc.metadata['title']}: {doc.page_content}" for doc in results])
    return context

# Test the tool
test_results = vector_search.invoke("What are some movies that take place in the ocean?")
print(test_results)

In [None]:
from langchain_mongodb.retrievers.full_text_search import MongoDBAtlasFullTextSearchRetriever

# Define a full-text search tool
@tool
def full_text_search(user_query: str) -> str:
    """
    Retrieve movie plot content based on the provided title.
    """
    
    # Initialize the retriever
    retriever = MongoDBAtlasFullTextSearchRetriever(
       collection = collection,            # MongoDB Collection in Atlas
       search_field = "title",             # Name of the field to search
       search_index_name = "search_index", # Name of the search index
       top_k = 1,                          # Number of top results to return       
    ) 
    results = retriever.invoke(user_query)
   
    for doc in results:
      if doc:
          return doc.metadata["fullplot"]
      else:
          return "Movie not found"
  
# Test the tool     
full_text_search.invoke("What is the plot of Titanic?")

## Prepare the LLM

In [None]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI

# Initialize the LLM
llm = ChatOpenAI()

# Create a chat prompt template for the agent, which includes a system prompt and a placeholder for `messages`
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "You are a helpful AI chatbot."
            " You are provided with tools to answer questions about movies."
            " Think step-by-step and use these tools to get the information required to answer the user query."
            " Do not re-run tools unless absolutely necessary."
            " If you are not able to get enough information using the tools, reply with I DON'T KNOW."
            " You have access to the following tools: {tool_names}."
        ),
        MessagesPlaceholder(variable_name="messages"),
    ]
)
 
tools = [
    vector_search,
    full_text_search
]

# Provide the tool names to the prompt
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))

# Prepare the LLM by making the tools and prompt available to the model
bind_tools = llm.bind_tools(tools)
llm_with_tools = prompt | bind_tools

In [None]:
# Check the 'name' of the tool that the LLM is calling.
# Here, we expect the LLM to use the 'vector_search' tool.
llm_with_tools.invoke(["What are some movies that take place in the ocean?"]).tool_calls

In [None]:
# Check the 'name' of the tool that the LLM is calling.
# Here, we expect the LLM to use the 'full_text_search' tool.
llm_with_tools.invoke(["What's the plot of Titanic?"]).tool_calls

## Build the graph

In [None]:
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START
from langgraph.graph.message import add_messages

# Define the graph state
class GraphState(TypedDict):
    messages: Annotated[list, add_messages]

# Instantiate the graph
graph = StateGraph(GraphState)


In [None]:
from typing import Dict, List

# Define the agent node
def agent(state: GraphState) -> Dict[str, List]:
    # Get the messages from the graph `state`
    messages = state["messages"]
    # Invoke `llm_with_tools` with `messages`
    result = llm_with_tools.invoke(messages)
    # Write `result` to the `messages` attribute of the graph state
    return {"messages": [result]}

# Add nodes using the `add_node` function
# The `agent` node should run the `agent` function
graph.add_node("agent", agent)

In [None]:
from langchain_core.messages import ToolMessage

# Create a map of tool name to tool call
tools_by_name = {tool.name: tool for tool in tools}

# Define tools node
def tools_node(state: GraphState) -> Dict[str, List]:
    result = []
    # Get the list of tool calls from messages
    tool_calls = state["messages"][-1].tool_calls
    # Iterate through `tool_calls`
    for tool_call in tool_calls:
        # Get the tool from `tools_by_name` using the `name` attribute of the `tool_call`
        tool = tools_by_name[tool_call["name"]]
        # Invoke the `tool` using the `args` attribute of the `tool_call`
        observation = tool.invoke(tool_call["args"])
        # Append the result of executing the tool to the `result` list as a ToolMessage
        # The `content`` of the message is `observation`
        # The `tool_call_id` can be obtained from the `tool_call`
        result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
    # Write `result` to the `messages` attribute of the graph state
    return {"messages": result}

# The `tools` node should run the `tools_node` function
graph.add_node("tools", tools_node)

In [None]:
from langgraph.graph import END

# Add an edge from the START node to the `agent` node
graph.add_edge(START, "agent")

# Add an edge from the `tools` node to the `agent` node
graph.add_edge("tools", "agent")

# Define a conditional edge
def route_tools(state: GraphState):
    """
    Uses a conditional_edge to route to the tools node if the last message
    has tool calls. Otherwise, route to the end.
    """
    # Get messages from graph state
    messages = state.get("messages", [])
    if len(messages) > 0:
        # Get the last AI message from messages
        ai_message = messages[-1]
    else:
        raise ValueError(f"No messages found in input state to tool_edge: {state}")
    
    # Check if the last message has tool calls
    if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
        return "tools"
    return END

# Add a conditional edge from the `agent` node to the `tools` node
graph.add_conditional_edges(
    "agent",
    route_tools,
    {"tools": "tools", END: END},
)

In [None]:
from IPython.display import Image, display

# Compile the graph
app = graph.compile()

# Optionally, visualize the graph
try:
    display(Image(app.get_graph().draw_mermaid_png()))
except Exception:
    pass

## Execute the graph

In [None]:
# Stream outputs from the graph as they pass through its nodes
def execute_graph(user_input: str) -> None:
    # Add user input to the messages attribute of the graph state
    # The role of the message should be "user" and content should be `user_input`
    input = {"messages": [("user", user_input)]}

    # Pass input to the graph and stream the outputs
    for output in app.stream(input):
        for key, value in output.items():
            print(f"Node {key}:")
            print(value)
            
    print("\n---FINAL ANSWER---")
    print(value["messages"][-1].content)

In [None]:
# Test the graph execution to view end-to-end flow
execute_graph("What are some movies that take place in the ocean?")

In [None]:
# Test the graph execution to view end-to-end flow
execute_graph("What is the plot of Titanic?")

## Add memory

In [None]:
from langgraph.checkpoint.mongodb import MongoDBSaver

# Initialize a MongoDB checkpointer
checkpointer = MongoDBSaver(client)

# Instantiate the graph with the checkpointer
app = graph.compile(checkpointer=checkpointer)

In [None]:
# Update the `execute_graph` function to include the `thread_id` argument
def execute_graph(thread_id: str, user_input: str) -> None:
    config = {"configurable": {"thread_id": thread_id}}
    input = {
        "messages": [
            (
                "user",
                user_input,
            )
        ]
    }
    for output in app.stream(input, config):
        for key, value in output.items():
            print(f"Node {key}:")
            print(value)
            
    print("\n---FINAL ANSWER---")
    print(value["messages"][-1].content)

In [None]:
# Test graph execution with thread ID
execute_graph("1", "What's the plot of Titanic?")

In [None]:
# Follow-up question to ensure message history works
execute_graph("1", "What movies are similar to the one I just asked about?")