In [None]:
# --- SUB-GRAPHS: SALES RESEARCH ASSISTANT EDITION ---

%pip install -U langchain_openai langgraph
import os, getpass

def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")

_set_env("OPENAI_API_KEY")
_set_env("LANGSMITH_API_KEY")

os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_PROJECT"] = "langchain-academy"

# --- Overview ---
# This notebook demonstrates Sub-Graphs using the Sales Analysis assistant.
# The main graph manages high-level coordination.
# Each subgraph performs a specific task:
#   - Product performance evaluation
#   - Regional analysis

from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from typing import Annotated
import operator
from pydantic import BaseModel
from IPython.display import Image

# --- Model ---
model = ChatOpenAI(model="gpt-4o", temperature=0)

# --- Prompts ---
products_prompt = """Generate 3 major products from the {category} category."""
region_prompt = """List 3 major global regions where these products are sold."""
performance_prompt = """Analyze product {product} in region {region}.
Write 2 bullet points on sales performance and growth opportunities."""
summary_prompt = """Combine these regional analyses into a final performance overview for {product}."""

# --- Schemas ---
class Products(BaseModel):
    products: list[str]

class Regions(BaseModel):
    regions: list[str]

class Performance(BaseModel):
    report: str

class OverallState(TypedDict):
    category: str
    products: list
    product_reports: Annotated[list, operator.add]

# --- Subgraph 1: Regional Analysis ---
# This subgraph analyzes one product across multiple regions.

class RegionalState(TypedDict):
    product: str
    regions: list
    regional_reports: Annotated[list, operator.add]

def generate_regions(state: RegionalState):
    prompt = region_prompt
    response = model.with_structured_output(Regions).invoke(prompt)
    return {"regions": response.regions}

def analyze_region(state: dict):
    prompt = performance_prompt.format(product=state["product"], region=state["region"])
    response = model.with_structured_output(Performance).invoke(prompt)
    return {"regional_reports": [f"{state['region']}: {response.report}"]}

from langgraph.types import Send

def continue_to_regions(state: RegionalState):
    return [Send("analyze_region", {"product": state["product"], "region": r}) for r in state["regions"]]

def summarize_product(state: RegionalState):
    text = "\n".join(state["regional_reports"])
    prompt = summary_prompt.format(product=state["product"])
    response = model.with_structured_output(Performance).invoke(prompt + "\n\n" + text)
    return {"product_reports": [f"{state['product']}: {response.report}"]}

# Build the subgraph for regional analysis
subgraph = StateGraph(RegionalState)
subgraph.add_node("generate_regions", generate_regions)
subgraph.add_node("analyze_region", analyze_region)
subgraph.add_node("summarize_product", summarize_product)
subgraph.add_edge(START, "generate_regions")
subgraph.add_conditional_edges("generate_regions", continue_to_regions, ["analyze_region"])
subgraph.add_edge("analyze_region", "summarize_product")
subgraph.add_edge("summarize_product", END)
subgraph_app = subgraph.compile()

# --- Main Graph (calls subgraphs) ---
def generate_products(state: OverallState):
    prompt = products_prompt.format(category=state["category"])
    response = model.with_structured_output(Products).invoke(prompt)
    return {"products": response.products}

def call_subgraph(state: OverallState):
    return [Send("regional_analysis", {"product": p}) for p in state["products"]]

# Integrate subgraph as a callable node
def regional_analysis(state: dict):
    result = subgraph_app.invoke({"product": state["product"], "regions": [], "regional_reports": []})
    return {"product_reports": result["product_reports"]}

def final_summary(state: OverallState):
    print("\n=== FINAL CATEGORY REPORT ===\n")
    for rep in state["product_reports"]:
        print(rep)
    print("\nâœ… Subgraph-driven research complete.")
    return {}

# --- Compose main graph ---
graph = StateGraph(OverallState)
graph.add_node("generate_products", generate_products)
graph.add_node("regional_analysis", regional_analysis)
graph.add_node("final_summary", final_summary)
graph.add_edge(START, "generate_products")
graph.add_conditional_edges("generate_products", call_subgraph, ["regional_analysis"])
graph.add_edge("regional_analysis", "final_summary")
graph.add_edge("final_summary", END)

app = graph.compile()
Image(app.get_graph().draw_mermaid_png())

# --- RUN ---
for s in app.stream({"category": "Electronics"}):
    print(s)
