diff --git a/.gitignore b/.gitignore index b0d27dd..bba6d04 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ dist/ *.egg-info/ .streamlit/ +.ga-switch/ .vscode/ .idea/ diff --git a/agentmain.py b/agentmain.py index 4643d3c..6a4bebe 100644 --- a/agentmain.py +++ b/agentmain.py @@ -6,14 +6,15 @@ elif hasattr(sys.stderr, 'reconfigure'): sys.stderr.reconfigure(errors='replace') sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -from llmcore import LLMSession, ToolClient, ClaudeSession, MixinSession, NativeToolClient, NativeClaudeSession, NativeOAISession from agent_loop import agent_runner_loop from ga import GenericAgentHandler, smart_format, get_global_memory, format_error, consume_file +from ga_switch import get_service +from ga_switch.runtime_bridge import describe_runtime, load_clients, next_client, set_active_route as bridge_set_active_route script_dir = os.path.dirname(os.path.abspath(__file__)) def load_tool_schema(suffix=''): global TOOLS_SCHEMA - TS = open(os.path.join(script_dir, f'assets/tools_schema{suffix}.json'), 'r', encoding='utf-8').read() + with open(os.path.join(script_dir, f'assets/tools_schema{suffix}.json'), 'r', encoding='utf-8') as f: TS = f.read() TOOLS_SCHEMA = json.loads(TS if os.name == 'nt' else TS.replace('powershell', 'bash')) load_tool_schema() @@ -21,16 +22,24 @@ def load_tool_schema(suffix=''): mem_dir = os.path.join(script_dir, 'memory') if not os.path.exists(mem_dir): os.makedirs(mem_dir) mem_txt = os.path.join(mem_dir, 'global_mem.txt') -if not os.path.exists(mem_txt): open(mem_txt, 'w', encoding='utf-8').write('# [Global Memory - L2]\n') +if not os.path.exists(mem_txt): + with open(mem_txt, 'w', encoding='utf-8') as f: + f.write('# [Global Memory - L2]\n') mem_insight = os.path.join(mem_dir, 'global_mem_insight.txt') if not os.path.exists(mem_insight): t = os.path.join(script_dir, f'assets/global_mem_insight_template{lang_suffix}.txt') - open(mem_insight, 'w', encoding='utf-8').write(open(t, encoding='utf-8').read() if os.path.exists(t) else '') + template = '' + if os.path.exists(t): + with open(t, encoding='utf-8') as f: + template = f.read() + with open(mem_insight, 'w', encoding='utf-8') as f: + f.write(template) cdp_cfg = os.path.join(script_dir, 'assets/tmwd_cdp_bridge/config.js') if not os.path.exists(cdp_cfg): try: os.makedirs(os.path.dirname(cdp_cfg), exist_ok=True) - open(cdp_cfg, 'w', encoding='utf-8').write(f"const TID = '__ljq_{hex(random.randint(0, 99999999))[2:8]}';") + with open(cdp_cfg, 'w', encoding='utf-8') as f: + f.write(f"const TID = '__ljq_{hex(random.randint(0, 99999999))[2:8]}';") except Exception as e: print(f'[WARN] CDP config init failed: {e} — advanced web features (tmwebdriver) will be unavailable.') def get_system_prompt(): @@ -43,47 +52,49 @@ class GeneraticAgent: def __init__(self): script_dir = os.path.dirname(os.path.abspath(__file__)) os.makedirs(os.path.join(script_dir, 'temp'), exist_ok=True) - from llmcore import mykeys - llm_sessions = [] - for k, cfg in mykeys.items(): - if not any(x in k for x in ['api', 'config', 'cookie']): continue - try: - if 'native' in k and 'claude' in k: llm_sessions += [NativeToolClient(NativeClaudeSession(cfg=cfg))] - elif 'native' in k and 'oai' in k: llm_sessions += [NativeToolClient(NativeOAISession(cfg=cfg))] - elif 'claude' in k: llm_sessions += [ToolClient(ClaudeSession(cfg=cfg))] - elif 'oai' in k: llm_sessions += [ToolClient(LLMSession(cfg=cfg))] - elif 'mixin' in k: llm_sessions += [{'mixin_cfg': cfg}] - except: pass - for i, s in enumerate(llm_sessions): - if isinstance(s, dict) and 'mixin_cfg' in s: - try: - mixin = MixinSession(llm_sessions, s['mixin_cfg']) - if isinstance(mixin._sessions[0], (NativeClaudeSession, NativeOAISession)): llm_sessions[i] = NativeToolClient(mixin) - else: llm_sessions[i] = ToolClient(mixin) - except Exception as e: print(f'[WARN] Failed to init MixinSession with cfg {s["mixin_cfg"]}: {e}') - self.llmclients = llm_sessions self.lock = threading.Lock() self.task_dir = None self.history = [] self.task_queue = queue.Queue() self.is_running = False; self.stop_sig = False - self.llm_no = 0; self.inc_out = False + self.llm_no = 0; self.llmclient = None; self.llmclients = [] + self.config_source = 'legacy'; self.config_meta = {} + self.ga_switch = get_service() + self.inc_out = False self.handler = None; self.verbose = True - self.llmclient = self.llmclients[self.llm_no] + self._reload_clients(initial=True) - def next_llm(self, n=-1): - self.llm_no = ((self.llm_no + 1) if n < 0 else n) % len(self.llmclients) - lastc = self.llmclient - self.llmclient = self.llmclients[self.llm_no] - self.llmclient.backend.history = lastc.backend.history - self.llmclient.last_tools = '' + def _sync_tool_schema(self): name = self.get_llm_name().lower() if 'glm' in name or 'minimax' in name or 'kimi' in name: load_tool_schema('_cn') else: load_tool_schema() - def list_llms(self): return [(i, self.get_llm_name(b), i == self.llm_no) for i, b in enumerate(self.llmclients)] - def get_llm_name(self, b=None): - b = self.llmclient if b is None else b - return f"{type(b.backend).__name__}/{b.backend.name}" if not isinstance(b, dict) else "BADCONFIG_MIXIN" + + def _reload_clients(self, *, initial=False, preserve_history=True): + return load_clients(self, preserve_history=preserve_history, initial=initial) + + def next_llm(self, n=-1): + return next_client(self, n) + + def set_active_route(self, route_id_or_idx): + return bridge_set_active_route(self, route_id_or_idx) + + def reload_llm_config(self, preserve_history=True): + if self.is_running: + raise RuntimeError('Cannot reload LLM config while agent is running.') + self._reload_clients(initial=False, preserve_history=preserve_history) + return self.describe_llms() + + def describe_llms(self): + return describe_runtime(self) + + def list_llms(self): + return [(item['idx'], f"{item['name']} [{item['backend_class']}/{item.get('provider_name') or self.llmclients[item['idx']].backend.name}]", item['active']) for item in self.describe_llms()] + + def get_llm_name(self): + if self.llmclient is None: + return 'No LLM' + item = self.describe_llms()[self.llm_no] + return f"{item['name']} [{item['backend_class']}/{item.get('provider_name') or self.llmclient.backend.name}]" def abort(self): if not self.is_running: return diff --git a/ga_switch/__init__.py b/ga_switch/__init__.py new file mode 100644 index 0000000..7d81462 --- /dev/null +++ b/ga_switch/__init__.py @@ -0,0 +1,14 @@ +import os +from functools import lru_cache + + +def get_default_db_path(): + root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + return os.path.join(root_dir, ".ga-switch", "ga-switch.db") + + +@lru_cache(maxsize=8) +def get_service(db_path=None): + from .service import GASwitchService + + return GASwitchService(db_path=db_path or get_default_db_path()) diff --git a/ga_switch/diagnostics.py b/ga_switch/diagnostics.py new file mode 100644 index 0000000..42fb5a9 --- /dev/null +++ b/ga_switch/diagnostics.py @@ -0,0 +1,48 @@ +from datetime import datetime, timezone + + +ERROR_KINDS = ( + "auth", + "quota", + "rate_limit", + "timeout", + "network", + "server", + "bad_request", + "model_not_found", + "unsupported_param", + "unknown", +) + + +def utcnow_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z") + + +def normalize_message(message, limit=2000) -> str: + text = "" if message is None else str(message).strip() + return text[:limit] + + +def classify_error(*, status_code=None, message="", body="", exc_type="") -> str: + status = None if status_code is None else int(status_code) + hay = " ".join(x for x in (str(message or ""), str(body or ""), str(exc_type or "")) if x).lower() + + if status in (401, 403): + return "auth" + if status == 404 or "model_not_found" in hay or "model not found" in hay or "no such model" in hay: + return "model_not_found" + if "unsupported_param" in hay or ("unsupported" in hay and any(k in hay for k in ("param", "reasoning_effort", "reasoning.effort", "api_mode"))): + return "unsupported_param" + if status == 400: + return "unsupported_param" if "unsupported" in hay else "bad_request" + if status == 429: + quota_tokens = ("insufficient_quota", "quota", "credit", "billing", "余额", "配额") + return "quota" if any(token in hay for token in quota_tokens) else "rate_limit" + if status is not None and status >= 500: + return "server" + if any(token in hay for token in ("timeout", "timed out", "readtimeout", "connecttimeout")): + return "timeout" + if any(token in hay for token in ("connectionerror", "proxyerror", "sslerror", "name or service not known", "connection reset", "dns", "proxy", "connection refused")): + return "network" + return "unknown" diff --git a/ga_switch/models.py b/ga_switch/models.py new file mode 100644 index 0000000..3f23907 --- /dev/null +++ b/ga_switch/models.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass, field +from typing import Any + + +PROVIDER_BACKEND_KINDS = ( + "native_claude", + "native_oai", + "claude_text", + "oai_text", +) + +ROUTE_KINDS = ("single", "failover") + + +def is_native_backend_kind(kind: str) -> bool: + return str(kind or "").startswith("native_") + + +def backend_family(kind: str) -> str: + kind = str(kind or "") + if "claude" in kind: + return "claude" + return "oai" + + +@dataclass +class ProviderModel: + id: int | None = None + name: str = "" + backend_kind: str = "oai_text" + apikey: str = "" + apibase: str = "" + model: str = "" + api_mode: str = "chat_completions" + temperature: float = 1.0 + max_tokens: int = 8192 + context_win: int = 24000 + proxy: str | None = None + timeout: int = 5 + read_timeout: int = 30 + max_retries: int = 1 + reasoning_effort: str | None = None + thinking_type: str | None = None + thinking_budget_tokens: int | None = None + stream: bool = True + is_enabled: bool = True + extra: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class RouteModel: + id: int | None = None + name: str = "" + kind: str = "single" + provider_id: int | None = None + member_provider_ids: list[int] = field(default_factory=list) + is_enabled: bool = True + is_default: bool = False + config: dict[str, Any] = field(default_factory=dict) diff --git a/ga_switch/runtime_bridge.py b/ga_switch/runtime_bridge.py new file mode 100644 index 0000000..e389710 --- /dev/null +++ b/ga_switch/runtime_bridge.py @@ -0,0 +1,424 @@ +from llmcore import ( + ClaudeSession, + LLMSession, + MixinSession, + NativeClaudeSession, + NativeOAISession, + NativeToolClient, + ToolClient, +) + +from .diagnostics import classify_error + + +def _provider_cfg(provider, override=None): + cfg = { + "name": provider["name"], + "apikey": provider["apikey"], + "apibase": provider["apibase"], + "model": provider.get("model") or "", + "api_mode": provider.get("api_mode") or "chat_completions", + "temperature": provider.get("temperature", 1.0), + "max_tokens": provider.get("max_tokens", 8192), + "context_win": provider.get("context_win", 24000), + "proxy": provider.get("proxy"), + "timeout": provider.get("timeout", 5), + "read_timeout": provider.get("read_timeout", 30), + "max_retries": provider.get("max_retries", 1), + "reasoning_effort": provider.get("reasoning_effort"), + "thinking_type": provider.get("thinking_type"), + "thinking_budget_tokens": provider.get("thinking_budget_tokens"), + "stream": provider.get("stream", True), + } + if override: + cfg.update({key: value for key, value in override.items() if value is not None}) + return cfg + + +def _backend_error_kind(backend): + return classify_error( + status_code=getattr(backend, "last_status_code", None), + message=getattr(backend, "last_error_message", ""), + ) + + +def _attach_runtime_metadata(client, *, source, route_id, route_name, route_kind, backend_kind, provider=None, members=None): + meta = { + "source": source, + "route_id": route_id, + "route_name": route_name, + "route_kind": route_kind, + "backend_kind": backend_kind, + "provider": provider, + "provider_id": provider["id"] if provider else None, + "provider_name": provider["name"] if provider else None, + "members": list(members or []), + } + client.ga_switch_meta = meta + client.ga_switch_route_id = route_id + client.ga_switch_route_name = route_name + client.ga_switch_route_kind = route_kind + client.ga_switch_backend_kind = backend_kind + client.ga_switch_members = list(members or ([] if provider is None else [provider])) + return client + + +def _resolve_event_provider(client): + meta = getattr(client, "ga_switch_meta", {}) + provider = meta.get("provider") + if provider is not None: + return provider + active_member_name = getattr(client.backend, "active_member_name", None) + if not active_member_name: + return None + for member in meta.get("members", []): + if member.get("name") == active_member_name: + return member + return None + + +def _wrap_client_chat(service, client): + if getattr(client, "_ga_switch_chat_wrapped", False): + return client + + original_chat = client.chat + + def wrapped_chat(messages, tools=None): + generator = original_chat(messages, tools=tools) + last_chunk = "" + response = None + try: + while True: + chunk = next(generator) + last_chunk = chunk + yield chunk + except StopIteration as stop: + response = stop.value + + backend = client.backend + provider = _resolve_event_provider(client) + meta = getattr(client, "ga_switch_meta", {}) + last_error_message = getattr(backend, "last_error_message", "") or "" + status_code = getattr(backend, "last_status_code", None) + ok = not last_error_message and not (isinstance(last_chunk, str) and last_chunk.startswith("Error:")) + message = "OK" if ok else (last_error_message or str(last_chunk or "Error")) + service.record_runtime_event( + provider, + route_id=meta.get("route_id"), + route_name=meta.get("route_name"), + backend_name=getattr(backend, "name", ""), + ok=ok, + message=message, + status_code=status_code, + latency_ms=getattr(backend, "last_latency_ms", None), + ttfb_ms=getattr(backend, "last_ttfb_ms", None), + ) + return response + + client.chat = wrapped_chat + client._ga_switch_chat_wrapped = True + return client + + +def _build_client_from_provider(service, provider, *, source, route_id=None, route_name=None, route_kind="single", override=None): + cfg = _provider_cfg(provider, override=override) + backend_kind = provider["backend_kind"] + if backend_kind == "native_claude": + backend = NativeClaudeSession(cfg=cfg) + client = NativeToolClient(backend) + elif backend_kind == "native_oai": + backend = NativeOAISession(cfg=cfg) + client = NativeToolClient(backend) + elif backend_kind == "claude_text": + backend = ClaudeSession(cfg=cfg) + client = ToolClient(backend) + else: + backend = LLMSession(cfg=cfg) + client = ToolClient(backend) + _attach_runtime_metadata( + client, + source=source, + route_id=route_id, + route_name=route_name or provider["name"], + route_kind=route_kind, + backend_kind=backend_kind, + provider=provider, + ) + return _wrap_client_chat(service, client) + + +def build_test_client(service, provider, override=None): + return _build_client_from_provider( + service, + provider, + source="test", + route_id=None, + route_name=f"test:{provider['name']}", + route_kind="single", + override=override, + ) + + +def _build_structured_clients(service): + routes = service.store.list_routes(enabled_only=True) + active_route_id = service.get_active_route_id() + clients = [] + for route in routes: + if route["kind"] == "single": + provider = route["provider"] + if provider is None: + raise ValueError(f"Single route {route['name']} is missing provider.") + client = _build_client_from_provider( + service, + provider, + source="store", + route_id=route["id"], + route_name=route["name"], + route_kind=route["kind"], + ) + client.ga_switch_members = [provider] + else: + members = [ + _build_client_from_provider( + service, + provider, + source="store", + route_id=route["id"], + route_name=route["name"], + route_kind=route["kind"], + ) + for provider in route["members"] + ] + mixin_cfg = { + "llm_nos": [member.backend.name for member in members], + "max_retries": route["config"].get("max_retries", 3), + "base_delay": route["config"].get("base_delay", 1.5), + "spring_back": route["config"].get("spring_back", 300), + } + mixin = MixinSession(members, mixin_cfg) + mixin.name = route["name"] + client = NativeToolClient(mixin) if route["members"] and route["members"][0]["is_native"] else ToolClient(mixin) + _attach_runtime_metadata( + client, + source="store", + route_id=route["id"], + route_name=route["name"], + route_kind=route["kind"], + backend_kind="mixin", + provider=None, + members=route["members"], + ) + _wrap_client_chat(service, client) + clients.append(client) + active_index = next((index for index, route in enumerate(routes) if route["id"] == active_route_id), 0) + return clients, {"source": "store", "routes": routes, "active_route_id": active_route_id, "active_index": active_index} + + +def _build_legacy_clients(): + from llmcore import mykeys + + sessions = [] + for key, cfg in mykeys.items(): + if not any(token in key for token in ("api", "config", "cookie")): + continue + route_name = cfg.get("name") or key + try: + if "native" in key and "claude" in key: + client = NativeToolClient(NativeClaudeSession(cfg=cfg)) + backend_kind = "native_claude" + elif "native" in key and "oai" in key: + client = NativeToolClient(NativeOAISession(cfg=cfg)) + backend_kind = "native_oai" + elif "claude" in key: + client = ToolClient(ClaudeSession(cfg=cfg)) + backend_kind = "claude_text" + elif "oai" in key: + client = ToolClient(LLMSession(cfg=cfg)) + backend_kind = "oai_text" + elif "mixin" in key: + sessions.append({"mixin_cfg": cfg, "route_name": route_name}) + continue + else: + continue + sessions.append(_attach_runtime_metadata( + client, + source="legacy", + route_id=None, + route_name=route_name, + route_kind="single", + backend_kind=backend_kind, + provider=None, + )) + except Exception as exc: + print(f"[WARN] Failed to init legacy session {key}: {exc}") + for index, item in enumerate(sessions): + if not isinstance(item, dict): + continue + try: + mixin = MixinSession(sessions, item["mixin_cfg"]) + client = NativeToolClient(mixin) if "Native" in type(mixin._sessions[0]).__name__ else ToolClient(mixin) + sessions[index] = _attach_runtime_metadata( + client, + source="legacy", + route_id=None, + route_name=item["route_name"], + route_kind="failover", + backend_kind="mixin", + provider=None, + ) + except Exception as exc: + print(f"[WARN] Failed to init MixinSession with cfg {item['mixin_cfg']}: {exc}") + clients = [client for client in sessions if not isinstance(client, dict)] + return clients, {"source": "legacy", "routes": [], "active_route_id": None, "active_index": 0} + + +def load_clients(agent, preserve_history=True, initial=False): + old_client = getattr(agent, "llmclient", None) + old_history = getattr(old_client.backend, "history", None) if old_client is not None and preserve_history else None + old_index = getattr(agent, "llm_no", 0) + + if agent.ga_switch.use_structured_config(): + clients, meta = _build_structured_clients(agent.ga_switch) + else: + clients, meta = _build_legacy_clients() + + agent.llmclients = clients + agent.config_source = meta["source"] + agent.config_meta = meta + if not clients: + agent.llm_no = 0 + agent.llmclient = None + return [] + + target_index = meta.get("active_index", 0) + if not initial and meta["source"] == "legacy" and old_index < len(clients): + target_index = old_index + agent.llm_no = target_index % len(clients) + agent.llmclient = clients[agent.llm_no] + if preserve_history and old_history is not None: + agent.llmclient.backend.history = old_history + if hasattr(agent, "_sync_tool_schema"): + agent._sync_tool_schema() + return clients + + +def _switch_index(agent, index): + if not agent.llmclients: + agent.llmclient = None + return None + last_client = agent.llmclient + agent.llm_no = index % len(agent.llmclients) + agent.llmclient = agent.llmclients[agent.llm_no] + if last_client is not None: + agent.llmclient.backend.history = last_client.backend.history + if hasattr(agent.llmclient, "last_tools"): + agent.llmclient.last_tools = "" + if agent.config_source == "store": + route_id = getattr(agent.llmclient, "ga_switch_route_id", None) + if route_id is not None: + agent.ga_switch.set_active_route(route_id) + if hasattr(agent, "_sync_tool_schema"): + agent._sync_tool_schema() + return agent.llmclient + + +def next_client(agent, n=-1): + if not agent.llmclients: + agent.llmclient = None + return None + index = (agent.llm_no + 1) if n < 0 else n + return _switch_index(agent, index) + + +def set_active_route(agent, route_id_or_idx): + if agent.config_source == "store": + runtime_items = describe_runtime(agent) + target = next((item for item in runtime_items if item["route_id"] == route_id_or_idx), None) + if target is None and isinstance(route_id_or_idx, int) and 0 <= route_id_or_idx < len(runtime_items): + target = runtime_items[route_id_or_idx] + if target is None: + raise ValueError(f"Unknown route id: {route_id_or_idx}") + agent.ga_switch.set_active_route(target["route_id"]) + load_clients(agent, preserve_history=True, initial=False) + return describe_runtime(agent)[agent.llm_no] + if not isinstance(route_id_or_idx, int): + raise ValueError(f"Legacy mode only supports index switching, got {route_id_or_idx!r}") + next_client(agent, route_id_or_idx) + return describe_runtime(agent)[agent.llm_no] + + +def describe_runtime(agent): + runtime_items = [] + for index, client in enumerate(getattr(agent, "llmclients", [])): + backend = client.backend + metadata = getattr(client, "ga_switch_meta", {}) + diagnostics = backend.describe_diagnostics() if hasattr(backend, "describe_diagnostics") else {} + provider_name = metadata.get("provider_name") or getattr(backend, "name", None) + backend_class = type(backend).__name__ + last_error_message = diagnostics.get("last_error_message", "") + runtime_items.append({ + "idx": index, + "active": index == getattr(agent, "llm_no", 0), + "source": getattr(agent, "config_source", metadata.get("source", "legacy")), + "route_id": metadata.get("route_id"), + "name": metadata.get("route_name") or getattr(backend, "name", ""), + "route_kind": metadata.get("route_kind", "single"), + "backend_class": backend_class, + "backend_kind": metadata.get("backend_kind"), + "provider_id": metadata.get("provider_id"), + "provider_name": provider_name, + "model": getattr(backend, "model", None), + "api_mode": getattr(backend, "api_mode", None), + "native_tools": isinstance(client, NativeToolClient) or "Native" in backend_class, + "member_names": [member.get("name", "") for member in metadata.get("members", [])], + "active_member_name": diagnostics.get("active_member_name", getattr(backend, "name", None)), + "last_error_kind": classify_error( + status_code=diagnostics.get("last_status_code"), + message=last_error_message, + ) if last_error_message else None, + "last_error_message": last_error_message, + "last_error_at": diagnostics.get("last_error_at"), + "last_ok_at": diagnostics.get("last_ok_at"), + "last_status_code": diagnostics.get("last_status_code"), + "last_latency_ms": diagnostics.get("last_latency_ms"), + "last_ttfb_ms": diagnostics.get("last_ttfb_ms"), + "last_switch_reason": diagnostics.get("last_switch_reason", ""), + "spring_back_seconds": diagnostics.get("spring_back_seconds"), + }) + return runtime_items + + +def build_runtime_snapshot(config_snapshot, runtime_items): + active_route_id = config_snapshot["active_route_id"] + active_route = next((route for route in config_snapshot["routes"] if route["id"] == active_route_id), None) + runtime_by_route_id = {item["route_id"]: item for item in runtime_items if item.get("route_id") is not None} + active_runtime = runtime_by_route_id.get(active_route_id) or next((item for item in runtime_items if item.get("active")), None) + active_route_summary = { + "route_id": active_route_id, + "route_name": (active_route or {}).get("name"), + "route_kind": (active_route or {}).get("kind"), + "provider_name": (active_runtime or {}).get("provider_name"), + "model": (active_runtime or {}).get("model"), + "backend_class": (active_runtime or {}).get("backend_class"), + "backend_kind": (active_runtime or {}).get("backend_kind"), + "api_mode": (active_runtime or {}).get("api_mode"), + "native_tools": bool((active_runtime or {}).get("native_tools")), + "member_names": list((active_runtime or {}).get("member_names", [])), + "active_member_name": (active_runtime or {}).get("active_member_name"), + "last_error_kind": (active_runtime or {}).get("last_error_kind"), + "last_error_message": (active_runtime or {}).get("last_error_message"), + "last_error_at": (active_runtime or {}).get("last_error_at"), + "last_ok_at": (active_runtime or {}).get("last_ok_at"), + "last_status_code": (active_runtime or {}).get("last_status_code"), + "last_switch_reason": (active_runtime or {}).get("last_switch_reason"), + } + return { + "use_structured_config": config_snapshot["use_structured_config"], + "active_route_id": active_route_id, + "active_route_summary": active_route_summary, + "providers": config_snapshot["providers"], + "routes": config_snapshot["routes"], + "events": config_snapshot["events"], + "runtime": runtime_items, + "stats": dict(config_snapshot["stats"], runtime_count=len(runtime_items)), + } diff --git a/ga_switch/service.py b/ga_switch/service.py new file mode 100644 index 0000000..a827e38 --- /dev/null +++ b/ga_switch/service.py @@ -0,0 +1,282 @@ +import importlib.util +import json +import os + +from .diagnostics import classify_error, normalize_message +from .store import GASwitchStore +from .testing import ModelTester + + +SAFE_PROVIDER_FIELDS = ( + "id", + "name", + "backend_kind", + "backend_family", + "model", + "api_mode", + "is_native", + "is_enabled", + "stream", + "timeout", + "read_timeout", + "max_retries", + "reasoning_effort", + "thinking_type", + "thinking_budget_tokens", +) + + +class GASwitchService: + def __init__(self, db_path): + self.db_path = db_path + self.store = GASwitchStore(db_path) + self.tester = ModelTester(self) + + def list_providers(self): + return self.store.list_providers(enabled_only=False) + + def upsert_provider(self, provider): + return self.store.upsert_provider(provider) + + def delete_provider(self, provider_id): + return self.store.delete_provider(provider_id) + + def list_routes(self): + return self.store.list_routes(enabled_only=False) + + def upsert_route(self, route): + return self.store.upsert_route(route) + + def delete_route(self, route_id): + return self.store.delete_route(route_id) + + def set_active_route(self, route_id): + return self.store.set_active_route(route_id) + + def set_structured_config_enabled(self, enabled): + self.store.set_setting("use_structured_config", bool(enabled)) + + def use_structured_config(self): + return bool(self.store.get_setting("use_structured_config", False)) + + def has_usable_routes(self): + return any(route["is_enabled"] for route in self.store.list_routes(enabled_only=True)) + + def get_active_route_id(self): + return self.store.get_setting("active_route_id") + + def _legacy_payload_to_provider(self, var_name, cfg): + lower_name = str(var_name).lower() + if "mixin" in lower_name: + return None + if "native" in lower_name and "claude" in lower_name: + backend_kind = "native_claude" + elif "native" in lower_name and "oai" in lower_name: + backend_kind = "native_oai" + elif "claude" in lower_name: + backend_kind = "claude_text" + elif "oai" in lower_name: + backend_kind = "oai_text" + else: + return None + return { + "name": cfg.get("name") or var_name, + "backend_kind": backend_kind, + "apikey": cfg.get("apikey", ""), + "apibase": cfg.get("apibase", ""), + "model": cfg.get("model", ""), + "api_mode": cfg.get("api_mode", "chat_completions"), + "temperature": cfg.get("temperature", 1.0), + "max_tokens": cfg.get("max_tokens", 8192), + "context_win": cfg.get("context_win", 24000), + "proxy": cfg.get("proxy"), + "timeout": cfg.get("timeout", cfg.get("connect_timeout", 5)), + "read_timeout": cfg.get("read_timeout", 30), + "max_retries": cfg.get("max_retries", 1), + "reasoning_effort": cfg.get("reasoning_effort"), + "thinking_type": cfg.get("thinking_type"), + "thinking_budget_tokens": cfg.get("thinking_budget_tokens"), + "stream": cfg.get("stream", True), + "is_enabled": True, + "extra": {"legacy_var_name": var_name}, + } + + def _load_legacy_config(self, path=None): + if not path: + repo_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + py_path = os.path.join(repo_dir, "mykey.py") + json_path = os.path.join(repo_dir, "mykey.json") + path = py_path if os.path.exists(py_path) else json_path + if not path or not os.path.exists(path): + raise FileNotFoundError("Legacy config not found.") + if path.lower().endswith(".json"): + with open(path, "r", encoding="utf-8") as handle: + payload = json.load(handle) + else: + spec = importlib.util.spec_from_file_location("ga_switch_legacy_mykey", path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + payload = {key: value for key, value in vars(module).items() if not key.startswith("_")} + return payload, path + + def import_legacy_mykey(self, path=None): + payload, source_path = self._load_legacy_config(path) + ordered_providers = [] + providers_by_name = {} + for var_name, cfg in payload.items(): + if not isinstance(cfg, dict): + continue + provider = self._legacy_payload_to_provider(var_name, cfg) + if not provider: + continue + saved = self.store.upsert_provider(provider) + providers_by_name[saved["name"]] = saved + ordered_providers.append(saved) + self.store.upsert_route({ + "name": saved["name"], + "kind": "single", + "provider_id": saved["id"], + "is_enabled": True, + "is_default": False, + }) + for var_name, cfg in payload.items(): + if not (isinstance(cfg, dict) and "mixin" in str(var_name).lower()): + continue + member_ids = [] + for ref in cfg.get("llm_nos") or []: + if isinstance(ref, int): + if ref < 0 or ref >= len(ordered_providers): + raise ValueError(f"Invalid mixin index {ref} in {var_name}") + member_ids.append(ordered_providers[ref]["id"]) + continue + ref_name = str(ref).strip() + provider = providers_by_name.get(ref_name) + if provider is None: + raise ValueError(f"Mixin {var_name} references unknown provider name {ref_name}") + member_ids.append(provider["id"]) + self.store.upsert_route({ + "name": cfg.get("name") or var_name, + "kind": "failover", + "member_provider_ids": member_ids, + "is_enabled": True, + "is_default": False, + "config": { + "max_retries": cfg.get("max_retries", 3), + "base_delay": cfg.get("base_delay", 1.5), + "spring_back": cfg.get("spring_back", 300), + }, + }) + routes = self.store.list_routes(enabled_only=False) + if routes: + self.store.set_setting("active_route_id", routes[0]["id"]) + self.store.set_setting("use_structured_config", True) + return { + "source_path": source_path, + "providers": self.store.list_providers(enabled_only=False), + "routes": self.store.list_routes(enabled_only=False), + } + + def _safe_provider(self, provider): + if provider is None: + return None + safe = {field: provider.get(field) for field in SAFE_PROVIDER_FIELDS} + safe["health"] = dict(provider.get("health") or {}) + return safe + + def _safe_route(self, route): + if route is None: + return None + return { + "id": route["id"], + "name": route["name"], + "kind": route["kind"], + "is_enabled": route["is_enabled"], + "is_default": route["is_default"], + "active": route.get("active", False), + "config": dict(route.get("config") or {}), + "provider": self._safe_provider(route.get("provider")), + "members": [self._safe_provider(member) for member in route.get("members", [])], + "member_provider_ids": list(route.get("member_provider_ids", [])), + } + + def get_config_snapshot(self): + providers = [self._safe_provider(provider) for provider in self.store.list_providers(enabled_only=False)] + routes = [self._safe_route(route) for route in self.store.list_routes(enabled_only=False)] + events = self.store.list_diagnostic_events(limit=100) + return { + "use_structured_config": self.use_structured_config(), + "active_route_id": self.get_active_route_id(), + "providers": providers, + "routes": routes, + "events": events, + "stats": { + "provider_count": len(providers), + "route_count": len(routes), + }, + } + + def get_runtime_diagnostics(self, config_snapshot, runtime_items): + active_route_id = config_snapshot["active_route_id"] + active_route = next((route for route in config_snapshot["routes"] if route["id"] == active_route_id), None) + runtime_by_route_id = {item["route_id"]: item for item in runtime_items if item.get("route_id") is not None} + active_runtime = runtime_by_route_id.get(active_route_id) or next((item for item in runtime_items if item.get("active")), None) + active_route_summary = { + "route_id": active_route_id, + "route_name": (active_route or {}).get("name"), + "route_kind": (active_route or {}).get("kind"), + "provider_name": (active_runtime or {}).get("provider_name"), + "model": (active_runtime or {}).get("model"), + "backend_class": (active_runtime or {}).get("backend_class"), + "backend_kind": (active_runtime or {}).get("backend_kind"), + "api_mode": (active_runtime or {}).get("api_mode"), + "native_tools": bool((active_runtime or {}).get("native_tools")), + "member_names": list((active_runtime or {}).get("member_names", [])), + "active_member_name": (active_runtime or {}).get("active_member_name"), + "last_error_kind": (active_runtime or {}).get("last_error_kind"), + "last_error_message": (active_runtime or {}).get("last_error_message"), + "last_error_at": (active_runtime or {}).get("last_error_at"), + "last_ok_at": (active_runtime or {}).get("last_ok_at"), + "last_status_code": (active_runtime or {}).get("last_status_code"), + "last_switch_reason": (active_runtime or {}).get("last_switch_reason"), + } + return { + "use_structured_config": config_snapshot["use_structured_config"], + "active_route_id": active_route_id, + "active_route_summary": active_route_summary, + "providers": config_snapshot["providers"], + "routes": config_snapshot["routes"], + "events": config_snapshot["events"], + "runtime": runtime_items, + "stats": dict(config_snapshot["stats"], runtime_count=len(runtime_items)), + } + + def record_runtime_event(self, provider, *, route_id=None, route_name=None, backend_name="", ok=False, message="", status_code=None, latency_ms=None, ttfb_ms=None, body="", exc_type=""): + error_kind = None if ok else classify_error(status_code=status_code, message=message, body=body, exc_type=exc_type) + self.store.append_diagnostic_event( + provider_id=provider["id"] if provider else None, + route_id=route_id, + backend_name=backend_name or (provider["name"] if provider else ""), + ok=ok, + error_kind=error_kind, + message=normalize_message(message), + status_code=status_code, + extra={ + "route_name": route_name, + "latency_ms": latency_ms, + "ttfb_ms": ttfb_ms, + "body": normalize_message(body, 1200), + "exc_type": exc_type or None, + }, + ) + if provider is None: + return + self.store.update_provider_health( + provider["id"], + status="healthy" if ok else "failed", + latency_ms=latency_ms if ok else None, + ttfb_ms=ttfb_ms if ok else None, + last_error="" if ok else normalize_message(message), + ) + + def run_model_test(self, provider_id): + return self.tester.run(provider_id) diff --git a/ga_switch/store.py b/ga_switch/store.py new file mode 100644 index 0000000..7605d4b --- /dev/null +++ b/ga_switch/store.py @@ -0,0 +1,509 @@ +import json +import os +import sqlite3 + +from .diagnostics import utcnow_iso +from .models import PROVIDER_BACKEND_KINDS, ROUTE_KINDS, backend_family, is_native_backend_kind + + +class GASwitchStore: + def __init__(self, db_path): + self.db_path = db_path + os.makedirs(os.path.dirname(os.path.abspath(db_path)), exist_ok=True) + self._init_db() + self._ensure_default_test_configs() + + def _connect(self): + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys = ON") + return conn + + @staticmethod + def _loads(payload, default): + if not payload: + return default + try: + return json.loads(payload) + except Exception: + return default + + @staticmethod + def _dumps(payload): + return json.dumps(payload or {}, ensure_ascii=False, separators=(",", ":")) + + def _row_to_provider(self, row, health_by_id): + if row is None: + return None + provider = dict(row) + provider["extra"] = self._loads(provider.pop("extra_json", None), {}) + provider["stream"] = bool(provider.get("stream", 1)) + provider["is_enabled"] = bool(provider.get("is_enabled", 1)) + provider["is_native"] = is_native_backend_kind(provider["backend_kind"]) + provider["backend_family"] = backend_family(provider["backend_kind"]) + provider["health"] = health_by_id.get(provider["id"], { + "provider_id": provider["id"], + "status": "unknown", + "latency_ms": None, + "ttfb_ms": None, + "last_checked_at": None, + "last_error": "", + }) + return provider + + def _health_map(self, conn): + rows = conn.execute("SELECT * FROM provider_health").fetchall() + return {row["provider_id"]: dict(row) for row in rows} + + def _init_db(self): + with self._connect() as conn: + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS providers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + backend_kind TEXT NOT NULL, + apikey TEXT NOT NULL, + apibase TEXT NOT NULL, + model TEXT, + api_mode TEXT DEFAULT 'chat_completions', + temperature REAL DEFAULT 1.0, + max_tokens INTEGER DEFAULT 8192, + context_win INTEGER DEFAULT 24000, + proxy TEXT, + timeout INTEGER DEFAULT 5, + read_timeout INTEGER DEFAULT 30, + max_retries INTEGER DEFAULT 1, + reasoning_effort TEXT, + thinking_type TEXT, + thinking_budget_tokens INTEGER, + stream INTEGER DEFAULT 1, + is_enabled INTEGER DEFAULT 1, + extra_json TEXT DEFAULT '{}', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS routes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + kind TEXT NOT NULL, + provider_id INTEGER, + is_enabled INTEGER DEFAULT 1, + is_default INTEGER DEFAULT 0, + config_json TEXT DEFAULT '{}', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY(provider_id) REFERENCES providers(id) + ); + CREATE TABLE IF NOT EXISTS route_members ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + route_id INTEGER NOT NULL, + provider_id INTEGER NOT NULL, + position INTEGER NOT NULL, + FOREIGN KEY(route_id) REFERENCES routes(id) ON DELETE CASCADE, + FOREIGN KEY(provider_id) REFERENCES providers(id), + UNIQUE(route_id, position) + ); + CREATE TABLE IF NOT EXISTS provider_health ( + provider_id INTEGER PRIMARY KEY, + status TEXT NOT NULL, + latency_ms INTEGER, + ttfb_ms INTEGER, + last_checked_at TEXT, + last_error TEXT, + FOREIGN KEY(provider_id) REFERENCES providers(id) ON DELETE CASCADE + ); + CREATE TABLE IF NOT EXISTS model_test_config ( + backend_family TEXT PRIMARY KEY, + test_model TEXT, + api_mode TEXT, + reasoning_effort TEXT, + prompt TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS app_settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS diagnostic_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + provider_id INTEGER, + route_id INTEGER, + backend_name TEXT, + ok INTEGER NOT NULL DEFAULT 0, + error_kind TEXT, + message TEXT, + status_code INTEGER, + created_at TEXT NOT NULL, + extra_json TEXT DEFAULT '{}', + FOREIGN KEY(provider_id) REFERENCES providers(id) ON DELETE SET NULL, + FOREIGN KEY(route_id) REFERENCES routes(id) ON DELETE SET NULL + ); + """ + ) + + def _ensure_default_test_configs(self): + defaults = { + "claude": { + "test_model": "claude-3-5-haiku-latest", + "api_mode": "chat_completions", + "reasoning_effort": None, + "prompt": "Reply with exactly: pong", + }, + "oai": { + "test_model": "gpt-4.1-mini", + "api_mode": "chat_completions", + "reasoning_effort": "low", + "prompt": "Reply with exactly: pong", + }, + } + with self._connect() as conn: + for family, payload in defaults.items(): + conn.execute( + """ + INSERT INTO model_test_config (backend_family, test_model, api_mode, reasoning_effort, prompt) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(backend_family) DO NOTHING + """, + (family, payload["test_model"], payload["api_mode"], payload["reasoning_effort"], payload["prompt"]), + ) + + def get_setting(self, key, default=None): + with self._connect() as conn: + row = conn.execute("SELECT value FROM app_settings WHERE key = ?", (key,)).fetchone() + if row is None: + return default + try: + return json.loads(row["value"]) + except Exception: + return row["value"] + + def set_setting(self, key, value): + payload = json.dumps(value, ensure_ascii=False) + with self._connect() as conn: + conn.execute( + """ + INSERT INTO app_settings (key, value) VALUES (?, ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """, + (key, payload), + ) + + def list_providers(self, enabled_only=False): + sql = "SELECT * FROM providers" + if enabled_only: + sql += " WHERE is_enabled = 1" + sql += " ORDER BY name COLLATE NOCASE" + with self._connect() as conn: + health_by_id = self._health_map(conn) + rows = conn.execute(sql).fetchall() + return [self._row_to_provider(row, health_by_id) for row in rows] + + def get_provider(self, provider_id): + with self._connect() as conn: + row = conn.execute("SELECT * FROM providers WHERE id = ?", (provider_id,)).fetchone() + return self._row_to_provider(row, self._health_map(conn)) + + def get_provider_by_name(self, name): + with self._connect() as conn: + row = conn.execute("SELECT * FROM providers WHERE name = ?", (name,)).fetchone() + return self._row_to_provider(row, self._health_map(conn)) + + def upsert_provider(self, provider): + backend_kind = str(provider.get("backend_kind") or "").strip() + if backend_kind not in PROVIDER_BACKEND_KINDS: + raise ValueError(f"Unsupported backend_kind: {backend_kind}") + name = str(provider.get("name") or "").strip() + apikey = str(provider.get("apikey") or "").strip() + apibase = str(provider.get("apibase") or "").strip() + if not name or not apikey or not apibase: + raise ValueError("Provider requires name, apikey and apibase.") + now = utcnow_iso() + payload = ( + name, + backend_kind, + apikey, + apibase.rstrip("/"), + provider.get("model") or "", + provider.get("api_mode") or "chat_completions", + float(provider.get("temperature", 1.0)), + int(provider.get("max_tokens", 8192)), + int(provider.get("context_win", 24000)), + provider.get("proxy"), + int(provider.get("timeout", 5)), + int(provider.get("read_timeout", 30)), + int(provider.get("max_retries", 1)), + provider.get("reasoning_effort"), + provider.get("thinking_type"), + provider.get("thinking_budget_tokens"), + 1 if provider.get("stream", True) else 0, + 1 if provider.get("is_enabled", True) else 0, + self._dumps(provider.get("extra")), + now, + now, + ) + with self._connect() as conn: + if provider.get("id"): + conn.execute( + """ + UPDATE providers + SET name=?, backend_kind=?, apikey=?, apibase=?, model=?, api_mode=?, temperature=?, max_tokens=?, + context_win=?, proxy=?, timeout=?, read_timeout=?, max_retries=?, reasoning_effort=?, + thinking_type=?, thinking_budget_tokens=?, stream=?, is_enabled=?, extra_json=?, updated_at=? + WHERE id = ? + """, + payload[:19] + (now, int(provider["id"])), + ) + provider_id = int(provider["id"]) + else: + cursor = conn.execute( + """ + INSERT INTO providers + (name, backend_kind, apikey, apibase, model, api_mode, temperature, max_tokens, context_win, + proxy, timeout, read_timeout, max_retries, reasoning_effort, thinking_type, thinking_budget_tokens, + stream, is_enabled, extra_json, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + payload, + ) + provider_id = int(cursor.lastrowid) + conn.execute( + """ + INSERT INTO provider_health (provider_id, status, latency_ms, ttfb_ms, last_checked_at, last_error) + VALUES (?, 'unknown', NULL, NULL, NULL, '') + ON CONFLICT(provider_id) DO NOTHING + """, + (provider_id,), + ) + return self.get_provider(provider_id) + + def delete_provider(self, provider_id): + with self._connect() as conn: + ref = conn.execute( + """ + SELECT routes.name FROM routes + LEFT JOIN route_members ON route_members.route_id = routes.id + WHERE routes.provider_id = ? OR route_members.provider_id = ? + LIMIT 1 + """, + (provider_id, provider_id), + ).fetchone() + if ref: + raise ValueError(f"Provider is still used by route {ref['name']}.") + conn.execute("DELETE FROM providers WHERE id = ?", (provider_id,)) + + def list_routes(self, enabled_only=False): + sql = "SELECT * FROM routes" + if enabled_only: + sql += " WHERE is_enabled = 1" + sql += " ORDER BY is_default DESC, id ASC" + with self._connect() as conn: + routes = [dict(row) for row in conn.execute(sql).fetchall()] + providers = {provider["id"]: provider for provider in self.list_providers(enabled_only=False)} + member_rows = conn.execute( + """ + SELECT route_id, provider_id, position + FROM route_members + ORDER BY route_id ASC, position ASC + """ + ).fetchall() + active_route_id = self.get_setting("active_route_id") + members_by_route = {} + for row in member_rows: + members_by_route.setdefault(row["route_id"], []).append(providers.get(row["provider_id"])) + result = [] + for route in routes: + route["config"] = self._loads(route.pop("config_json", None), {}) + route["is_enabled"] = bool(route.get("is_enabled", 1)) + route["is_default"] = bool(route.get("is_default", 0)) + route["provider"] = providers.get(route.get("provider_id")) + route["members"] = [member for member in members_by_route.get(route["id"], []) if member] + route["member_provider_ids"] = [member["id"] for member in route["members"]] + route["active"] = active_route_id == route["id"] + result.append(route) + return result + + def get_route(self, route_id): + return next((route for route in self.list_routes(enabled_only=False) if route["id"] == route_id), None) + + def _ensure_route_defaults(self, conn, route_id=None, make_default=False): + any_default = conn.execute("SELECT id FROM routes WHERE is_default = 1 LIMIT 1").fetchone() + if make_default or not any_default: + conn.execute("UPDATE routes SET is_default = 0") + if route_id: + conn.execute("UPDATE routes SET is_default = 1 WHERE id = ?", (route_id,)) + + def _validate_failover_members(self, providers): + if len(providers) < 2: + raise ValueError("Failover route requires at least two providers.") + native_groups = {is_native_backend_kind(provider["backend_kind"]) for provider in providers} + if len(native_groups) != 1: + kinds = [provider["backend_kind"] for provider in providers] + raise ValueError(f"Failover route cannot mix native and non-native providers: {kinds}") + + def upsert_route(self, route): + kind = str(route.get("kind") or "").strip() + if kind not in ROUTE_KINDS: + raise ValueError(f"Unsupported route kind: {kind}") + name = str(route.get("name") or "").strip() + if not name: + raise ValueError("Route requires a name.") + now = utcnow_iso() + with self._connect() as conn: + if kind == "single": + provider_id = int(route.get("provider_id") or 0) + provider = self.get_provider(provider_id) + if not provider: + raise ValueError(f"Single route provider not found: {provider_id}") + member_provider_ids = [] + else: + provider_id = None + member_provider_ids = [int(pid) for pid in (route.get("member_provider_ids") or [])] + providers = [self.get_provider(pid) for pid in member_provider_ids] + if not all(providers): + raise ValueError("Failover route references missing providers.") + self._validate_failover_members(providers) + is_enabled = 1 if route.get("is_enabled", True) else 0 + is_default = 1 if route.get("is_default", False) else 0 + config_json = self._dumps(route.get("config")) + if route.get("id"): + conn.execute( + """ + UPDATE routes + SET name = ?, kind = ?, provider_id = ?, is_enabled = ?, config_json = ?, updated_at = ? + WHERE id = ? + """, + (name, kind, provider_id, is_enabled, config_json, now, int(route["id"])), + ) + route_id = int(route["id"]) + conn.execute("DELETE FROM route_members WHERE route_id = ?", (route_id,)) + else: + cursor = conn.execute( + """ + INSERT INTO routes (name, kind, provider_id, is_enabled, is_default, config_json, created_at, updated_at) + VALUES (?, ?, ?, ?, 0, ?, ?, ?) + """, + (name, kind, provider_id, is_enabled, config_json, now, now), + ) + route_id = int(cursor.lastrowid) + for position, member_provider_id in enumerate(member_provider_ids): + conn.execute( + "INSERT INTO route_members (route_id, provider_id, position) VALUES (?, ?, ?)", + (route_id, member_provider_id, position), + ) + self._ensure_route_defaults(conn, route_id=route_id, make_default=bool(is_default)) + active_row = conn.execute("SELECT value FROM app_settings WHERE key = 'active_route_id'").fetchone() + if active_row is None: + conn.execute( + """ + INSERT INTO app_settings (key, value) VALUES ('active_route_id', ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """, + (json.dumps(route_id, ensure_ascii=False),), + ) + conn.execute( + """ + INSERT INTO app_settings (key, value) VALUES ('use_structured_config', ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """, + (json.dumps(True, ensure_ascii=False),), + ) + return self.get_route(route_id) + + def delete_route(self, route_id): + with self._connect() as conn: + conn.execute("DELETE FROM routes WHERE id = ?", (route_id,)) + next_route = conn.execute("SELECT id FROM routes ORDER BY is_default DESC, id ASC LIMIT 1").fetchone() + self.set_setting("active_route_id", next_route["id"] if next_route else None) + + def set_active_route(self, route_id): + route = self.get_route(route_id) + if route is None: + raise ValueError(f"Route not found: {route_id}") + if not route["is_enabled"]: + raise ValueError(f"Route is disabled: {route['name']}") + self.set_setting("active_route_id", route["id"]) + self.set_setting("use_structured_config", True) + return route + + def get_test_config(self, family): + with self._connect() as conn: + row = conn.execute("SELECT * FROM model_test_config WHERE backend_family = ?", (family,)).fetchone() + return dict(row) if row else None + + def set_test_config(self, family, payload): + with self._connect() as conn: + conn.execute( + """ + INSERT INTO model_test_config (backend_family, test_model, api_mode, reasoning_effort, prompt) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(backend_family) DO UPDATE SET + test_model = excluded.test_model, + api_mode = excluded.api_mode, + reasoning_effort = excluded.reasoning_effort, + prompt = excluded.prompt + """, + ( + family, + payload.get("test_model"), + payload.get("api_mode"), + payload.get("reasoning_effort"), + payload.get("prompt") or "Reply with exactly: pong", + ), + ) + return self.get_test_config(family) + + def update_provider_health(self, provider_id, *, status, latency_ms=None, ttfb_ms=None, last_error="", checked_at=None): + with self._connect() as conn: + conn.execute( + """ + INSERT INTO provider_health (provider_id, status, latency_ms, ttfb_ms, last_checked_at, last_error) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(provider_id) DO UPDATE SET + status = excluded.status, + latency_ms = excluded.latency_ms, + ttfb_ms = excluded.ttfb_ms, + last_checked_at = excluded.last_checked_at, + last_error = excluded.last_error + """, + (provider_id, status, latency_ms, ttfb_ms, checked_at or utcnow_iso(), last_error or ""), + ) + + def append_diagnostic_event(self, *, provider_id=None, route_id=None, backend_name="", ok=False, error_kind=None, message="", status_code=None, extra=None): + created_at = utcnow_iso() + with self._connect() as conn: + conn.execute( + """ + INSERT INTO diagnostic_events + (provider_id, route_id, backend_name, ok, error_kind, message, status_code, created_at, extra_json) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + provider_id, + route_id, + backend_name or "", + 1 if ok else 0, + error_kind, + message or "", + status_code, + created_at, + self._dumps(extra), + ), + ) + + def list_diagnostic_events(self, limit=50): + with self._connect() as conn: + rows = conn.execute( + """ + SELECT * FROM diagnostic_events + ORDER BY id DESC + LIMIT ? + """, + (int(limit),), + ).fetchall() + events = [] + for row in rows: + event = dict(row) + event["ok"] = bool(event["ok"]) + event["extra"] = self._loads(event.pop("extra_json", None), {}) + events.append(event) + return events diff --git a/ga_switch/testing.py b/ga_switch/testing.py new file mode 100644 index 0000000..9053966 --- /dev/null +++ b/ga_switch/testing.py @@ -0,0 +1,67 @@ +import time + +from .diagnostics import classify_error +from .runtime_bridge import build_test_client + + +class ModelTester: + def __init__(self, service): + self.service = service + + def run(self, provider_id): + provider = self.service.store.get_provider(provider_id) + if not provider: + raise ValueError(f"Provider not found: {provider_id}") + family = provider["backend_family"] + test_cfg = self.service.store.get_test_config(family) or {} + client = build_test_client( + self.service, + provider, + override={ + "model": test_cfg.get("test_model") or provider.get("model"), + "api_mode": test_cfg.get("api_mode") or provider.get("api_mode"), + "reasoning_effort": test_cfg.get("reasoning_effort") if family == "oai" else provider.get("reasoning_effort"), + }, + ) + prompt = test_cfg.get("prompt") or "Reply with exactly: pong" + started = time.perf_counter() + first_chunk_at = None + raw_text = "" + response = None + gen = client.chat([{"role": "user", "content": prompt}], tools=None) + try: + while True: + chunk = next(gen) + if first_chunk_at is None: + first_chunk_at = time.perf_counter() + raw_text += chunk + except StopIteration as stop: + response = stop.value + finished = time.perf_counter() + backend = client.backend + latency_ms = int((finished - started) * 1000) + ttfb_ms = int(((first_chunk_at or finished) - started) * 1000) + last_error = getattr(backend, "last_error_message", "") or "" + error_kind = classify_error(status_code=getattr(backend, "last_status_code", None), message=last_error) if last_error else None + success = not (raw_text.startswith("Error:") or last_error) + status = "healthy" if success else "failed" + if success and latency_ms >= 15000: + status = "degraded" + self.service.store.update_provider_health( + provider_id, + status=status, + latency_ms=latency_ms, + ttfb_ms=ttfb_ms, + last_error=last_error, + ) + return { + "provider_id": provider_id, + "provider_name": provider["name"], + "status": status, + "latency_ms": latency_ms, + "ttfb_ms": ttfb_ms, + "last_error": last_error, + "error_kind": error_kind, + "raw_text": raw_text[:500], + "response_repr": repr(response) if response is not None else "", + } diff --git a/llmcore.py b/llmcore.py index a8887fb..a1cf3e2 100644 --- a/llmcore.py +++ b/llmcore.py @@ -1,8 +1,17 @@ import os, json, re, time, requests, sys, threading, urllib3, base64, mimetypes, uuid -from datetime import datetime +from datetime import datetime, timezone urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) _RESP_CACHE_KEY = str(uuid.uuid4()) + +def _utcnow_iso(): + return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z") + + +def _normalize_message(message, limit=2000): + text = "" if message is None else str(message).strip() + return text[:limit] + def _load_mykeys(): try: import mykey; return {k: v for k, v in vars(mykey).items() if not k.startswith('_')} @@ -253,7 +262,7 @@ def _stamp_oai_cache_markers(messages, model): def _openai_stream(api_base, api_key, messages, model, api_mode='chat_completions', *, temperature=0.5, max_tokens=None, tools=None, reasoning_effort=None, - max_retries=0, connect_timeout=10, read_timeout=300, proxies=None): + max_retries=0, connect_timeout=10, read_timeout=300, proxies=None, session=None): """Shared OpenAI-compatible streaming request with retry. Yields text chunks, returns list[content_block].""" ml = model.lower() if 'kimi' in ml or 'moonshot' in ml: temperature = 1 @@ -287,6 +296,8 @@ def _delay(resp, attempt): try: ra = float((resp.headers or {}).get("retry-after")) except: ra = None return max(0.5, ra if ra is not None else min(30.0, 1.5 * (2 ** attempt))) + started_at = time.perf_counter() + first_chunk_at = None for attempt in range(max_retries + 1): streamed = False try: @@ -306,8 +317,23 @@ def _delay(resp, attempt): e._err_body = err_body; raise gen = _parse_openai_sse(r.iter_lines(), api_mode) try: - while True: streamed = True; yield next(gen) + while True: + chunk = next(gen) + streamed = True + if first_chunk_at is None: + first_chunk_at = time.perf_counter() + yield chunk except StopIteration as e: + if session is not None: + session._record_success( + status_code=r.status_code, + message="OK", + extra={ + "latency_ms": int((time.perf_counter() - started_at) * 1000), + "ttfb_ms": int((((first_chunk_at or time.perf_counter()) - started_at) * 1000)), + "api_mode": api_mode, + }, + ) return e.value or [] except requests.HTTPError as e: resp = getattr(e, "response", None); status = getattr(resp, "status_code", None) @@ -321,6 +347,8 @@ def _delay(resp, attempt): try: h = resp.headers or {}; rid = h.get("x-request-id","") or h.get("request-id",""); ra = h.get("retry-after",""); ct = h.get("content-type","") except: pass err = f"Error: HTTP {status} {e}; content_type: {ct or ''}; retry_after: {ra or ''}; request_id: {rid or ''}; body: {body or ''}" + if session is not None: + session._record_error(err, status_code=status, body=body, exc_type=type(e).__name__, extra={"request_id": rid, "retry_after": ra, "content_type": ct, "api_mode": api_mode}) yield err; return [{"type": "text", "text": err}] except (requests.Timeout, requests.ConnectionError) as e: if attempt < max_retries and not streamed: @@ -328,9 +356,13 @@ def _delay(resp, attempt): print(f"[LLM Retry] {type(e).__name__}, retry in {d:.1f}s ({attempt+1}/{max_retries+1})") time.sleep(d); continue err = f"Error: {type(e).__name__}: {e}" + if session is not None: + session._record_error(err, exc_type=type(e).__name__, extra={"api_mode": api_mode}) yield err; return [{"type": "text", "text": err}] except Exception as e: err = f"Error: {e}" + if session is not None: + session._record_error(err, exc_type=type(e).__name__, extra={"api_mode": api_mode}) yield err; return [{"type": "text", "text": err}] def _to_responses_input(messages): @@ -424,7 +456,7 @@ def __init__(self, cfg): self.max_retries = max(0, int(cfg.get('max_retries', 1))) self.stream = cfg.get('stream', True) default_ct, default_rt = (5, 30) if self.stream else (10, 240) - self.connect_timeout = max(1, int(cfg.get('timeout', default_ct))) + self.connect_timeout = max(1, int(cfg.get('timeout', cfg.get('connect_timeout', default_ct)))) self.read_timeout = max(5, int(cfg.get('read_timeout', default_rt))) def _enum(key, valid): v = cfg.get(key); v = None if v is None else str(v).strip().lower() @@ -436,6 +468,38 @@ def _enum(key, valid): self.api_mode = 'responses' if mode in ('responses', 'response') else 'chat_completions' self.temperature = cfg.get('temperature', 1) self.max_tokens = cfg.get('max_tokens', 8192) + self.last_error_message = '' + self.last_error_at = None + self.last_ok_at = None + self.last_status_code = None + self.last_latency_ms = None + self.last_ttfb_ms = None + + def _record_success(self, status_code=200, message='OK', extra=None): + extra = dict(extra or {}) + self.last_status_code = status_code + self.last_ok_at = _utcnow_iso() + self.last_error_message = '' + if 'latency_ms' in extra: + self.last_latency_ms = extra.get('latency_ms') + if 'ttfb_ms' in extra: + self.last_ttfb_ms = extra.get('ttfb_ms') + + def _record_error(self, message, *, status_code=None, body='', exc_type='', error_kind=None, extra=None): + self.last_error_message = _normalize_message(message) + self.last_error_at = _utcnow_iso() + self.last_status_code = status_code + + def describe_diagnostics(self): + return { + 'last_error_message': self.last_error_message, + 'last_error_at': self.last_error_at, + 'last_ok_at': self.last_ok_at, + 'last_status_code': self.last_status_code, + 'last_latency_ms': self.last_latency_ms, + 'last_ttfb_ms': self.last_ttfb_ms, + } + def _apply_claude_thinking(self, payload): if self.thinking_type: thinking = {"type": self.thinking_type} @@ -474,11 +538,35 @@ def raw_ask(self, messages): if self.temperature != 1: payload["temperature"] = self.temperature self._apply_claude_thinking(payload) if self.system: payload["system"] = [{"type": "text", "text": self.system, "cache_control": {"type": "persistent"}}] + started_at = time.perf_counter() + first_chunk_at = None try: with requests.post(auto_make_url(self.api_base, "messages"), headers=headers, json=payload, stream=True, timeout=(self.connect_timeout, self.read_timeout)) as r: - if r.status_code != 200: raise Exception(f"HTTP {r.status_code} {r.content.decode('utf-8', errors='replace')[:500]}") - return (yield from _parse_claude_sse(r.iter_lines())) or [] + if r.status_code != 200: + body = r.content.decode('utf-8', errors='replace')[:500] + err = f"HTTP {r.status_code} {body}" + self._record_error(err, status_code=r.status_code, body=body, extra={"api_base": self.api_base}) + yield (err_text := f"Error: {err}") + return [{"type": "text", "text": err_text}] + gen = _parse_claude_sse(r.iter_lines()) + try: + while True: + chunk = next(gen) + if first_chunk_at is None: + first_chunk_at = time.perf_counter() + yield chunk + except StopIteration as e: + self._record_success( + status_code=r.status_code, + message="OK", + extra={ + "latency_ms": int((time.perf_counter() - started_at) * 1000), + "ttfb_ms": int((((first_chunk_at or time.perf_counter()) - started_at) * 1000)), + }, + ) + return e.value or [] except Exception as e: + self._record_error(str(e), exc_type=type(e).__name__, extra={"api_base": self.api_base}) yield (err := f"Error: {e}") return [{"type": "text", "text": err}] def make_messages(self, raw_list): @@ -493,7 +581,7 @@ def raw_ask(self, messages): return (yield from _openai_stream(self.api_base, self.api_key, messages, self.model, self.api_mode, temperature=self.temperature, reasoning_effort=self.reasoning_effort, max_tokens=self.max_tokens, max_retries=self.max_retries, - connect_timeout=self.connect_timeout, read_timeout=self.read_timeout, proxies=self.proxies)) + connect_timeout=self.connect_timeout, read_timeout=self.read_timeout, proxies=self.proxies, session=self)) def make_messages(self, raw_list): return _msgs_claude2oai(raw_list) def _fix_messages(messages): @@ -550,10 +638,34 @@ def raw_ask(self, messages): for idx in user_idxs[-2:]: messages[idx] = {**messages[idx], "content": list(messages[idx]["content"])} messages[idx]["content"][-1] = dict(messages[idx]["content"][-1], cache_control={"type": "ephemeral"}) + started_at = time.perf_counter() + first_chunk_at = None try: with requests.post(auto_make_url(self.api_base, "messages")+'?beta=true', headers=headers, json=payload, stream=self.stream, timeout=(self.connect_timeout, self.read_timeout)) as resp: - if resp.status_code != 200: raise Exception(f"HTTP {resp.status_code} {resp.content.decode('utf-8', errors='replace')[:500]}") - if self.stream: return (yield from _parse_claude_sse(resp.iter_lines())) or [] + if resp.status_code != 200: + body = resp.content.decode('utf-8', errors='replace')[:500] + err = f"HTTP {resp.status_code} {body}" + self._record_error(err, status_code=resp.status_code, body=body, extra={"api_base": self.api_base}) + yield (err_text := f"Error: {err}") + return [{"type": "text", "text": err_text}] + if self.stream: + gen = _parse_claude_sse(resp.iter_lines()) + try: + while True: + chunk = next(gen) + if first_chunk_at is None: + first_chunk_at = time.perf_counter() + yield chunk + except StopIteration as e: + self._record_success( + status_code=resp.status_code, + message="OK", + extra={ + "latency_ms": int((time.perf_counter() - started_at) * 1000), + "ttfb_ms": int((((first_chunk_at or time.perf_counter()) - started_at) * 1000)), + }, + ) + return e.value or [] else: data = resp.json(); content_blocks = data.get("content", []) usage = data.get("usage", {}) @@ -561,8 +673,14 @@ def raw_ask(self, messages): for b in content_blocks: if b.get("type") == "text": yield b.get("text", "") elif b.get("type") == "thinking": yield "" + self._record_success( + status_code=resp.status_code, + message="OK", + extra={"latency_ms": int((time.perf_counter() - started_at) * 1000), "ttfb_ms": 0}, + ) return content_blocks except Exception as e: + self._record_error(str(e), exc_type=type(e).__name__, extra={"api_base": self.api_base}) yield (err := f"Error: {e}") return [{"type": "text", "text": err}] @@ -603,7 +721,7 @@ def raw_ask(self, messages): temperature=self.temperature, max_tokens=self.max_tokens, tools=self.tools, reasoning_effort=self.reasoning_effort, max_retries=self.max_retries, connect_timeout=self.connect_timeout, - read_timeout=self.read_timeout, proxies=self.proxies)) + read_timeout=self.read_timeout, proxies=self.proxies, session=self)) def openai_tools_to_claude(tools): """[{type:'function', function:{name,description,parameters}}] → [{name,description,input_schema}].""" @@ -821,6 +939,15 @@ def __init__(self, all_sessions, cfg): self._sessions[0].raw_ask = self._raw_ask self.model = getattr(self._sessions[0], 'model', None) self._cur_idx, self._switched_at = 0, 0.0 + self.active_session_index = 0 + self.active_member_name = self._sessions[0].name + self.last_switch_reason = '' + self.last_error_message = '' + self.last_error_at = None + self.last_ok_at = None + self.last_status_code = None + self.last_latency_ms = None + self.last_ttfb_ms = None def __getattr__(self, name): return getattr(self._sessions[0], name) _BROADCAST_ATTRS = frozenset({'system', 'tools', 'temperature', 'max_tokens', 'reasoning_effort'}) def __setattr__(self, name, value): @@ -831,8 +958,28 @@ def __setattr__(self, name, value): else: object.__setattr__(self, name, value) @property def primary(self): return self._sessions[0] + def _sync_diagnostics_from(self, session): + for attr in ('last_error_message', 'last_error_at', 'last_ok_at', 'last_status_code', 'last_latency_ms', 'last_ttfb_ms'): + setattr(self, attr, getattr(session, attr, None)) + def describe_diagnostics(self): + return { + 'last_error_message': self.last_error_message, + 'last_error_at': self.last_error_at, + 'last_ok_at': self.last_ok_at, + 'last_status_code': self.last_status_code, + 'last_latency_ms': self.last_latency_ms, + 'last_ttfb_ms': self.last_ttfb_ms, + 'active_session_index': self.active_session_index, + 'active_member_name': self.active_member_name, + 'last_switch_reason': self.last_switch_reason, + 'spring_back_seconds': self._spring_sec, + } def _pick(self): - if self._cur_idx and time.time() - self._switched_at > self._spring_sec: self._cur_idx = 0 + if self._cur_idx and time.time() - self._switched_at > self._spring_sec: + self._cur_idx = 0 + self.active_session_index = 0 + self.active_member_name = self._sessions[0].name + self.last_switch_reason = 'spring_back' return self._cur_idx def _raw_ask(self, *args, **kwargs): base, n = self._pick(), len(self._sessions) @@ -850,8 +997,18 @@ def _raw_ask(self, *args, **kwargs): except StopIteration as e: return_val = e.value or [] is_err = test_error(last_chunk) if not is_err: - if attempt > 0: self._cur_idx = idx; self._switched_at = time.time() + if attempt > 0: + self._cur_idx = idx + self._switched_at = time.time() + self.last_switch_reason = f'fallback_success:{self._sessions[idx].name}' + self.active_session_index = idx + self.active_member_name = self._sessions[idx].name + self._sync_diagnostics_from(self._sessions[idx]) return return_val + self.active_session_index = idx + self.active_member_name = self._sessions[idx].name + self.last_switch_reason = (last_chunk or '')[:240] + self._sync_diagnostics_from(self._sessions[idx]) if attempt >= self._retries: yield last_chunk; return return_val nxt = (base + attempt + 1) % n @@ -917,4 +1074,4 @@ def chat(self, messages, tools=None): except StopIteration as e: resp = e.value if resp: _write_llm_log('Response', resp.raw) if resp and hasattr(resp, 'tool_calls') and resp.tool_calls: self._pending_tool_ids = [tc.id for tc in resp.tool_calls] - return resp \ No newline at end of file + return resp diff --git a/tests/test_ga_switch.py b/tests/test_ga_switch.py new file mode 100644 index 0000000..aeaefd6 --- /dev/null +++ b/tests/test_ga_switch.py @@ -0,0 +1,353 @@ +import ast +import inspect +import json +import os +import sys +import tempfile +import unittest +from unittest.mock import MagicMock, patch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +class GASwitchTestCase(unittest.TestCase): + def make_service(self): + from ga_switch.service import GASwitchService + + tmp_dir = tempfile.mkdtemp(prefix="ga-switch-test-") + return GASwitchService(os.path.join(tmp_dir, "ga-switch.db")) + + def make_oai_provider(self, service, name="p1", backend_kind="oai_text", model="gpt-4.1-mini"): + return service.upsert_provider({ + "name": name, + "backend_kind": backend_kind, + "apikey": "test-key", + "apibase": "https://api.example.com/v1", + "model": model, + "proxy": "http://127.0.0.1:8080", + "extra": {"token": "secret"}, + }) + + def make_agent(self, service): + with patch("agentmain.get_service", return_value=service): + from agentmain import GeneraticAgent + + return GeneraticAgent() + + +class TestDependencyBoundaries(GASwitchTestCase): + def test_llmcore_does_not_import_ga_switch(self): + root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + with open(os.path.join(root, "llmcore.py"), encoding="utf-8") as f: + tree = ast.parse(f.read(), filename="llmcore.py") + + imports = [] + for node in ast.walk(tree): + if isinstance(node, ast.Import): + imports.extend(alias.name for alias in node.names) + elif isinstance(node, ast.ImportFrom) and node.module: + imports.append(node.module) + self.assertFalse(any(name == "ga_switch" or name.startswith("ga_switch.") for name in imports)) + + def test_service_does_not_import_agentmain(self): + root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + with open(os.path.join(root, "ga_switch", "service.py"), encoding="utf-8") as f: + tree = ast.parse(f.read(), filename="ga_switch/service.py") + + imports = [] + for node in ast.walk(tree): + if isinstance(node, ast.Import): + imports.extend(alias.name for alias in node.names) + elif isinstance(node, ast.ImportFrom) and node.module: + imports.append(node.module) + self.assertFalse(any(name == "agentmain" or name.startswith("agentmain.") for name in imports)) + + def test_get_config_snapshot_does_not_accept_agent(self): + from ga_switch.service import GASwitchService + + signature = inspect.signature(GASwitchService.get_config_snapshot) + self.assertNotIn("agent", signature.parameters) + + +class TestGASwitchImport(GASwitchTestCase): + def test_import_legacy_mykey_json_to_structured_routes(self): + service = self.make_service() + tmp_dir = tempfile.mkdtemp(prefix="ga-switch-import-") + legacy_path = os.path.join(tmp_dir, "mykey.json") + with open(legacy_path, "w", encoding="utf-8") as f: + json.dump({ + "oai_config": { + "name": "kimi-primary", + "apikey": "k1", + "apibase": "https://api.example.com/v1", + "model": "kimi-k2.5", + }, + "oai_config_backup": { + "name": "glm-backup", + "apikey": "k2", + "apibase": "https://api.example.com/v1", + "model": "glm-5.1", + }, + "mixin_config": { + "name": "fallback-pair", + "llm_nos": ["kimi-primary", "glm-backup"], + "max_retries": 2, + "spring_back": 60, + }, + }, f, ensure_ascii=False) + + result = service.import_legacy_mykey(legacy_path) + + self.assertEqual(len(result["providers"]), 2) + self.assertEqual({route["kind"] for route in result["routes"]}, {"single", "failover"}) + failover = next(route for route in result["routes"] if route["kind"] == "failover") + self.assertEqual([member["name"] for member in failover["members"]], ["kimi-primary", "glm-backup"]) + self.assertTrue(service.use_structured_config()) + + +class TestGASwitchValidation(GASwitchTestCase): + def test_provider_update_round_trips(self): + service = self.make_service() + provider = self.make_oai_provider(service, name="editable") + + updated = service.upsert_provider({ + "id": provider["id"], + "name": "editable", + "backend_kind": "oai_text", + "apikey": "test-key-2", + "apibase": "https://api.example.com/v1", + "model": "updated-model", + "timeout": 9, + }) + + self.assertEqual(updated["model"], "updated-model") + self.assertEqual(updated["timeout"], 9) + + def test_failover_rejects_native_and_non_native_mix(self): + service = self.make_service() + native = self.make_oai_provider(service, name="native", backend_kind="native_oai") + text = self.make_oai_provider(service, name="text", backend_kind="oai_text") + + with self.assertRaisesRegex(ValueError, "cannot mix native and non-native"): + service.upsert_route({ + "name": "bad-route", + "kind": "failover", + "member_provider_ids": [native["id"], text["id"]], + }) + + def test_failover_member_order_is_persisted(self): + service = self.make_service() + p1 = self.make_oai_provider(service, name="alpha") + p2 = self.make_oai_provider(service, name="beta") + p3 = self.make_oai_provider(service, name="gamma") + + route = service.upsert_route({ + "name": "fallback-route", + "kind": "failover", + "member_provider_ids": [p3["id"], p1["id"], p2["id"]], + "is_default": True, + }) + + self.assertEqual(route["member_provider_ids"], [p3["id"], p1["id"], p2["id"]]) + self.assertEqual([member["name"] for member in route["members"]], ["gamma", "alpha", "beta"]) + + +class TestSnapshots(GASwitchTestCase): + def test_config_snapshot_redacts_sensitive_fields(self): + service = self.make_service() + p1 = self.make_oai_provider(service, name="alpha", model="m1") + p2 = self.make_oai_provider(service, name="beta", model="m2") + service.upsert_route({ + "name": "fallback-route", + "kind": "failover", + "member_provider_ids": [p1["id"], p2["id"]], + "is_default": True, + }) + + snapshot = service.get_config_snapshot() + + for provider in snapshot["providers"]: + self.assertNotIn("apikey", provider) + self.assertNotIn("apibase", provider) + self.assertNotIn("proxy", provider) + self.assertNotIn("extra", provider) + for route in snapshot["routes"]: + if route["provider"] is not None: + self.assertNotIn("apikey", route["provider"]) + self.assertNotIn("apibase", route["provider"]) + for member in route["members"]: + self.assertNotIn("apikey", member) + self.assertNotIn("apibase", member) + self.assertNotIn("proxy", member) + self.assertNotIn("extra", member) + + def test_runtime_snapshot_uses_backend_safe_contract(self): + from ga_switch.runtime_bridge import build_runtime_snapshot + + service = self.make_service() + p1 = self.make_oai_provider(service, name="alpha", backend_kind="oai_text", model="m1") + p2 = self.make_oai_provider(service, name="beta", backend_kind="oai_text", model="m2") + service.upsert_route({"name": "alpha-route", "kind": "single", "provider_id": p1["id"], "is_default": True}) + service.upsert_route({"name": "fallback-route", "kind": "failover", "member_provider_ids": [p1["id"], p2["id"]]}) + + agent = self.make_agent(service) + snapshot = build_runtime_snapshot(service.get_config_snapshot(), agent.describe_llms()) + + self.assertEqual(snapshot["active_route_summary"]["route_name"], "alpha-route") + self.assertEqual(snapshot["active_route_summary"]["provider_name"], "alpha") + self.assertEqual(snapshot["stats"]["route_count"], 2) + self.assertNotIn("quick_actions", snapshot) + self.assertNotIn("edit_groups", snapshot) + self.assertNotIn("active_runtime", snapshot) + for provider in snapshot["providers"]: + self.assertNotIn("apikey", provider) + self.assertNotIn("apibase", provider) + + +class TestDiagnostics(unittest.TestCase): + def test_classify_error_common_cases(self): + from ga_switch.diagnostics import classify_error + + self.assertEqual(classify_error(status_code=401, message="Unauthorized"), "auth") + self.assertEqual(classify_error(status_code=429, body="insufficient_quota"), "quota") + self.assertEqual(classify_error(status_code=429, body="rate limit exceeded"), "rate_limit") + self.assertEqual(classify_error(status_code=404, body="model not found"), "model_not_found") + self.assertEqual(classify_error(message="unsupported parameter reasoning_effort"), "unsupported_param") + self.assertEqual(classify_error(exc_type="Timeout", message="timed out"), "timeout") + self.assertEqual(classify_error(exc_type="ConnectionError", message="connection refused"), "network") + + +class TestAgentRuntime(GASwitchTestCase): + def test_set_active_route_switches_future_runtime(self): + from ga_switch.runtime_bridge import build_runtime_snapshot + + service = self.make_service() + p1 = self.make_oai_provider(service, name="route-a", model="m1") + p2 = self.make_oai_provider(service, name="route-b", model="m2") + route_a = service.upsert_route({"name": "route-a", "kind": "single", "provider_id": p1["id"], "is_default": True}) + route_b = service.upsert_route({"name": "route-b", "kind": "single", "provider_id": p2["id"]}) + + agent = self.make_agent(service) + switched = agent.set_active_route(route_b["id"]) + snapshot = build_runtime_snapshot(service.get_config_snapshot(), agent.describe_llms()) + + self.assertEqual(switched["route_id"], route_b["id"]) + self.assertEqual(agent.llmclient.ga_switch_route_name, "route-b") + self.assertEqual(snapshot["active_route_id"], route_b["id"]) + self.assertEqual(snapshot["active_route_summary"]["route_name"], "route-b") + self.assertNotEqual(route_a["id"], route_b["id"]) + + def test_reload_llm_config_preserves_history_and_blocks_running(self): + service = self.make_service() + p1 = self.make_oai_provider(service, name="route-a", backend_kind="oai_text", model="m1") + service.upsert_route({"name": "route-a", "kind": "single", "provider_id": p1["id"], "is_default": True}) + + agent = self.make_agent(service) + agent.llmclient.backend.history = [{"role": "user", "content": [{"type": "text", "text": "keep me"}]}] + + p2 = self.make_oai_provider(service, name="route-b", backend_kind="oai_text", model="m2") + service.upsert_route({"name": "route-b", "kind": "single", "provider_id": p2["id"]}) + described = agent.reload_llm_config() + + self.assertEqual(agent.llmclient.backend.history[0]["content"][0]["text"], "keep me") + self.assertEqual(agent.llmclient.ga_switch_route_name, "route-a") + self.assertEqual(len(described), 2) + + agent.is_running = True + with self.assertRaisesRegex(RuntimeError, "Cannot reload"): + agent.reload_llm_config() + + def test_structured_mode_failure_does_not_fallback_to_legacy(self): + import llmcore + from ga_switch.diagnostics import utcnow_iso + + service = self.make_service() + now = utcnow_iso() + with service.store._connect() as conn: + cursor = conn.execute( + """ + INSERT INTO routes (name, kind, provider_id, is_enabled, is_default, config_json, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ("broken-route", "single", None, 1, 1, "{}", now, now), + ) + route_id = int(cursor.lastrowid) + conn.execute( + """ + INSERT INTO app_settings (key, value) VALUES ('active_route_id', ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """, + (json.dumps(route_id),), + ) + conn.execute( + """ + INSERT INTO app_settings (key, value) VALUES ('use_structured_config', ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """, + (json.dumps(True),), + ) + + legacy_payload = { + "oai_config": { + "name": "legacy-ok", + "apikey": "legacy-key", + "apibase": "https://api.example.com/v1", + "model": "legacy-model", + } + } + llmcore.__dict__.pop("mykeys", None) + llmcore.__dict__.pop("proxies", None) + with patch.object(llmcore, "_load_mykeys", return_value=legacy_payload): + with self.assertRaisesRegex(ValueError, "missing provider"): + self.make_agent(service) + + def test_failover_runtime_keeps_member_order_and_diagnostics_fields(self): + service = self.make_service() + p1 = self.make_oai_provider(service, name="alpha", model="m1") + p2 = self.make_oai_provider(service, name="beta", model="m2") + p3 = self.make_oai_provider(service, name="gamma", model="m3") + service.upsert_route({ + "name": "fallback-route", + "kind": "failover", + "member_provider_ids": [p3["id"], p1["id"], p2["id"]], + "is_default": True, + "config": {"spring_back": 60}, + }) + + agent = self.make_agent(service) + described = agent.describe_llms() + + self.assertEqual(len(described), 1) + self.assertEqual(described[0]["member_names"], ["gamma", "alpha", "beta"]) + self.assertEqual(described[0]["active_member_name"], "gamma") + self.assertIn("last_error_message", described[0]) + self.assertIn("last_status_code", described[0]) + self.assertIn("spring_back_seconds", described[0]) + + +class TestModelTester(GASwitchTestCase): + def test_model_test_uses_temporary_session(self): + from ga_switch.runtime_bridge import build_test_client + + service = self.make_service() + provider = self.make_oai_provider(service, name="tester", backend_kind="oai_text") + fresh_client = build_test_client(service, provider) + fresh_client.backend.history = [{"role": "user", "content": [{"type": "text", "text": "real history"}]}] + + def fake_post(url, headers=None, json=None, stream=None, timeout=None, proxies=None): + resp = MagicMock() + resp.status_code = 200 + resp.iter_lines.return_value = iter([b"data: [DONE]"]) + resp.__enter__ = lambda s: s + resp.__exit__ = MagicMock(return_value=False) + return resp + + with patch("llmcore.requests.post", side_effect=fake_post): + result = service.run_model_test(provider["id"]) + + self.assertEqual(result["status"], "healthy") + self.assertEqual(fresh_client.backend.history[0]["content"][0]["text"], "real history") + + +if __name__ == "__main__": + unittest.main()