diff --git a/pybatfish/mcp/__init__.py b/pybatfish/mcp/__init__.py index b2091c76..45e78676 100644 --- a/pybatfish/mcp/__init__.py +++ b/pybatfish/mcp/__init__.py @@ -30,6 +30,9 @@ # Or run with a specific Batfish host: BATFISH_HOST=my-batfish-host python -m pybatfish.mcp + + # Or configure sessions in ~/.batfish/sessions.json: + # {"default": {"type": "bf", "params": {"host": "localhost"}}} """ from pybatfish.mcp.server import create_server diff --git a/pybatfish/mcp/__main__.py b/pybatfish/mcp/__main__.py index 422442ea..3a61bc28 100644 --- a/pybatfish/mcp/__main__.py +++ b/pybatfish/mcp/__main__.py @@ -26,17 +26,39 @@ batfish-mcp -Environment variables: +Session configuration: -* ``BATFISH_HOST`` — hostname of the Batfish server (default: ``localhost``). +Sessions can be configured in ``~/.batfish/sessions.json``:: + + { + "default": {"type": "bf", "params": {"host": "localhost"}}, + "other": {"type": "bf", "params": {"host": "batfish2.example.com"}} + } + +Precedence for the default session: + +1. ``"default"`` entry in the sessions config file (if present) +2. ``BATFISH_HOST`` environment variable +3. ``localhost`` """ +import argparse +from pathlib import Path + from pybatfish.mcp.server import create_server def main() -> None: """Start the Batfish MCP server using stdio transport.""" - server = create_server() + parser = argparse.ArgumentParser(description="Batfish MCP server (Beta)") + parser.add_argument( + "--sessions-config", + type=Path, + default=None, + help="Path to sessions JSON config file (default: ~/.batfish/sessions.json)", + ) + args = parser.parse_args() + server = create_server(sessions_config=args.sessions_config) server.run(transport="stdio") diff --git a/pybatfish/mcp/server.py b/pybatfish/mcp/server.py index 6747753c..a1fc3543 100644 --- a/pybatfish/mcp/server.py +++ b/pybatfish/mcp/server.py @@ -37,6 +37,8 @@ "The 'mcp' package is required to use the Batfish MCP server. Install it with: pip install 'pybatfish[mcp]'" ) from e +from pathlib import Path + from pybatfish.client.session import Session from pybatfish.datamodel import HeaderConstraints, Interface @@ -53,68 +55,84 @@ ] ) -# Per-host Session cache. Question templates are downloaded from the Batfish -# service exactly once per host per process lifetime, covering both management -# and analysis operations. +# Default path for the sessions configuration file. +_SESSIONS_CONFIG_PATH = Path.home() / ".batfish" / "sessions.json" + +# Named session registry. Sessions are created lazily from their stored +# configs and cached for the lifetime of the process. +_session_configs: dict[str, dict[str, Any]] = {} _session_cache: dict[str, Session] = {} _session_cache_lock = threading.Lock() -def _get_session(host: str) -> Session: - """Return the cached Batfish Session for the given host. +def _load_sessions_config(path: Path = _SESSIONS_CONFIG_PATH) -> None: + """Load session configurations from a JSON file. - The session is retrieved from (or added to) a process-level cache keyed by - *host*, so that question templates are downloaded from the Batfish service - **at most once per process** rather than on every tool call. + The file should contain a JSON object mapping session names to + ``{"type": "", "params": {}}``. + If no ``"default"`` session is configured (either because the file + does not exist or because it doesn't define one), a default ``bf`` + session is created using the ``BATFISH_HOST`` environment variable + (falling back to ``localhost``). """ + if path.exists(): + with open(path) as f: + configs = json.load(f) + for name, cfg in configs.items(): + _session_configs[name] = cfg + if "default" not in _session_configs: + host = os.environ.get("BATFISH_HOST", "localhost") + _session_configs["default"] = {"type": "bf", "params": {"host": host}} + + +def _register_session(name: str, type_: str, **params: Any) -> Session: + """Register and immediately create a named session.""" + _session_configs[name] = {"type": type_, "params": params} with _session_cache_lock: - if host not in _session_cache: - _session_cache[host] = Session(host=host) - return _session_cache[host] + _session_cache.pop(name, None) + return _get_session(name) -def _clear_session_cache() -> None: - """Clear the per-host session cache. +def _get_session(name: str = "default") -> Session: + """Return the cached Session for the given name. - Intended for use in tests and in situations where the caller wants to - force question templates to be re-fetched from the Batfish service. + Creates the session lazily from ``_session_configs`` on first access. """ with _session_cache_lock: - _session_cache.clear() + if name not in _session_cache: + cfg = _session_configs.get(name) + if cfg is None: + raise ValueError( + f"No session named '{name}'. " + f"Available sessions: {sorted(_session_configs.keys())}. " + "Use the register_session tool to create one." + ) + _session_cache[name] = Session.get(cfg["type"], **cfg.get("params", {})) + return _session_cache[name] -def _resolve_host(host: str) -> str: - """Return the effective Batfish hostname. +def _clear_session_cache() -> None: + """Clear the session cache and configs. - Returns *host* if non-empty; otherwise falls back to the - ``BATFISH_HOST`` environment variable, and finally to ``'localhost'``. + Intended for use in tests and in situations where the caller wants to + force sessions to be re-created. """ - return host or os.environ.get("BATFISH_HOST", "localhost") - + with _session_cache_lock: + _session_cache.clear() + _session_configs.clear() -def _mgmt_session(host: str, network: str = "") -> Session: - """Return the cached session with an optional network set. - Resolves the effective hostname, fetches (or creates) the per-host cached - session, and optionally calls :meth:`~Session.set_network` when *network* - is provided. Use this for tools that perform network or snapshot management - operations (e.g. ``list_networks``, ``init_snapshot``, ``delete_snapshot``). - """ - bf = _get_session(_resolve_host(host)) +def _mgmt_session(session: str, network: str = "") -> Session: + """Return the named session with an optional network set.""" + bf = _get_session(session) if network: bf.set_network(network) return bf -def _analysis_session(host: str, network: str, snapshot: str) -> Session: - """Return the cached session with network and snapshot set. - - Resolves the effective hostname, fetches (or creates) the per-host cached - session, then calls :meth:`~Session.set_network` and - :meth:`~Session.set_snapshot`. Use this for all tools that invoke Batfish - questions. - """ - bf = _get_session(_resolve_host(host)) +def _analysis_session(session: str, network: str, snapshot: str) -> Session: + """Return the named session with network and snapshot set.""" + bf = _get_session(session) bf.set_network(network) bf.set_snapshot(snapshot) return bf @@ -143,7 +161,11 @@ def _drop_legacy_nexthop_columns(df: Any) -> Any: return df -def create_server(name: str = "Batfish") -> FastMCP: +def create_server( + name: str = "Batfish", + default_session: Session | None = None, + sessions_config: Path | None = None, +) -> FastMCP: """Create and return a configured Batfish MCP server (Beta). .. warning:: @@ -151,8 +173,20 @@ def create_server(name: str = "Batfish") -> FastMCP: return formats may change in future releases without prior notice. :param name: Name for the MCP server (default: "Batfish") + :param default_session: Optional pre-created session to register as "default". + :param sessions_config: Path to sessions JSON config file. + Defaults to ``~/.batfish/sessions.json``. :return: Configured FastMCP server instance """ + # Load session configs from file (or set up BATFISH_HOST default). + _load_sessions_config(sessions_config or _SESSIONS_CONFIG_PATH) + + # If a pre-created session was provided, register it as "default". + if default_session is not None: + with _session_cache_lock: + _session_configs["default"] = {"type": "precreated", "params": {}} + _session_cache["default"] = default_session + mcp = FastMCP( name, instructions=( @@ -160,48 +194,82 @@ def create_server(name: str = "Batfish") -> FastMCP: "Note: this MCP server is in beta — tool names and parameters may change in future releases. " "Use these tools to load network snapshots, run traceroutes, analyze reachability, " "inspect ACLs/firewall rules, query routing tables, and compare snapshots. " - "Most tools require a 'host' parameter (Batfish server hostname, defaults to " - "the BATFISH_HOST environment variable or 'localhost'), a 'network' parameter " - "(the network name in Batfish), and a 'snapshot' parameter (the snapshot name). " + "Most tools require a 'network' parameter (the network name in Batfish) " + "and a 'snapshot' parameter (the snapshot name). " + "All tools accept an optional 'session' parameter to select a named session " + "(default: 'default'). Use register_session to configure additional sessions. " "Start by listing networks or initializing a snapshot, then run analysis tools." ), ) + # ------------------------------------------------------------------------- + # Session management tools + # ------------------------------------------------------------------------- + + @mcp.tool() + def register_session( + name: str, + type: str = "bf", + params: str = "{}", + ) -> str: + """Register a new named session for use with all other tools. + + The session type must be a registered pybatfish session entry point + (e.g. 'bf' for standard Batfish, 'dhalperianvdemo' for ANVDemo). + + :param name: Name for the session (used as the 'session' parameter in other tools). + :param type: Session type entry point name (default: 'bf'). + :param params: JSON object of constructor keyword arguments for the session type + (e.g. '{"host": "localhost"}' for bf). + :return: JSON object confirming registration. + """ + parsed_params = json.loads(params) if isinstance(params, str) else params + _register_session(name, type, **parsed_params) + return json.dumps({"registered": name, "type": type}) + + @mcp.tool() + def list_sessions() -> str: + """List all registered session names and their types. + + :return: JSON object mapping session names to their types. + """ + return json.dumps({name: cfg.get("type", "unknown") for name, cfg in _session_configs.items()}) + # ------------------------------------------------------------------------- # Network management tools # ------------------------------------------------------------------------- @mcp.tool() - def list_networks(host: str = "") -> str: + def list_networks(session: str = "default") -> str: """List all available networks on the Batfish server. - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of network names. """ - bf = _mgmt_session(host) + bf = _mgmt_session(session) return json.dumps(bf.list_networks()) @mcp.tool() - def set_network(network: str, host: str = "") -> str: + def set_network(network: str, session: str = "default") -> str: """Create or select a network on the Batfish server. :param network: Name of the network to create or select. - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON object with the active network name. """ - bf = _mgmt_session(host) + bf = _mgmt_session(session) name = bf.set_network(network) return json.dumps({"network": name}) @mcp.tool() - def delete_network(network: str, host: str = "") -> str: + def delete_network(network: str, session: str = "default") -> str: """Delete a network from the Batfish server. :param network: Name of the network to delete. - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON object confirming deletion. """ - bf = _mgmt_session(host) + bf = _mgmt_session(session) bf.delete_network(network) return json.dumps({"deleted": network}) @@ -210,14 +278,14 @@ def delete_network(network: str, host: str = "") -> str: # ------------------------------------------------------------------------- @mcp.tool() - def list_snapshots(network: str, host: str = "") -> str: + def list_snapshots(network: str, session: str = "default") -> str: """List all snapshots within a network. :param network: Name of the network. - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of snapshot names. """ - bf = _mgmt_session(host, network) + bf = _mgmt_session(session, network) return json.dumps(bf.list_snapshots()) @mcp.tool() @@ -226,7 +294,7 @@ def init_snapshot( snapshot_path: str, snapshot_name: str = "", overwrite: bool = False, - host: str = "", + session: str = "default", ) -> str: """Initialize a new snapshot from a local directory or zip file. @@ -237,10 +305,10 @@ def init_snapshot( :param snapshot_path: Local path to a snapshot directory or zip file. :param snapshot_name: Optional name for the snapshot. Auto-generated if empty. :param overwrite: Whether to overwrite an existing snapshot with the same name. - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON object with the initialized snapshot name. """ - bf = _mgmt_session(host, network) + bf = _mgmt_session(session, network) name = bf.init_snapshot( snapshot_path, name=snapshot_name or None, @@ -256,7 +324,7 @@ def init_snapshot_from_text( snapshot_name: str = "", platform: str = "", overwrite: bool = False, - host: str = "", + session: str = "default", ) -> str: """Initialize a single-device snapshot from configuration text. @@ -270,10 +338,10 @@ def init_snapshot_from_text( :param platform: RANCID platform string (e.g. 'cisco-nx', 'arista', 'juniper'). If empty, the platform is inferred from the configuration header. :param overwrite: Whether to overwrite an existing snapshot with the same name. - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON object with the initialized snapshot name. """ - bf = _mgmt_session(host, network) + bf = _mgmt_session(session, network) name = bf.init_snapshot_from_text( config_text, filename=filename, @@ -284,15 +352,15 @@ def init_snapshot_from_text( return json.dumps({"snapshot": name}) @mcp.tool() - def delete_snapshot(network: str, snapshot: str, host: str = "") -> str: + def delete_snapshot(network: str, snapshot: str, session: str = "default") -> str: """Delete a snapshot from a network. :param network: Name of the network containing the snapshot. :param snapshot: Name of the snapshot to delete. - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON object confirming deletion. """ - bf = _mgmt_session(host, network) + bf = _mgmt_session(session, network) bf.delete_snapshot(snapshot) return json.dumps({"deleted": snapshot}) @@ -306,7 +374,7 @@ def fork_snapshot( restore_nodes: str = "", restore_interfaces: str = "", overwrite: bool = False, - host: str = "", + session: str = "default", ) -> str: """Fork an existing snapshot, optionally deactivating or restoring nodes/interfaces. @@ -321,10 +389,10 @@ def fork_snapshot( :param restore_nodes: Comma-separated list of node names to restore. :param restore_interfaces: Comma-separated list of 'node[interface]' pairs to restore. :param overwrite: Whether to overwrite an existing snapshot with the same name. - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON object with the forked snapshot name. """ - bf = _mgmt_session(host, network) + bf = _mgmt_session(session, network) deactivate_nodes_list = [n.strip() for n in deactivate_nodes.split(",") if n.strip()] or None restore_nodes_list = [n.strip() for n in restore_nodes.split(",") if n.strip()] or None @@ -358,7 +426,7 @@ def run_traceroute( ip_protocols: str = "", src_ports: str = "", dst_ports: str = "", - host: str = "", + session: str = "default", ) -> str: """Simulate a traceroute from a location to a destination IP address. @@ -374,10 +442,10 @@ def run_traceroute( :param ip_protocols: IP protocol(s) e.g. 'TCP' (optional). :param src_ports: Source port(s) e.g. '1024-65535' (optional). :param dst_ports: Destination port(s) e.g. '22' (optional). - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of traceroute result rows. """ - bf = _analysis_session(host, network, snapshot) + bf = _analysis_session(session, network, snapshot) headers = _build_header_constraints( dst_ips=dst_ips, @@ -401,7 +469,7 @@ def run_bidirectional_traceroute( ip_protocols: str = "", src_ports: str = "", dst_ports: str = "", - host: str = "", + session: str = "default", ) -> str: """Simulate a bidirectional traceroute (forward + reverse paths). @@ -417,10 +485,10 @@ def run_bidirectional_traceroute( :param ip_protocols: IP protocol(s) (optional). :param src_ports: Source port(s) (optional). :param dst_ports: Destination port(s) (optional). - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of bidirectional traceroute result rows. """ - bf = _analysis_session(host, network, snapshot) + bf = _analysis_session(session, network, snapshot) headers = _build_header_constraints( dst_ips=dst_ips, @@ -445,7 +513,7 @@ def check_reachability( src_ports: str = "", dst_ports: str = "", actions: str = "", - host: str = "", + session: str = "default", ) -> str: """Check reachability between network locations. @@ -462,10 +530,10 @@ def check_reachability( :param src_ports: Source port(s) (optional). :param dst_ports: Destination port(s) (optional). :param actions: Disposition filter, e.g. 'DENIED_IN,DENIED_OUT,DROP' (optional). - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of reachability result rows. """ - bf = _analysis_session(host, network, snapshot) + bf = _analysis_session(session, network, snapshot) headers = _build_header_constraints( dst_ips=dst_ips, @@ -494,7 +562,7 @@ def analyze_acl( snapshot: str, filters: str = "", nodes: str = "", - host: str = "", + session: str = "default", ) -> str: """Identify unreachable (shadowed) lines in ACLs and firewall rules. @@ -505,10 +573,10 @@ def analyze_acl( :param snapshot: Name of the snapshot. :param filters: Filter specifier to restrict analysis (optional). :param nodes: Node specifier to restrict analysis (optional). - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of unreachable ACL/filter line rows. """ - bf = _analysis_session(host, network, snapshot) + bf = _analysis_session(session, network, snapshot) kwargs: dict[str, Any] = {} if filters: @@ -532,7 +600,7 @@ def search_filters( src_ports: str = "", dst_ports: str = "", action: str = "", - host: str = "", + session: str = "default", ) -> str: """Search for flows that match specific filter (ACL/firewall) criteria. @@ -550,10 +618,10 @@ def search_filters( :param src_ports: Source port(s) (optional). :param dst_ports: Destination port(s) (optional). :param action: Filter action: 'PERMIT' or 'DENY' (optional). - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of matched flow rows. """ - bf = _analysis_session(host, network, snapshot) + bf = _analysis_session(session, network, snapshot) headers = _build_header_constraints( dst_ips=dst_ips, @@ -586,7 +654,7 @@ def get_routes( vrfs: str = "", network_prefix: str = "", protocols: str = "", - host: str = "", + session: str = "default", ) -> str: """Retrieve the routing table (RIB) from one or more devices. @@ -599,10 +667,10 @@ def get_routes( :param vrfs: VRF specifier to restrict results (optional). :param network_prefix: Prefix to filter routes by (optional). :param protocols: Routing protocol(s) to filter by, e.g. 'bgp,ospf' (optional). - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of routing table rows. """ - bf = _analysis_session(host, network, snapshot) + bf = _analysis_session(session, network, snapshot) kwargs: dict[str, Any] = {} if nodes: @@ -626,7 +694,7 @@ def compare_routes( vrfs: str = "", network_prefix: str = "", protocols: str = "", - host: str = "", + session: str = "default", ) -> str: """Compare routing tables between two snapshots to identify route changes. @@ -643,10 +711,10 @@ def compare_routes( :param vrfs: VRF specifier to restrict results (optional). :param network_prefix: Prefix to filter routes by (optional). :param protocols: Routing protocol(s) to filter by (optional). - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array showing route differences (added/removed routes). """ - bf = _analysis_session(host, network, snapshot) + bf = _analysis_session(session, network, snapshot) kwargs: dict[str, Any] = {} if nodes: @@ -674,7 +742,7 @@ def get_bgp_session_status( nodes: str = "", remote_nodes: str = "", status: str = "", - host: str = "", + session: str = "default", ) -> str: """Get the status of BGP sessions in a snapshot. @@ -685,10 +753,10 @@ def get_bgp_session_status( :param nodes: Node specifier for local BGP speakers (optional). :param remote_nodes: Node specifier for remote BGP speakers (optional). :param status: BGP session status specifier to filter by (optional). - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of BGP session status rows. """ - bf = _analysis_session(host, network, snapshot) + bf = _analysis_session(session, network, snapshot) kwargs: dict[str, Any] = {} if nodes: @@ -708,7 +776,7 @@ def get_bgp_session_compatibility( nodes: str = "", remote_nodes: str = "", status: str = "", - host: str = "", + session: str = "default", ) -> str: """Check BGP session compatibility between peers. @@ -725,10 +793,10 @@ def get_bgp_session_compatibility( :param nodes: Node specifier for local BGP speakers (optional). :param remote_nodes: Node specifier for remote BGP speakers (optional). :param status: BGP compatibility status specifier to filter by (optional). - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of BGP compatibility rows. """ - bf = _analysis_session(host, network, snapshot) + bf = _analysis_session(session, network, snapshot) kwargs: dict[str, Any] = {} if nodes: @@ -751,7 +819,7 @@ def get_node_properties( snapshot: str, nodes: str = "", properties: str = "", - host: str = "", + session: str = "default", ) -> str: """Retrieve configuration properties of network nodes (routers/switches). @@ -759,10 +827,10 @@ def get_node_properties( :param snapshot: Name of the snapshot. :param nodes: Node specifier to restrict results (optional). :param properties: Comma-separated list of property names to retrieve (optional). - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of node property rows. """ - bf = _analysis_session(host, network, snapshot) + bf = _analysis_session(session, network, snapshot) kwargs: dict[str, Any] = {} if nodes: @@ -780,7 +848,7 @@ def get_interface_properties( nodes: str = "", interfaces: str = "", properties: str = "", - host: str = "", + session: str = "default", ) -> str: """Retrieve configuration properties of network interfaces. @@ -789,10 +857,10 @@ def get_interface_properties( :param nodes: Node specifier to restrict results (optional). :param interfaces: Interface specifier to restrict results (optional). :param properties: Comma-separated list of property names to retrieve (optional). - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of interface property rows. """ - bf = _analysis_session(host, network, snapshot) + bf = _analysis_session(session, network, snapshot) kwargs: dict[str, Any] = {} if nodes: @@ -810,17 +878,17 @@ def get_ip_owners( network: str, snapshot: str, duplicates_only: bool = False, - host: str = "", + session: str = "default", ) -> str: """Get the mapping of IP addresses to network interfaces. :param network: Name of the network. :param snapshot: Name of the snapshot. :param duplicates_only: If True, return only IPs assigned to multiple interfaces. - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of IP ownership rows. """ - bf = _analysis_session(host, network, snapshot) + bf = _analysis_session(session, network, snapshot) result = bf.q.ipOwners(duplicatesOnly=duplicates_only).answer().frame() # type: ignore[attr-defined] return _df_to_json(result) @@ -836,7 +904,7 @@ def compare_filters( reference_snapshot: str, filters: str = "", nodes: str = "", - host: str = "", + session: str = "default", ) -> str: """Compare ACL/firewall filter behavior between two snapshots. @@ -848,10 +916,10 @@ def compare_filters( :param reference_snapshot: Name of the reference (baseline) snapshot. :param filters: Filter specifier to restrict comparison (optional). :param nodes: Node specifier to restrict comparison (optional). - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of filter difference rows. """ - bf = _analysis_session(host, network, snapshot) + bf = _analysis_session(session, network, snapshot) kwargs: dict[str, Any] = {} if filters: @@ -867,7 +935,7 @@ def get_undefined_references( network: str, snapshot: str, nodes: str = "", - host: str = "", + session: str = "default", ) -> str: """Find undefined references in device configurations. @@ -877,10 +945,10 @@ def get_undefined_references( :param network: Name of the network. :param snapshot: Name of the snapshot. :param nodes: Node specifier to restrict results (optional). - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of undefined reference rows. """ - bf = _analysis_session(host, network, snapshot) + bf = _analysis_session(session, network, snapshot) kwargs: dict[str, Any] = {} if nodes: @@ -893,7 +961,7 @@ def get_undefined_references( def detect_loops( network: str, snapshot: str, - host: str = "", + session: str = "default", ) -> str: """Detect forwarding loops in the network snapshot. @@ -902,10 +970,10 @@ def detect_loops( :param network: Name of the network. :param snapshot: Name of the snapshot. - :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :param session: Named session to use (default: 'default'). :return: JSON array of forwarding loop rows (empty if no loops found). """ - bf = _analysis_session(host, network, snapshot) + bf = _analysis_session(session, network, snapshot) result = bf.q.detectLoops().answer().frame() # type: ignore[attr-defined] return _df_to_json(result) diff --git a/tests/mcp/test_server.py b/tests/mcp/test_server.py index 941b0e5f..fb2feb8a 100644 --- a/tests/mcp/test_server.py +++ b/tests/mcp/test_server.py @@ -18,6 +18,7 @@ import asyncio import json +from pathlib import Path from typing import Any from unittest.mock import MagicMock, patch @@ -31,12 +32,24 @@ _clear_session_cache, _df_to_json, _drop_legacy_nexthop_columns, + _get_session, + _load_sessions_config, _mgmt_session, _parse_interfaces, - _resolve_host, + _register_session, + _session_configs, create_server, ) +# Path that does not exist, used to avoid loading ~/.batfish/sessions.json +_NO_CONFIG = Path("/nonexistent/sessions.json") + + +def _init_default_sessions() -> None: + """Load session configs from a nonexistent file to get the BATFISH_HOST/localhost default.""" + _load_sessions_config(_NO_CONFIG) + + # --------------------------------------------------------------------------- # Helper factories # --------------------------------------------------------------------------- @@ -148,74 +161,74 @@ def test_all_fields(self): assert hc.dstPorts == "22" -class TestSessionCache: - """Tests for the per-host session cache in _get_session.""" +class TestSessionRegistry: + """Tests for the session registry and config loading.""" def setup_method(self): - """Clear the cache before each test to ensure isolation.""" _clear_session_cache() def teardown_method(self): - """Clear the cache after each test.""" _clear_session_cache() - def test_session_is_cached(self): - """Session must be created only once for the same host.""" + def test_default_session_from_env(self, monkeypatch): + monkeypatch.setenv("BATFISH_HOST", "env-host") + _init_default_sessions() + assert _session_configs["default"] == {"type": "bf", "params": {"host": "env-host"}} + + def test_default_session_localhost(self, monkeypatch): + monkeypatch.delenv("BATFISH_HOST", raising=False) + _init_default_sessions() + assert _session_configs["default"]["params"]["host"] == "localhost" + + def test_load_from_config_file(self, tmp_path): + config = {"prod": {"type": "bf", "params": {"host": "prod-host"}}} + config_file = tmp_path / "sessions.json" + config_file.write_text(json.dumps(config)) + _load_sessions_config(config_file) + assert "prod" in _session_configs + assert _session_configs["prod"]["params"]["host"] == "prod-host" + # Default is still added + assert "default" in _session_configs + + def test_config_file_overrides_default(self, tmp_path): + config = {"default": {"type": "bf", "params": {"host": "custom-host"}}} + config_file = tmp_path / "sessions.json" + config_file.write_text(json.dumps(config)) + _load_sessions_config(config_file) + assert _session_configs["default"]["params"]["host"] == "custom-host" + + def test_register_session_creates_and_caches(self): mock_session = MagicMock() - with patch("pybatfish.mcp.server.Session", return_value=mock_session) as MockSession: - from pybatfish.mcp.server import _get_session + with patch("pybatfish.mcp.server.Session") as MockSession: + MockSession.get.return_value = mock_session + _init_default_sessions() + result = _register_session("test", "bf", host="test-host") + assert result is mock_session + assert _session_configs["test"] == {"type": "bf", "params": {"host": "test-host"}} - s1 = _get_session("bf-host") - s2 = _get_session("bf-host") + def test_get_session_uses_session_get(self): + mock_session = MagicMock() + with patch("pybatfish.mcp.server.Session") as MockSession: + MockSession.get.return_value = mock_session + _init_default_sessions() + result = _get_session("default") + MockSession.get.assert_called_once_with("bf", host="localhost") + assert result is mock_session - # Session constructor called only once - assert MockSession.call_count == 1 - # Both calls return the same cached object + def test_get_session_caches(self): + mock_session = MagicMock() + with patch("pybatfish.mcp.server.Session") as MockSession: + MockSession.get.return_value = mock_session + _init_default_sessions() + s1 = _get_session("default") + s2 = _get_session("default") + assert MockSession.get.call_count == 1 assert s1 is s2 - def test_different_hosts_get_different_cached_sessions(self): - """Each host gets its own independent cache entry.""" - mock_a = MagicMock() - mock_b = MagicMock() - sessions = [mock_a, mock_b] - with patch("pybatfish.mcp.server.Session", side_effect=sessions) as MockSession: - from pybatfish.mcp.server import _get_session - - sa = _get_session("host-a") - sb = _get_session("host-b") - - assert MockSession.call_count == 2 - assert sa is not sb - - def test_clear_session_cache_forces_new_session(self): - """After _clear_session_cache(), the next call creates a fresh session.""" - from pybatfish.mcp.server import _clear_session_cache, _get_session - - mock_first = MagicMock() - mock_second = MagicMock() - sessions = [mock_first, mock_second] - with patch("pybatfish.mcp.server.Session", side_effect=sessions) as MockSession: - s1 = _get_session("bf-host") - _clear_session_cache() - s2 = _get_session("bf-host") - - assert MockSession.call_count == 2 - assert s1 is not s2 - - -class TestResolveHost: - """Tests for the _resolve_host() helper.""" - - def test_returns_explicit_host(self): - assert _resolve_host("my-host") == "my-host" - - def test_falls_back_to_env_var(self, monkeypatch): - monkeypatch.setenv("BATFISH_HOST", "env-host") - assert _resolve_host("") == "env-host" - - def test_falls_back_to_localhost(self, monkeypatch): - monkeypatch.delenv("BATFISH_HOST", raising=False) - assert _resolve_host("") == "localhost" + def test_get_session_unknown_raises(self): + _init_default_sessions() + with pytest.raises(ValueError, match="No session named"): + _get_session("nonexistent") class TestMgmtSession: @@ -227,39 +240,39 @@ def setup_method(self): def teardown_method(self): _clear_session_cache() - def test_creates_cached_session(self): + def test_returns_session(self): + mock_session = MagicMock() with patch("pybatfish.mcp.server.Session") as MockSession: - MockSession.return_value = MagicMock() - _mgmt_session("localhost") - MockSession.assert_called_once_with(host="localhost") + MockSession.get.return_value = mock_session + _init_default_sessions() + result = _mgmt_session("default") + assert result is mock_session def test_sets_network_when_provided(self): mock_session = MagicMock() - with patch("pybatfish.mcp.server.Session", return_value=mock_session): - _mgmt_session("localhost", "my-network") + with patch("pybatfish.mcp.server.Session") as MockSession: + MockSession.get.return_value = mock_session + _init_default_sessions() + _mgmt_session("default", "my-network") mock_session.set_network.assert_called_once_with("my-network") def test_skips_set_network_when_empty(self): mock_session = MagicMock() - with patch("pybatfish.mcp.server.Session", return_value=mock_session): - _mgmt_session("localhost", "") - mock_session.set_network.assert_not_called() - - def test_resolves_host_from_env(self, monkeypatch): - monkeypatch.setenv("BATFISH_HOST", "env-bf") with patch("pybatfish.mcp.server.Session") as MockSession: - MockSession.return_value = MagicMock() - _mgmt_session("") - MockSession.assert_called_once_with(host="env-bf") + MockSession.get.return_value = mock_session + _init_default_sessions() + _mgmt_session("default", "") + mock_session.set_network.assert_not_called() def test_shares_cache_with_analysis_session(self): """_mgmt_session and _analysis_session must return the same cached session.""" mock_session = MagicMock() - with patch("pybatfish.mcp.server.Session", return_value=mock_session) as MockSession: - s1 = _mgmt_session("localhost") - s2 = _analysis_session("localhost", "net1", "snap1") - # Session constructor called only once — both helpers share the cache - assert MockSession.call_count == 1 + with patch("pybatfish.mcp.server.Session") as MockSession: + MockSession.get.return_value = mock_session + _init_default_sessions() + s1 = _mgmt_session("default") + s2 = _analysis_session("default", "net1", "snap1") + assert MockSession.get.call_count == 1 assert s1 is s2 @@ -274,25 +287,13 @@ def teardown_method(self): def test_sets_network_and_snapshot(self): mock_session = MagicMock() - with patch("pybatfish.mcp.server.Session", return_value=mock_session): - _analysis_session("localhost", "net1", "snap1") + with patch("pybatfish.mcp.server.Session") as MockSession: + MockSession.get.return_value = mock_session + _init_default_sessions() + _analysis_session("default", "net1", "snap1") mock_session.set_network.assert_called_once_with("net1") mock_session.set_snapshot.assert_called_once_with("snap1") - def test_creates_cached_session(self): - with patch("pybatfish.mcp.server.Session") as MockSession: - MockSession.return_value = MagicMock() - _analysis_session("localhost", "net1", "snap1") - MockSession.assert_called_once_with(host="localhost") - - def test_resolves_host_from_env(self, monkeypatch): - _clear_session_cache() - monkeypatch.setenv("BATFISH_HOST", "env-bf") - with patch("pybatfish.mcp.server.Session") as MockSession: - MockSession.return_value = MagicMock() - _analysis_session("", "net1", "snap1") - MockSession.assert_called_once_with(host="env-bf") - class TestDropLegacyNexthopColumns: def test_drops_known_legacy_columns(self): @@ -337,6 +338,12 @@ def test_drops_camelcase_variants(self): class TestCreateServer: + def setup_method(self): + _clear_session_cache() + + def teardown_method(self): + _clear_session_cache() + def test_returns_fastmcp_instance(self): from mcp.server.fastmcp import FastMCP @@ -347,22 +354,61 @@ def test_custom_name(self): server = create_server(name="MyBatfish") assert server.name == "MyBatfish" + def test_default_session_injection(self): + mock_session = MagicMock() + mock_session.list_networks.return_value = ["net1"] + server = create_server(default_session=mock_session) + data = _call_tool(server, "list_networks", {}) + assert data == ["net1"] + + +class TestRegisterSessionTool: + def setup_method(self): + _clear_session_cache() + + def teardown_method(self): + _clear_session_cache() + + def test_registers_session(self): + mock_session = MagicMock() + with patch("pybatfish.mcp.server.Session") as MockSession: + MockSession.get.return_value = mock_session + server = create_server() + data = _call_tool( + server, + "register_session", + {"name": "test", "type": "bf", "params": '{"host": "test-host"}'}, + ) + assert data == {"registered": "test", "type": "bf"} + + +class TestListSessionsTool: + def setup_method(self): + _clear_session_cache() + + def teardown_method(self): + _clear_session_cache() + + def test_lists_sessions(self): + server = create_server() + data = _call_tool(server, "list_sessions", {}) + assert "default" in data + class TestListNetworksTool: def test_returns_network_list(self): mock_session = _make_session_mock(list_networks=["net1", "net2"]) with patch(PATCH_TARGET, return_value=mock_session): server = create_server() - data = _call_tool(server, "list_networks", {"host": "localhost"}) + data = _call_tool(server, "list_networks", {}) assert data == ["net1", "net2"] - def test_uses_env_host(self, monkeypatch): - monkeypatch.setenv("BATFISH_HOST", "my-bf-host") + def test_explicit_session_param(self): mock_session = _make_session_mock(list_networks=["net1"]) with patch(PATCH_TARGET, return_value=mock_session) as mock_get: server = create_server() - _call_tool(server, "list_networks", {}) - mock_get.assert_called_once_with("my-bf-host") + _call_tool(server, "list_networks", {"session": "default"}) + mock_get.assert_called_once_with("default") class TestSetNetworkTool: @@ -371,7 +417,7 @@ def test_returns_network_name(self): mock_session.set_network.return_value = "my-network" with patch(PATCH_TARGET, return_value=mock_session): server = create_server() - data = _call_tool(server, "set_network", {"network": "my-network", "host": "localhost"}) + data = _call_tool(server, "set_network", {"network": "my-network"}) assert data == {"network": "my-network"} @@ -380,7 +426,7 @@ def test_returns_deleted_name(self): mock_session = MagicMock() with patch(PATCH_TARGET, return_value=mock_session): server = create_server() - data = _call_tool(server, "delete_network", {"network": "old-net", "host": "localhost"}) + data = _call_tool(server, "delete_network", {"network": "old-net"}) assert data == {"deleted": "old-net"} mock_session.delete_network.assert_called_once_with("old-net") @@ -390,7 +436,7 @@ def test_returns_snapshot_list(self): mock_session = _make_session_mock(list_snapshots=["snap1", "snap2"]) with patch(PATCH_TARGET, return_value=mock_session): server = create_server() - data = _call_tool(server, "list_snapshots", {"network": "net1", "host": "localhost"}) + data = _call_tool(server, "list_snapshots", {"network": "net1"}) assert data == ["snap1", "snap2"] @@ -403,7 +449,7 @@ def test_returns_snapshot_name(self): data = _call_tool( server, "init_snapshot", - {"network": "net1", "snapshot_path": "/path/to/snap", "host": "localhost"}, + {"network": "net1", "snapshot_path": "/path/to/snap"}, ) assert data == {"snapshot": "my-snap"} @@ -420,7 +466,6 @@ def test_passes_name_and_overwrite(self): "snapshot_path": "/path", "snapshot_name": "named-snap", "overwrite": True, - "host": "localhost", }, ) mock_session.init_snapshot.assert_called_once_with("/path", name="named-snap", overwrite=True) @@ -435,7 +480,7 @@ def test_returns_snapshot_name(self): data = _call_tool( server, "init_snapshot_from_text", - {"network": "net1", "config_text": "hostname router1", "host": "localhost"}, + {"network": "net1", "config_text": "hostname router1"}, ) assert data == {"snapshot": "text-snap"} @@ -451,7 +496,6 @@ def test_passes_platform_when_set(self): "network": "net1", "config_text": "config", "platform": "arista", - "host": "localhost", }, ) call_kwargs = mock_session.init_snapshot_from_text.call_args[1] @@ -465,7 +509,7 @@ def test_passes_none_platform_when_empty(self): _call_tool( server, "init_snapshot_from_text", - {"network": "net1", "config_text": "config", "host": "localhost"}, + {"network": "net1", "config_text": "config"}, ) call_kwargs = mock_session.init_snapshot_from_text.call_args[1] assert call_kwargs["platform"] is None @@ -476,7 +520,7 @@ def test_returns_deleted_name(self): mock_session = MagicMock() with patch(PATCH_TARGET, return_value=mock_session): server = create_server() - data = _call_tool(server, "delete_snapshot", {"network": "net1", "snapshot": "snap1", "host": "localhost"}) + data = _call_tool(server, "delete_snapshot", {"network": "net1", "snapshot": "snap1"}) assert data == {"deleted": "snap1"} mock_session.delete_snapshot.assert_called_once_with("snap1") @@ -490,7 +534,7 @@ def test_basic_fork(self): data = _call_tool( server, "fork_snapshot", - {"network": "net1", "base_snapshot": "base", "new_snapshot": "forked", "host": "localhost"}, + {"network": "net1", "base_snapshot": "base", "new_snapshot": "forked"}, ) assert data == {"snapshot": "forked-snap"} @@ -506,7 +550,6 @@ def test_deactivate_nodes(self): "network": "net1", "base_snapshot": "base", "deactivate_nodes": "r1,r2", - "host": "localhost", }, ) call_kwargs = mock_session.fork_snapshot.call_args[1] @@ -524,7 +567,6 @@ def test_deactivate_interfaces(self): "network": "net1", "base_snapshot": "base", "deactivate_interfaces": "r1[Gi0/0]", - "host": "localhost", }, ) call_kwargs = mock_session.fork_snapshot.call_args[1] @@ -546,7 +588,6 @@ def test_returns_json_rows(self): "snapshot": "snap1", "start_location": "router1", "dst_ips": "10.0.0.1", - "host": "localhost", }, ) assert len(data) == 1 @@ -570,7 +611,6 @@ def test_optional_header_params_passed(self): "ip_protocols": "TCP", "src_ports": "1024", "dst_ports": "22", - "host": "localhost", }, ) call_kwargs = mock_session.q.traceroute.call_args[1] @@ -593,37 +633,11 @@ def test_returns_json_rows(self): "snapshot": "snap1", "start_location": "router1", "dst_ips": "10.0.0.1", - "host": "localhost", }, ) assert len(data) == 1 assert data[0]["Forward_Flow"] == "f1" - def test_optional_header_params_passed(self): - mock_session = MagicMock() - mock_session.q.bidirectionalTraceroute.return_value = _make_answer_frame([]) - with patch(PATCH_TARGET, return_value=mock_session): - server = create_server() - _call_tool( - server, - "run_bidirectional_traceroute", - { - "network": "net1", - "snapshot": "snap1", - "start_location": "router1", - "dst_ips": "10.0.0.1", - "src_ips": "192.168.0.1", - "applications": "ssh", - "ip_protocols": "TCP", - "src_ports": "1024", - "dst_ports": "22", - "host": "localhost", - }, - ) - call_kwargs = mock_session.q.bidirectionalTraceroute.call_args[1] - assert call_kwargs["headers"].dstIps == "10.0.0.1" - assert call_kwargs["headers"].srcIps == "192.168.0.1" - class TestCheckReachabilityTool: def test_basic_call(self): @@ -634,7 +648,7 @@ def test_basic_call(self): data = _call_tool( server, "check_reachability", - {"network": "net1", "snapshot": "snap1", "dst_ips": "8.8.8.8", "host": "localhost"}, + {"network": "net1", "snapshot": "snap1", "dst_ips": "8.8.8.8"}, ) assert data[0]["Action"] == "ACCEPT" @@ -651,7 +665,6 @@ def test_optional_params_passed(self): "snapshot": "snap1", "src_locations": "router1", "actions": "DENIED_IN,DROP", - "host": "localhost", }, ) call_kwargs = mock_session.q.reachability.call_args[1] @@ -668,24 +681,10 @@ def test_returns_acl_rows(self): data = _call_tool( server, "analyze_acl", - {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + {"network": "net1", "snapshot": "snap1"}, ) assert data[0]["Filter"] == "acl1" - def test_optional_params_passed(self): - mock_session = MagicMock() - mock_session.q.filterLineReachability.return_value = _make_answer_frame([]) - with patch(PATCH_TARGET, return_value=mock_session): - server = create_server() - _call_tool( - server, - "analyze_acl", - {"network": "net1", "snapshot": "snap1", "filters": "acl1", "nodes": "r1", "host": "localhost"}, - ) - call_kwargs = mock_session.q.filterLineReachability.call_args[1] - assert call_kwargs["filters"] == "acl1" - assert call_kwargs["nodes"] == "r1" - class TestSearchFiltersTool: def test_permit_action_passed(self): @@ -696,31 +695,11 @@ def test_permit_action_passed(self): _call_tool( server, "search_filters", - {"network": "net1", "snapshot": "snap1", "action": "PERMIT", "host": "localhost"}, + {"network": "net1", "snapshot": "snap1", "action": "PERMIT"}, ) call_kwargs = mock_session.q.searchFilters.call_args[1] assert call_kwargs["action"] == "PERMIT" - def test_optional_filters_and_nodes_passed(self): - mock_session = MagicMock() - mock_session.q.searchFilters.return_value = _make_answer_frame([]) - with patch(PATCH_TARGET, return_value=mock_session): - server = create_server() - _call_tool( - server, - "search_filters", - { - "network": "net1", - "snapshot": "snap1", - "filters": "acl1", - "nodes": "r1", - "host": "localhost", - }, - ) - call_kwargs = mock_session.q.searchFilters.call_args[1] - assert call_kwargs["filters"] == "acl1" - assert call_kwargs["nodes"] == "r1" - class TestGetRoutesTool: def test_returns_routes(self): @@ -731,7 +710,7 @@ def test_returns_routes(self): data = _call_tool( server, "get_routes", - {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + {"network": "net1", "snapshot": "snap1"}, ) assert data[0]["Node"] == "r1" @@ -753,7 +732,7 @@ def test_legacy_nexthop_columns_dropped(self): data = _call_tool( server, "get_routes", - {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + {"network": "net1", "snapshot": "snap1"}, ) assert "Next_Hop" in data[0] assert "Next_Hop_IP" not in data[0] @@ -774,7 +753,6 @@ def test_filters_passed(self): "vrfs": "default", "network_prefix": "10.0.0.0/8", "protocols": "bgp", - "host": "localhost", }, ) call_kwargs = mock_session.q.routes.call_args[1] @@ -802,41 +780,10 @@ def test_calls_differential_answer(self): "network": "net1", "snapshot": "snap-new", "reference_snapshot": "snap-old", - "host": "localhost", }, ) mock_answer_obj.answer.assert_called_once_with(snapshot="snap-new", reference_snapshot="snap-old") - def test_optional_filters_passed(self): - mock_frame_obj = MagicMock() - mock_frame_obj.frame.return_value = pd.DataFrame([]) - mock_answer_obj = MagicMock() - mock_answer_obj.answer.return_value = mock_frame_obj - mock_session = MagicMock() - mock_session.q.routes.return_value = mock_answer_obj - - with patch(PATCH_TARGET, return_value=mock_session): - server = create_server() - _call_tool( - server, - "compare_routes", - { - "network": "net1", - "snapshot": "snap-new", - "reference_snapshot": "snap-old", - "nodes": "r1", - "vrfs": "default", - "network_prefix": "10.0.0.0/8", - "protocols": "bgp", - "host": "localhost", - }, - ) - call_kwargs = mock_session.q.routes.call_args[1] - assert call_kwargs["nodes"] == "r1" - assert call_kwargs["vrfs"] == "default" - assert call_kwargs["network"] == "10.0.0.0/8" - assert call_kwargs["protocols"] == "bgp" - class TestGetBgpSessionStatusTool: def test_returns_bgp_rows(self): @@ -847,32 +794,10 @@ def test_returns_bgp_rows(self): data = _call_tool( server, "get_bgp_session_status", - {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + {"network": "net1", "snapshot": "snap1"}, ) assert data[0]["Status"] == "ESTABLISHED" - def test_optional_params_passed(self): - mock_session = MagicMock() - mock_session.q.bgpSessionStatus.return_value = _make_answer_frame([]) - with patch(PATCH_TARGET, return_value=mock_session): - server = create_server() - _call_tool( - server, - "get_bgp_session_status", - { - "network": "net1", - "snapshot": "snap1", - "nodes": "r1", - "remote_nodes": "r2", - "status": "ESTABLISHED", - "host": "localhost", - }, - ) - call_kwargs = mock_session.q.bgpSessionStatus.call_args[1] - assert call_kwargs["nodes"] == "r1" - assert call_kwargs["remoteNodes"] == "r2" - assert call_kwargs["status"] == "ESTABLISHED" - class TestGetBgpSessionCompatibilityTool: def test_returns_compat_rows(self): @@ -883,32 +808,10 @@ def test_returns_compat_rows(self): data = _call_tool( server, "get_bgp_session_compatibility", - {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + {"network": "net1", "snapshot": "snap1"}, ) assert data[0]["Node"] == "r1" - def test_optional_params_passed(self): - mock_session = MagicMock() - mock_session.q.bgpSessionCompatibility.return_value = _make_answer_frame([]) - with patch(PATCH_TARGET, return_value=mock_session): - server = create_server() - _call_tool( - server, - "get_bgp_session_compatibility", - { - "network": "net1", - "snapshot": "snap1", - "nodes": "r1", - "remote_nodes": "r2", - "status": "UNIQUE_MATCH", - "host": "localhost", - }, - ) - call_kwargs = mock_session.q.bgpSessionCompatibility.call_args[1] - assert call_kwargs["nodes"] == "r1" - assert call_kwargs["remoteNodes"] == "r2" - assert call_kwargs["status"] == "UNIQUE_MATCH" - class TestGetNodePropertiesTool: def test_returns_node_properties(self): @@ -919,23 +822,10 @@ def test_returns_node_properties(self): data = _call_tool( server, "get_node_properties", - {"network": "net1", "snapshot": "snap1", "nodes": "r1", "host": "localhost"}, + {"network": "net1", "snapshot": "snap1", "nodes": "r1"}, ) assert data[0]["Node"] == "r1" - def test_properties_param_passed(self): - mock_session = MagicMock() - mock_session.q.nodeProperties.return_value = _make_answer_frame([]) - with patch(PATCH_TARGET, return_value=mock_session): - server = create_server() - _call_tool( - server, - "get_node_properties", - {"network": "net1", "snapshot": "snap1", "properties": "Hostname,NTP_Servers", "host": "localhost"}, - ) - call_kwargs = mock_session.q.nodeProperties.call_args[1] - assert call_kwargs["properties"] == "Hostname,NTP_Servers" - class TestGetInterfacePropertiesTool: def test_returns_interface_properties(self): @@ -946,32 +836,10 @@ def test_returns_interface_properties(self): data = _call_tool( server, "get_interface_properties", - {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + {"network": "net1", "snapshot": "snap1"}, ) assert data[0]["Interface"] == "r1[Gi0/0]" - def test_optional_params_passed(self): - mock_session = MagicMock() - mock_session.q.interfaceProperties.return_value = _make_answer_frame([]) - with patch(PATCH_TARGET, return_value=mock_session): - server = create_server() - _call_tool( - server, - "get_interface_properties", - { - "network": "net1", - "snapshot": "snap1", - "nodes": "r1", - "interfaces": "Gi0/0", - "properties": "Active,Description", - "host": "localhost", - }, - ) - call_kwargs = mock_session.q.interfaceProperties.call_args[1] - assert call_kwargs["nodes"] == "r1" - assert call_kwargs["interfaces"] == "Gi0/0" - assert call_kwargs["properties"] == "Active,Description" - class TestGetIpOwnersTool: def test_returns_ip_rows(self): @@ -982,7 +850,7 @@ def test_returns_ip_rows(self): data = _call_tool( server, "get_ip_owners", - {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + {"network": "net1", "snapshot": "snap1"}, ) assert data[0]["IP"] == "10.0.0.1" @@ -994,7 +862,7 @@ def test_duplicates_only_flag(self): _call_tool( server, "get_ip_owners", - {"network": "net1", "snapshot": "snap1", "duplicates_only": True, "host": "localhost"}, + {"network": "net1", "snapshot": "snap1", "duplicates_only": True}, ) mock_session.q.ipOwners.assert_called_once_with(duplicatesOnly=True) @@ -1017,37 +885,10 @@ def test_calls_differential_answer(self): "network": "net1", "snapshot": "snap-new", "reference_snapshot": "snap-old", - "host": "localhost", }, ) mock_answer_obj.answer.assert_called_once_with(snapshot="snap-new", reference_snapshot="snap-old") - def test_optional_params_passed(self): - mock_frame_obj = MagicMock() - mock_frame_obj.frame.return_value = pd.DataFrame([]) - mock_answer_obj = MagicMock() - mock_answer_obj.answer.return_value = mock_frame_obj - mock_session = MagicMock() - mock_session.q.compareFilters.return_value = mock_answer_obj - - with patch(PATCH_TARGET, return_value=mock_session): - server = create_server() - _call_tool( - server, - "compare_filters", - { - "network": "net1", - "snapshot": "snap-new", - "reference_snapshot": "snap-old", - "filters": "acl1", - "nodes": "r1", - "host": "localhost", - }, - ) - call_kwargs = mock_session.q.compareFilters.call_args[1] - assert call_kwargs["filters"] == "acl1" - assert call_kwargs["nodes"] == "r1" - class TestGetUndefinedReferencesTool: def test_returns_reference_rows(self): @@ -1058,23 +899,10 @@ def test_returns_reference_rows(self): data = _call_tool( server, "get_undefined_references", - {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + {"network": "net1", "snapshot": "snap1"}, ) assert data[0]["Ref_Name"] == "acl-foo" - def test_nodes_param_passed(self): - mock_session = MagicMock() - mock_session.q.undefinedReferences.return_value = _make_answer_frame([]) - with patch(PATCH_TARGET, return_value=mock_session): - server = create_server() - _call_tool( - server, - "get_undefined_references", - {"network": "net1", "snapshot": "snap1", "nodes": "r1", "host": "localhost"}, - ) - call_kwargs = mock_session.q.undefinedReferences.call_args[1] - assert call_kwargs["nodes"] == "r1" - class TestDetectLoopsTool: def test_returns_loop_rows(self): @@ -1085,7 +913,7 @@ def test_returns_loop_rows(self): data = _call_tool( server, "detect_loops", - {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + {"network": "net1", "snapshot": "snap1"}, ) assert data[0]["Node"] == "r1" @@ -1097,7 +925,7 @@ def test_no_loops(self): data = _call_tool( server, "detect_loops", - {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + {"network": "net1", "snapshot": "snap1"}, ) assert data == [] @@ -1106,6 +934,8 @@ class TestToolListCompleteness: """Verify the server exposes the expected set of tools.""" EXPECTED_TOOLS = { + "register_session", + "list_sessions", "list_networks", "set_network", "delete_network", @@ -1131,6 +961,12 @@ class TestToolListCompleteness: "detect_loops", } + def setup_method(self): + _clear_session_cache() + + def teardown_method(self): + _clear_session_cache() + def test_all_expected_tools_registered(self): server = create_server() tools = asyncio.run(server.list_tools())