In [None]:
"""
Production-style Tool Orchestration:
- Hugging Face embeddings for semantic tool retrieval (LangChain)
- LangGraph workflow orchestration
- Optional LlamaIndex ObjectIndex retriever

Install (minimal):
  pip install langchain-core langchain-community langchain-huggingface langgraph faiss-cpu

Optional (for LlamaIndex retriever):
  pip install llama-index

Optional (LangSmith tracing):
  export LANGCHAIN_TRACING_V2=true
  export LANGCHAIN_API_KEY=...
  export LANGCHAIN_PROJECT="tool-orchestration"
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Protocol, Tuple, TypedDict, Literal

# --- LangChain core types
from langchain_core.documents import Document
from langchain_core.tools import Tool

# --- HF embeddings (LangChain's Hugging Face integration)
from langchain_huggingface import HuggingFaceEmbeddings

# --- Vector store
from langchain_community.vectorstores import FAISS

# --- LangGraph
from langgraph.graph import StateGraph, START, END


# ============================================================
# 1) Tools (examples)
# ============================================================

def weather_api(city: str) -> str:
    return f"[demo] Weather in {city}: Sunny, 30C"

def fx_rate(base: str, quote: str) -> str:
    # demo
    return f"[demo] FX rate {base}/{quote} = 3.67"

def sql_query(query: str) -> str:
    # demo
    return f"[demo] Executed SQL: {query}"


WEATHER_TOOL = Tool(
    name="weather_api",
    description="Get current weather for a city. Inputs: city (string).",
    func=lambda city: weather_api(city),
)

FX_TOOL = Tool(
    name="fx_rate",
    description="Get FX exchange rate between two currencies. Inputs: base, quote (ISO codes).",
    func=lambda base, quote: fx_rate(base, quote),
)

SQL_TOOL = Tool(
    name="sql_query",
    description="Execute an analytics SQL query (select/aggregate/join) over a database. Input: query (string).",
    func=lambda query: sql_query(query),
)


# ============================================================
# 2) Tool Registry (stores tools + descriptions)
# ============================================================

@dataclass
class ToolSpec:
    tool: Tool
    tags: Tuple[str, ...] = ()
    # production metadata you might want to include
    cost: float = 1.0
    latency_ms: int = 100

class ToolRegistry:
    def __init__(self) -> None:
        self._tools: Dict[str, ToolSpec] = {}

    def register(self, spec: ToolSpec) -> None:
        self._tools[spec.tool.name] = spec

    def get(self, name: str) -> ToolSpec:
        return self._tools[name]

    def all(self) -> List[ToolSpec]:
        return list(self._tools.values())

    def tool_objects(self) -> List[Tool]:
        return [ts.tool for ts in self._tools.values()]


# ============================================================
# 3) Retriever interface + (A) LangChain VectorStore retriever
# ============================================================

@dataclass
class RetrievedTool:
    name: str
    score: float

class ToolRetriever(Protocol):
    def retrieve(self, query: str, top_k: int = 5, min_score: float = 0.2) -> List[RetrievedTool]:
        ...

class LangChainVectorToolRetriever:
    """
    Production pattern:
    - Embed tool descriptions using HuggingFaceEmbeddings
    - Store in FAISS
    - Similarity search at runtime (returns relevance scores)
    """
    def __init__(
        self,
        registry: ToolRegistry,
        *,
        hf_model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
    ) -> None:
        self.registry = registry
        self.embeddings = HuggingFaceEmbeddings(model_name=hf_model_name)
        self.vstore: Optional[FAISS] = None
        self._build_index()

    def _tool_doc(self, spec: ToolSpec) -> Document:
        txt = (
            f"name: {spec.tool.name}\n"
            f"description: {spec.tool.description or ''}\n"
            f"tags: {', '.join(spec.tags)}\n"
            f"cost: {spec.cost}\n"
            f"latency_ms: {spec.latency_ms}\n"
        )
        return Document(page_content=txt, metadata={"tool_name": spec.tool.name})

    def _build_index(self) -> None:
        docs = [self._tool_doc(spec) for spec in self.registry.all()]
        if not docs:
            self.vstore = None
            return
        self.vstore = FAISS.from_documents(docs, self.embeddings)

    def retrieve(self, query: str, top_k: int = 5, min_score: float = 0.2) -> List[RetrievedTool]:
        if self.vstore is None:
            return []

        # returns List[Tuple[Document, relevance_score]]
        results = self.vstore.similarity_search_with_relevance_scores(query, k=top_k)

        out: List[RetrievedTool] = []
        for doc, score in results:
            name = doc.metadata.get("tool_name")
            if not name:
                continue
            if float(score) >= min_score:
                out.append(RetrievedTool(name=name, score=float(score)))

        out.sort(key=lambda x: x.score, reverse=True)
        return out


# ============================================================
# 4) Optional retriever (B) LlamaIndex ObjectIndex retriever
# ============================================================

class LlamaIndexObjectToolRetriever:
    """
    Optional: Use LlamaIndex ObjectIndex to index/retrieve Tool objects.
    This is helpful when you want LlamaIndex-native tool retrieval patterns.
    """
    def __init__(
        self,
        registry: ToolRegistry,
        *,
        hf_model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
        top_k_default: int = 5,
    ) -> None:
        # Lazy imports so the file runs without llama-index installed
        from llama_index.core.objects import ObjectIndex
        from llama_index.core.embeddings import BaseEmbedding

        # Build a small adapter so LlamaIndex can call the same HF model
        # (kept simple; in real prod you might use a native LlamaIndex embedding class)
        class _HFEmbedAdapter(BaseEmbedding):
            def __init__(self, lc_embeddings: HuggingFaceEmbeddings):
                super().__init__()
                self.lc = lc_embeddings

            def _get_text_embedding(self, text: str) -> List[float]:
                return self.lc.embed_query(text)

            def _get_query_embedding(self, query: str) -> List[float]:
                return self.lc.embed_query(query)

            async def _aget_query_embedding(self, query: str) -> List[float]:
                return self.lc.embed_query(query)

            async def _aget_text_embedding(self, text: str) -> List[float]:
                return self.lc.embed_query(text)

        self.registry = registry
        self.top_k_default = top_k_default
        self.lc_embeddings = HuggingFaceEmbeddings(model_name=hf_model_name)
        self.li_embed = _HFEmbedAdapter(self.lc_embeddings)

        # LlamaIndex wants “objects”; we’ll index ToolSpec text representations but store names.
        objects = []
        for spec in registry.all():
            objects.append(
                {
                    "tool_name": spec.tool.name,
                    "text": f"{spec.tool.name}: {spec.tool.description or ''} tags={','.join(spec.tags)}"
                }
            )

        # Build ObjectIndex (simple: index dict objects)
        self.object_index = ObjectIndex.from_objects(objects, embed_model=self.li_embed)
        self.retriever = self.object_index.as_retriever(similarity_top_k=self.top_k_default)

    def retrieve(self, query: str, top_k: int = 5, min_score: float = 0.0) -> List[RetrievedTool]:
        # LlamaIndex retriever returns nodes with scores; API can vary by version.
        # We'll keep it robust.
        self.retriever.similarity_top_k = top_k
        nodes = self.retriever.retrieve(query)

        out: List[RetrievedTool] = []
        for n in nodes:
            # n.node.metadata may differ; try common places
            tool_name = None
            score = getattr(n, "score", 0.0) or 0.0
            node_obj = getattr(n, "node", None)

            if node_obj is not None:
                meta = getattr(node_obj, "metadata", None) or {}
                tool_name = meta.get("tool_name")

                # if we stored dict objects, metadata might be empty; fallback to text parse
                if tool_name is None:
                    text = getattr(node_obj, "text", None) or getattr(node_obj, "get_text", lambda: "")()
                    if isinstance(text, str) and ":" in text:
                        tool_name = text.split(":", 1)[0].strip()

            if tool_name and score >= min_score and tool_name in [ts.tool.name for ts in self.registry.all()]:
                out.append(RetrievedTool(name=tool_name, score=float(score)))

        out.sort(key=lambda x: x.score, reverse=True)
        return out


# ============================================================
# 5) LangGraph orchestration (retrieve -> choose -> execute)
# ============================================================

class AgentState(TypedDict):
    user_query: str
    # In real systems, you’ll also keep message history + observations
    tool_args: Dict[str, Any]        # provided by caller (or produced by an LLM in another node)
    candidates: List[RetrievedTool]
    chosen_tool: Optional[str]
    result: Optional[Any]
    error: Optional[str]

def retrieve_node(state: AgentState, retriever: ToolRetriever) -> AgentState:
    state["candidates"] = retriever.retrieve(state["user_query"], top_k=5, min_score=0.2)
    return state

def choose_node(state: AgentState) -> AgentState:
    """
    Production note:
    - In real agentic systems, an LLM chooses among candidates and generates args.
    - For a coding round, keep it deterministic: pick the top candidate.
    """
    state["chosen_tool"] = state["candidates"][0].name if state["candidates"] else None
    state["error"] = None if state["chosen_tool"] else "No suitable tool found."
    return state

def execute_node(state: AgentState, registry: ToolRegistry) -> AgentState:
    name = state.get("chosen_tool")
    if not name:
        return state

    try:
        tool = registry.get(name).tool
        # Tool.invoke expects a single input; we support dict for multi-arg tools.
        tool_input = state.get("tool_args", {})
        state["result"] = tool.invoke(tool_input)
        state["error"] = None
    except Exception as e:
        state["error"] = str(e)
    return state

def build_graph(registry: ToolRegistry, retriever: ToolRetriever):
    g = StateGraph(AgentState)
    g.add_node("retrieve", lambda s: retrieve_node(s, retriever))
    g.add_node("choose", choose_node)
    g.add_node("execute", lambda s: execute_node(s, registry))

    g.add_edge(START, "retrieve")
    g.add_edge("retrieve", "choose")
    g.add_edge("choose", "execute")
    g.add_edge("execute", END)

    return g.compile()


# ============================================================
# 6) Demo
# ============================================================

def main():
    registry = ToolRegistry()
    registry.register(ToolSpec(tool=WEATHER_TOOL, tags=("weather", "forecast"), cost=1.0, latency_ms=120))
    registry.register(ToolSpec(tool=FX_TOOL, tags=("finance", "fx"), cost=0.5, latency_ms=80))
    registry.register(ToolSpec(tool=SQL_TOOL, tags=("sql", "analytics", "database"), cost=2.0, latency_ms=200))

    # Choose ONE retriever:

    # (A) LangChain VectorStore retriever (recommended for this coding round)
    retriever: ToolRetriever = LangChainVectorToolRetriever(
        registry,
        hf_model_name="sentence-transformers/all-MiniLM-L6-v2"
    )

    # (B) LlamaIndex ObjectIndex retriever (optional)
    # retriever = LlamaIndexObjectToolRetriever(
    #     registry,
    #     hf_model_name="sentence-transformers/all-MiniLM-L6-v2"
    # )

    graph = build_graph(registry, retriever)

    # Example run 1
    state: AgentState = {
        "user_query": "I need the weather forecast for Dubai",
        "tool_args": {"city": "Dubai"},
        "candidates": [],
        "chosen_tool": None,
        "result": None,
        "error": None,
    }
    out = graph.invoke(state)
    print("Chosen:", out["chosen_tool"])
    print("Result:", out["result"])
    print("Error:", out["error"])

    # Example run 2
    state2: AgentState = {
        "user_query": "Give me exchange rate from AED to USD",
        "tool_args": {"base": "AED", "quote": "USD"},
        "candidates": [],
        "chosen_tool": None,
        "result": None,
        "error": None,
    }
    out2 = graph.invoke(state2)
    print("Chosen:", out2["chosen_tool"])
    print("Result:", out2["result"])
    print("Error:", out2["error"])


if __name__ == "__main__":
    main()
