# Graph Parallelization

## Run sequentially

In [None]:
from IPython.display import Image, display
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END

# define a graph state
class State(TypedDict):
    value: str

def a(state: State):
    print(f"Adding 'A' to state {state['value']}")
    return {"value": ["A"]}

def b(state: State):
    print(f"Adding 'B' to state {state['value']}")
    return {"value": ["B"]}

def c(state: State):
    print(f"Adding 'C' to state {state['value']}")
    return {"value": ["C"]}

def d(state: State):
    print(f"Adding 'D' to state {state['value']}")
    return {"value": ["D"]}



builder = StateGraph(State)

builder.add_node(a)
builder.add_node(b)
builder.add_node(c)
builder.add_node(d)

builder.add_edge(START, "a")
builder.add_edge("a", "b")
builder.add_edge("b", "c")
builder.add_edge("c", "d")
builder.add_edge("d", END)
graph = builder.compile()

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
graph.invoke({"value": []})

## Run in parallel

In [None]:
builder = StateGraph(State)

builder.add_node(a)
builder.add_node(b)
builder.add_node(c)
builder.add_node(d)

builder.add_edge(START, "a")
builder.add_edge("a", "b")
builder.add_edge("a", "c")
builder.add_edge("b", "d")
builder.add_edge("c", "d")
builder.add_edge("d", END)
graph = builder.compile()

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
graph.invoke({"value": []})

let's redefine Graph's State

In [None]:
import operator
from typing import Annotated

class State(TypedDict):
    value: Annotated[list, operator.add]

builder = StateGraph(State)

builder.add_node(a)
builder.add_node(b)
builder.add_node(c)
builder.add_node(d)

builder.add_edge(START, "a")
builder.add_edge("a", "b")
builder.add_edge("a", "c")
builder.add_edge("b", "d")
builder.add_edge("c", "d")
builder.add_edge("d", END)
graph = builder.compile()

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
graph.invoke({"value": []})

nodes "B" and "C" are executed concurrently in the same superstep, meaning "B" & "C" are in the same transactional context, so if one fails, both wont update the state.

## Extend B route

In [None]:
def b_1(state: State):
    print(f"Adding 'B_1' to state {state['value']}")
    return {"value": ["B_1"]}

def b_2(state: State):
    print(f"Adding 'B_2' to state {state['value']}")
    return {"value": ["B_2"]}

builder = StateGraph(State)

builder.add_node(a)
builder.add_node(b_1)
builder.add_node(b_2)
builder.add_node(c)
builder.add_node(d)

builder.add_edge(START, "a")
builder.add_edge("a", "b_1")
builder.add_edge("a", "c")
builder.add_edge("b_1", "b_2")
builder.add_edge("b_2", "d")
builder.add_edge("c", "d")
builder.add_edge("d", END)
graph = builder.compile()

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
graph.invoke({"value": []})

Only B_1 & C are executed concurrently in the same superstep.

Let's force D to wait until B_1 + B_2 AND C are completed.

In [None]:
builder = StateGraph(State)

builder.add_node(a)
builder.add_node(b_1)
builder.add_node(b_2)
builder.add_node(c)
builder.add_node(d)

builder.add_edge(START, "a")
builder.add_edge("a", "b_1")
builder.add_edge("a", "c")
builder.add_edge("b_1", "b_2")
builder.add_edge(["b_2", "c"], "d")
builder.add_edge("d", END)
graph = builder.compile()

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
graph.invoke({"value": []})

## Conditional Branching

In [None]:
import operator
from typing import Annotated, Sequence

from typing_extensions import TypedDict

from langgraph.graph import StateGraph, START, END


class State(TypedDict):
    value: Annotated[list, operator.add]
    route: str

def a(state: State):
    print(f"Adding 'A' to state {state['value']}")
    return {"value": ["A"]}

def b(state: State):
    print(f"Adding 'B' to state {state['value']}")
    return {"value": ["B"]}

def c(state: State):
    print(f"Adding 'C' to state {state['value']}")
    return {"value": ["C"]}

def d(state: State):
    print(f"Adding 'D' to state {state['value']}")
    return {"value": ["D"]}

def e(state: State):
    print(f"Adding 'E' to state {state['value']}")
    return {"value": ["E"]}


builder = StateGraph(State)
builder.add_node(a)
builder.add_node(b)
builder.add_node(c)
builder.add_node(d)
builder.add_node(e)

builder.add_edge(START, "a")


def route_bc_or_cd(state: State) -> Sequence[str]:
    if state["route"] == "bc":
        return ["b", "c"]
    return ["c", "d"]
    
intermediates = ["b", "c", "d"]
builder.add_conditional_edges(
    "a",
    route_bc_or_cd,
    intermediates
)
for node in intermediates:
    builder.add_edge(node, "e")

builder.add_edge("e", END)

graph = builder.compile()

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
graph.invoke({"value": [], "route": "bc"})

In [None]:
graph.invoke({"value": [], "route": "cd"})

## Practical example

In [None]:
from langchain_core.messages import HumanMessage, SystemMessage

from langchain_community.document_loaders import WikipediaLoader
from langchain_community.tools import TavilySearchResults


from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini") 

class State(TypedDict):
    question: str
    answer: str
    context: Annotated[list, operator.add]


def search_web(state):
    tavily_search = TavilySearchResults(max_results=5)
    search_docs = tavily_search.invoke(state['question'])

    results = [
        f'<Document>\n{doc["content"]}\n</Document>'
        for doc in search_docs
    ]

    return {"context": results}


def search_wikipedia(state):
    search_docs = WikipediaLoader(query=state['question'], load_max_docs=5).load()

    results = [
        f'<Document>\n{doc.page_content}\n</Document>'
        for doc in search_docs
    ]
    
    return {"context": results} 


def generate_answer(state):
    # System message
    system_message = SystemMessage(content=("""
        You are an AI assistant that answers questions based on the provided documents.
        Guidelines:
            - Provide direct, concise, and accurate answers.
            - When possible, cite the relevant document or URL.
            - If multiple documents contain relevant information, synthesize the best answer.

        If a document contains conflicting information, mention both perspectives.
    """))


    formatted_docs = "\n".join(
        [f"- {doc}" for doc in state["context"]]
    )

    system_context = SystemMessage(content=(f"Use the following documents as context for your response:\n\n{formatted_docs}"))

    answer = llm.invoke([system_message] + [system_context] + [HumanMessage(content=state["question"])])
    
    # Append it to state
    return {"answer": answer}

# Add nodes
builder = StateGraph(State)

# Initialize each node with node_secret 
builder.add_node("search_web",search_web)
builder.add_node("search_wikipedia", search_wikipedia)
builder.add_node("generate_answer", generate_answer)

# Flow
builder.add_edge(START, "search_wikipedia")
builder.add_edge(START, "search_web")
builder.add_edge("search_wikipedia", "generate_answer")
builder.add_edge("search_web", "generate_answer")
builder.add_edge("generate_answer", END)
graph = builder.compile()

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
result = graph.invoke({"question": "should i invest in AI stocks now"})
print(result['answer'].content)

In [None]:
result