diff --git a/.gitignore b/.gitignore index b0d27dd..ada1720 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ dist/ *.egg-info/ .streamlit/ +.ga-switch/ .vscode/ .idea/ @@ -22,6 +23,7 @@ dist/ Thumbs.db *.log +api_server.pid .env auth.json model_responses.txt diff --git a/API_SERVER_README.md b/API_SERVER_README.md new file mode 100644 index 0000000..7de76bf --- /dev/null +++ b/API_SERVER_README.md @@ -0,0 +1,45 @@ +# GA Switch API Server + +## 快速启动 + +### Windows +```bash +start_api_server.bat +``` + +### Linux/macOS +```bash +python api_server.py +``` + +服务器将在 `http://127.0.0.1:8765` 启动。 + +## API 端点 + +- `GET /api/health` - 健康检查 +- `GET /api/snapshot` - 获取完整快照 +- `GET /api/routes` - 路由列表 +- `POST /api/routes` - 创建路由 +- `PUT /api/routes/{id}` - 更新路由 +- `DELETE /api/routes/{id}` - 删除路由 +- `POST /api/routes/{id}/activate` - 激活路由 +- `GET /api/providers` - Provider 列表 +- `POST /api/providers` - 创建 Provider +- `PUT /api/providers/{id}` - 更新 Provider +- `DELETE /api/providers/{id}` - 删除 Provider +- `POST /api/providers/{id}/test` - 测试 Provider +- `GET /api/diagnostics` - 诊断事件 +- `POST /api/reload` - 软重载 +- `POST /api/import-legacy` - 导入配置 + +## 依赖安装 + +```bash +pip install -r requirements-api.txt +``` + +## 测试 + +```bash +curl http://127.0.0.1:8765/api/health +``` diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 0b95e1d..354030a 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -171,6 +171,12 @@ Agent 会自己读代码、找出需要的包、全部装好。 python3 launch.pyw ``` +如果额外安装了 `PySide6`,`launch.pyw` 会优先进入 Qt 桌面前端;否则自动回退到现有的 `Streamlit + pywebview` 窗口。 + +```bash +pip install PySide6 +``` + 启动后会出现一个桌面悬浮窗,直接在里面输入任务指令。 ### 可选:让 Agent 帮你做的事 @@ -266,4 +272,4 @@ GenericAgent 不预设技能,而是**靠使用进化**。每完成一个新任 > Agent 会自动 pull 最新代码并解读 commit log,告诉你新增了什么能力。 -> 更多细节请参阅 [README.md](README.md) 或 [详细版图文教程](https://my.feishu.cn/wiki/CGrDw0T76iNFuskmwxdcWrpinPb)。 \ No newline at end of file +> 更多细节请参阅 [README.md](README.md) 或 [详细版图文教程](https://my.feishu.cn/wiki/CGrDw0T76iNFuskmwxdcWrpinPb)。 diff --git a/README.md b/README.md index 8c6cb8b..e7595a5 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,9 @@ cd GenericAgent # 2. Install minimal dependencies pip install streamlit pywebview +# 2.5 Optional: install the preferred Qt desktop shell +pip install PySide6 + # 3. Configure API Key cp mykey_template.py mykey.py # Edit mykey.py and fill in your LLM API Key @@ -88,6 +91,8 @@ cp mykey_template.py mykey.py python launch.pyw ``` +`launch.pyw` now prefers the Qt desktop frontend when `PySide6` is available, and falls back to the Streamlit + `pywebview` shell otherwise. + Full guide: [GETTING_STARTED.md](GETTING_STARTED.md) --- @@ -108,10 +113,10 @@ python frontends/tgapp.py ### Alternative App Frontends -Besides the default Streamlit web UI, you can also try other frontend styles: +Besides the default launch flow, you can also try other frontend styles directly: ```bash -python frontends/qtapp.py # Qt-based desktop app +python frontends/qtapp.py # Qt desktop app with Route Center streamlit run frontends/stapp2.py # Alternative Streamlit UI ``` @@ -271,6 +276,7 @@ cd GenericAgent # 2. 安装最小依赖 pip install streamlit pywebview +pip install PySide6 # 可选:启用默认优先的 Qt 桌面壳 # 3. 配置 API Key cp mykey_template.py mykey.py @@ -280,6 +286,8 @@ cp mykey_template.py mykey.py python launch.pyw ``` +如果本机已安装 `PySide6`,`launch.pyw` 会优先进入 Qt 桌面前端;否则自动回退到 `Streamlit + pywebview`。 + 完整引导流程见 [GETTING_STARTED.md](GETTING_STARTED.md)。 📖 新手使用指南(图文版):[飞书文档](https://my.feishu.cn/wiki/CGrDw0T76iNFuskmwxdcWrpinPb) @@ -370,10 +378,10 @@ dingtalk_allowed_users = ["your_staff_id"] # 或 ['*'] ### 其他 App 前端 -除默认的 Streamlit Web UI 外,还可以尝试不同风格的前端: +除默认启动流程外,还可以直接尝试不同风格的前端: ```bash -python frontends/qtapp.py # 基于 Qt 的桌面应用 +python frontends/qtapp.py # 基于 Qt 的桌面应用(含 Route Center) streamlit run frontends/stapp2.py # 另一种 Streamlit 风格 UI ``` diff --git a/agentmain.py b/agentmain.py index 4643d3c..275fc38 100644 --- a/agentmain.py +++ b/agentmain.py @@ -9,11 +9,12 @@ from llmcore import LLMSession, ToolClient, ClaudeSession, MixinSession, NativeToolClient, NativeClaudeSession, NativeOAISession from agent_loop import agent_runner_loop from ga import GenericAgentHandler, smart_format, get_global_memory, format_error, consume_file +from ga_switch import get_service script_dir = os.path.dirname(os.path.abspath(__file__)) def load_tool_schema(suffix=''): global TOOLS_SCHEMA - TS = open(os.path.join(script_dir, f'assets/tools_schema{suffix}.json'), 'r', encoding='utf-8').read() + with open(os.path.join(script_dir, f'assets/tools_schema{suffix}.json'), 'r', encoding='utf-8') as f: TS = f.read() TOOLS_SCHEMA = json.loads(TS if os.name == 'nt' else TS.replace('powershell', 'bash')) load_tool_schema() @@ -43,47 +44,181 @@ class GeneraticAgent: def __init__(self): script_dir = os.path.dirname(os.path.abspath(__file__)) os.makedirs(os.path.join(script_dir, 'temp'), exist_ok=True) + self.lock = threading.Lock() + self.task_dir = None + self.history = [] + self.task_queue = queue.Queue() + self.is_running = False; self.stop_sig = False + self.llm_no = 0; self.llmclient = None; self.llmclients = [] + self.config_source = 'legacy'; self.config_meta = {} + self.ga_switch = get_service() + self.inc_out = False + self.handler = None; self.verbose = True + self._reload_clients(initial=True) + + def _sync_tool_schema(self): + name = self.get_llm_name().lower() + if 'glm' in name or 'minimax' in name or 'kimi' in name: load_tool_schema('_cn') + else: load_tool_schema() + + def _tag_client(self, client, *, route_name=None, route_kind='single', backend_kind=None, members=None): + client.ga_switch_route_id = getattr(client, 'ga_switch_route_id', None) + client.ga_switch_route_name = route_name or getattr(client, 'ga_switch_route_name', getattr(client.backend, 'name', '')) + client.ga_switch_route_kind = route_kind + client.ga_switch_backend_kind = backend_kind or getattr(client, 'ga_switch_backend_kind', None) + client.ga_switch_members = list(members or getattr(client, 'ga_switch_members', [])) + return client + + def build_llmclients_from_store(self): + return self.ga_switch.build_clients_from_store() + + def build_llmclients_from_legacy_mykey(self): from llmcore import mykeys llm_sessions = [] for k, cfg in mykeys.items(): - if not any(x in k for x in ['api', 'config', 'cookie']): continue + if not any(x in k for x in ['api', 'config', 'cookie']): + continue try: - if 'native' in k and 'claude' in k: llm_sessions += [NativeToolClient(NativeClaudeSession(cfg=cfg))] - elif 'native' in k and 'oai' in k: llm_sessions += [NativeToolClient(NativeOAISession(cfg=cfg))] - elif 'claude' in k: llm_sessions += [ToolClient(ClaudeSession(cfg=cfg))] - elif 'oai' in k: llm_sessions += [ToolClient(LLMSession(cfg=cfg))] - elif 'mixin' in k: llm_sessions += [{'mixin_cfg': cfg}] - except: pass + if 'native' in k and 'claude' in k: + llm_sessions.append(self._tag_client(NativeToolClient(NativeClaudeSession(cfg=cfg)), route_name=cfg.get('name') or k, backend_kind='native_claude')) + elif 'native' in k and 'oai' in k: + llm_sessions.append(self._tag_client(NativeToolClient(NativeOAISession(cfg=cfg)), route_name=cfg.get('name') or k, backend_kind='native_oai')) + elif 'claude' in k: + llm_sessions.append(self._tag_client(ToolClient(ClaudeSession(cfg=cfg)), route_name=cfg.get('name') or k, backend_kind='claude_text')) + elif 'oai' in k: + llm_sessions.append(self._tag_client(ToolClient(LLMSession(cfg=cfg)), route_name=cfg.get('name') or k, backend_kind='oai_text')) + elif 'mixin' in k: + llm_sessions.append({'mixin_cfg': cfg, 'route_name': cfg.get('name') or k}) + except Exception as e: + print(f'[WARN] Failed to init legacy session {k}: {e}') for i, s in enumerate(llm_sessions): if isinstance(s, dict) and 'mixin_cfg' in s: try: mixin = MixinSession(llm_sessions, s['mixin_cfg']) - if isinstance(mixin._sessions[0], (NativeClaudeSession, NativeOAISession)): llm_sessions[i] = NativeToolClient(mixin) - else: llm_sessions[i] = ToolClient(mixin) - except Exception as e: print(f'[WARN] Failed to init MixinSession with cfg {s["mixin_cfg"]}: {e}') - self.llmclients = llm_sessions - self.lock = threading.Lock() - self.task_dir = None - self.history = [] - self.task_queue = queue.Queue() - self.is_running = False; self.stop_sig = False - self.llm_no = 0; self.inc_out = False - self.handler = None; self.verbose = True + client = NativeToolClient(mixin) if isinstance(mixin._sessions[0], (NativeClaudeSession, NativeOAISession)) else ToolClient(mixin) + llm_sessions[i] = self._tag_client(client, route_name=s['route_name'], route_kind='failover', backend_kind='mixin') + except Exception as e: + print(f'[WARN] Failed to init MixinSession with cfg {s["mixin_cfg"]}: {e}') + llm_sessions = [s for s in llm_sessions if not isinstance(s, dict)] + return llm_sessions, {'source': 'legacy', 'active_index': min(self.llm_no, max(len(llm_sessions) - 1, 0)), 'routes': []} + + def _build_client_set(self): + if self.ga_switch.use_structured_config() and self.ga_switch.has_usable_routes(): + try: + clients, meta = self.build_llmclients_from_store() + if clients: + meta = dict(meta or {}, source='store') + return clients, meta + except Exception as e: + print(f'[WARN] Structured config load failed, fallback to legacy: {e}') + return self.build_llmclients_from_legacy_mykey() + + def _reload_clients(self, *, initial=False, preserve_history=True): + old_client = self.llmclient + old_history = getattr(old_client.backend, 'history', None) if old_client and preserve_history else None + old_route_id = getattr(old_client, 'ga_switch_route_id', None) if old_client else None + old_idx = self.llm_no + clients, meta = self._build_client_set() + self.llmclients = clients + self.config_source = meta.get('source', 'legacy') + self.config_meta = meta + if not self.llmclients: + self.llm_no = 0 + self.llmclient = None + return [] + target_idx = meta.get('active_index', 0) + if not initial and preserve_history: + if self.config_source == 'store' and old_route_id is not None: + matched_idx = next((i for i, client in enumerate(self.llmclients) if getattr(client, 'ga_switch_route_id', None) == old_route_id), None) + if matched_idx is not None: + target_idx = matched_idx + elif old_idx < len(self.llmclients): + target_idx = old_idx + self.llm_no = target_idx % len(self.llmclients) self.llmclient = self.llmclients[self.llm_no] + if preserve_history and old_history is not None: + self.llmclient.backend.history = old_history + if self.config_source == 'store' and getattr(self.llmclient, 'ga_switch_route_id', None) is not None: + self.ga_switch.set_active_route(self.llmclient.ga_switch_route_id) + self._sync_tool_schema() + return self.llmclients def next_llm(self, n=-1): + if not self.llmclients: + self.llmclient = None + return None self.llm_no = ((self.llm_no + 1) if n < 0 else n) % len(self.llmclients) lastc = self.llmclient self.llmclient = self.llmclients[self.llm_no] - self.llmclient.backend.history = lastc.backend.history - self.llmclient.last_tools = '' - name = self.get_llm_name().lower() - if 'glm' in name or 'minimax' in name or 'kimi' in name: load_tool_schema('_cn') - else: load_tool_schema() - def list_llms(self): return [(i, self.get_llm_name(b), i == self.llm_no) for i, b in enumerate(self.llmclients)] - def get_llm_name(self, b=None): - b = self.llmclient if b is None else b - return f"{type(b.backend).__name__}/{b.backend.name}" if not isinstance(b, dict) else "BADCONFIG_MIXIN" + if lastc is not None: + self.llmclient.backend.history = lastc.backend.history + if hasattr(self.llmclient, 'last_tools'): + self.llmclient.last_tools = '' + if self.config_source == 'store' and getattr(self.llmclient, 'ga_switch_route_id', None) is not None: + self.ga_switch.set_active_route(self.llmclient.ga_switch_route_id) + self._sync_tool_schema() + return self.llmclient + + def set_active_route(self, route_id_or_idx): + if self.config_source == 'store': + target_idx = next((i for i, client in enumerate(self.llmclients) if getattr(client, 'ga_switch_route_id', None) == route_id_or_idx), None) + if target_idx is None and isinstance(route_id_or_idx, int) and 0 <= route_id_or_idx < len(self.llmclients): + target_idx = route_id_or_idx + if target_idx is None: + raise ValueError(f'Unknown route id: {route_id_or_idx}') + self.next_llm(target_idx) + return self.describe_llms()[self.llm_no] + if not isinstance(route_id_or_idx, int): + raise ValueError(f'Legacy mode only supports index switching, got {route_id_or_idx!r}') + self.next_llm(route_id_or_idx) + return self.describe_llms()[self.llm_no] + + def reload_llm_config(self, preserve_history=True): + if self.is_running: + raise RuntimeError('Cannot reload LLM config while agent is running.') + self._reload_clients(initial=False, preserve_history=preserve_history) + return self.describe_llms() + + def describe_llms(self): + result = [] + for idx, client in enumerate(self.llmclients): + backend = client.backend + diag = backend.describe_diagnostics() if hasattr(backend, 'describe_diagnostics') else {} + members = getattr(client, 'ga_switch_members', []) + route_name = getattr(client, 'ga_switch_route_name', getattr(backend, 'name', '')) + backend_class = type(backend).__name__ + item = { + 'idx': idx, + 'active': idx == self.llm_no, + 'source': self.config_source, + 'route_id': getattr(client, 'ga_switch_route_id', None), + 'name': route_name, + 'display_name': f"{route_name} [{backend_class}/{backend.name}]", + 'route_kind': getattr(client, 'ga_switch_route_kind', 'single'), + 'backend_class': backend_class, + 'backend_kind': getattr(client, 'ga_switch_backend_kind', getattr(backend, 'backend_kind', None)), + 'provider_id': getattr(backend, 'provider_id', None), + 'provider_name': getattr(backend, 'provider_name', getattr(backend, 'name', None)), + 'model': getattr(backend, 'model', None), + 'api_mode': getattr(backend, 'api_mode', None), + 'native_tools': isinstance(client, NativeToolClient) or 'Native' in backend_class, + 'member_names': [m.get('name', '') if isinstance(m, dict) else str(m) for m in members], + 'active_member_name': getattr(backend, 'active_member_name', getattr(backend, 'name', None)), + 'last_switch_reason': getattr(backend, 'last_switch_reason', ''), + 'spring_back_seconds': getattr(backend, '_spring_sec', None), + } + item.update(diag) + result.append(item) + return result + + def list_llms(self): + return [(item['idx'], item['display_name'], item['active']) for item in self.describe_llms()] + + def get_llm_name(self): + if self.llmclient is None: + return 'No LLM' + item = self.describe_llms()[self.llm_no] + return item['display_name'] def abort(self): if not self.is_running: return diff --git a/api_server.py b/api_server.py new file mode 100644 index 0000000..bb65f0b --- /dev/null +++ b/api_server.py @@ -0,0 +1,148 @@ +""" +GA Switch API Server +Minimal FastAPI server exposing GA backend functionality via REST API +""" +from fastapi import FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from typing import Optional +from contextlib import asynccontextmanager +import sys +import os + +# Add GA path +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from agentmain import GeneraticAgent +from ga_switch import get_service + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Initialize on startup + app.state.agent = GeneraticAgent() + app.state.service = get_service() + yield + # Cleanup on shutdown (if needed) + +app = FastAPI(title="GA Switch API", version="1.0.0", lifespan=lifespan) + +# CORS - restrict to local frontend only +app.add_middleware( + CORSMiddleware, + allow_origins=[ + "http://localhost:*", + "http://127.0.0.1:*", + "tauri://localhost", + ], + allow_methods=["GET", "POST", "PUT", "DELETE"], + allow_headers=["*"], + allow_credentials=True, +) + +# Models +class RoutePayload(BaseModel): + id: Optional[str] = None + name: str + kind: str + provider_id: Optional[str] = None + member_provider_ids: list[str] = [] + is_default: bool = False + is_enabled: bool = True + config: dict = {"max_retries": 3, "base_delay": 1.5, "spring_back": 300} + +class ProviderPayload(BaseModel): + id: Optional[str] = None + name: str + backend_kind: str + apikey: str + apibase: str + model: str + api_mode: str = "chat_completions" + temperature: float = 1.0 + max_tokens: int = 8192 + timeout: int = 5 + read_timeout: int = 30 + proxy: Optional[str] = None + extra: dict = {} + +# Endpoints +@app.get("/api/health") +def health(): + return {"status": "healthy", "version": "1.0.0"} + +@app.get("/api/snapshot") +def get_snapshot(request: Request): + return request.app.state.service.get_ui_snapshot(request.app.state.agent) + +@app.get("/api/routes") +def list_routes(request: Request): + snapshot = request.app.state.service.get_ui_snapshot(request.app.state.agent) + return snapshot.get("routes", []) + +@app.post("/api/routes") +def create_route(payload: RoutePayload, request: Request): + request.app.state.service.upsert_route(payload.model_dump()) + return {"success": True} + +@app.put("/api/routes/{route_id}") +def update_route(route_id: str, payload: RoutePayload, request: Request): + data = payload.model_dump() + data["id"] = route_id + request.app.state.service.upsert_route(data) + return {"success": True} + +@app.delete("/api/routes/{route_id}") +def delete_route(route_id: str, request: Request): + request.app.state.service.delete_route(route_id) + return {"success": True} + +@app.post("/api/routes/{route_id}/activate") +def activate_route(route_id: str, request: Request): + request.app.state.agent.set_active_route(route_id) + return {"success": True} + +@app.get("/api/providers") +def list_providers(request: Request): + snapshot = request.app.state.service.get_ui_snapshot(request.app.state.agent) + return snapshot.get("providers", []) + +@app.post("/api/providers") +def create_provider(payload: ProviderPayload, request: Request): + request.app.state.service.upsert_provider(payload.model_dump()) + return {"success": True} + +@app.put("/api/providers/{provider_id}") +def update_provider(provider_id: str, payload: ProviderPayload, request: Request): + data = payload.model_dump() + data["id"] = provider_id + request.app.state.service.upsert_provider(data) + return {"success": True} + +@app.delete("/api/providers/{provider_id}") +def delete_provider(provider_id: str, request: Request): + request.app.state.service.delete_provider(provider_id) + return {"success": True} + +@app.post("/api/providers/{provider_id}/test") +def test_provider(provider_id: str, request: Request): + result = request.app.state.service.run_model_test(provider_id) + return result + +@app.get("/api/diagnostics") +def get_diagnostics(request: Request): + snapshot = request.app.state.service.get_ui_snapshot(request.app.state.agent) + return snapshot.get("events", []) + +@app.post("/api/reload") +def reload_config(preserve_history: bool = True, request: Request = None): + request.app.state.agent.reload_llm_config(preserve_history=preserve_history) + return {"success": True} + +@app.post("/api/import-legacy") +def import_legacy(path: Optional[str] = None, request: Request = None): + request.app.state.service.import_legacy_mykey(path) + return {"success": True} + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="127.0.0.1", port=8765) diff --git a/docs/CC_SWITCH_UI_AUDIT.md b/docs/CC_SWITCH_UI_AUDIT.md new file mode 100644 index 0000000..36044e2 --- /dev/null +++ b/docs/CC_SWITCH_UI_AUDIT.md @@ -0,0 +1,38 @@ +# CC Switch UI Audit For GA Switch + +本地参考基线: + +- Clone 路径:`D:\DEV\harness\cc-switch` +- 主截图:`cc-switch/assets/screenshots/main-en.png` +- 添加 Provider 截图:`cc-switch/assets/screenshots/add-en.png` +- 关键源码: + - `cc-switch/src/App.tsx` + - `cc-switch/src/components/providers/ProviderList.tsx` + - `cc-switch/src/components/providers/ProviderCard.tsx` + - `cc-switch/src/components/providers/ProviderHealthBadge.tsx` + +## Keep + +- 顶部主状态区:品牌、当前主视图、明显的控制入口,信息层级很清楚。 +- 大圆角 provider 卡片:名称、URL、当前使用状态、余额/配额等摘要信息都在一层完成扫描。 +- 状态 chip:`current / health / failover priority` 这类信息直接贴在卡片上,不藏在详情页。 +- “列表为主,编辑单独展开”的节奏是对的,能降低误操作和视线跳跃。 + +## Change + +- 模块数量太多,不适合 GA 第一环照搬。GA 第一环只保留 `Routes / Providers / Diagnostics / Tests / Runtime`。 +- `cc-switch` 的壳偏“多产品总控台”,而 GA 的主价值还是“路由 + 聊天联动”,所以不能让聊天入口退到次要位置。 +- 它的 provider 体系过宽,包含多产品、多协议、多 OAuth 流程;GA 第一环只应该围绕现有 `provider / route / failover / diagnostics` 收口。 + +## GA-Specific Additions + +- 必须显式显示当前 active route,而不只是 active provider。 +- failover 路由必须显示成员顺序、当前 active member、最近切换原因。 +- 最近一次错误要在聊天页和管理页都可见,而不是只在诊断列表里。 +- 主聊天页需要显式切路由和跳转到管理页,避免“配置页”和“使用页”割裂。 + +## Implementation Direction + +- 管理台参考 `cc-switch` 的信息架构和卡片节奏,但不照搬它的全量模块。 +- 第一环以 Streamlit 为宿主,先保证可运行、可查验、可联动。 +- 后续如果要独立壳,再评估把当前三段式工作台迁移到更桌面化的宿主。 diff --git a/docs/GA_SWITCH_FRONTEND_HANDOFF.md b/docs/GA_SWITCH_FRONTEND_HANDOFF.md new file mode 100644 index 0000000..9e0e82d --- /dev/null +++ b/docs/GA_SWITCH_FRONTEND_HANDOFF.md @@ -0,0 +1,504 @@ +# GA Switch Frontend Handoff + +## 1. 文档目的 + +这份文档面向下一位接手 `GA Switch` 前端工作的工程师,目标是把当前工程状态、可运行环境、实现边界、近期收敛方向和后续独立视图方向一次性交代清楚,避免重复踩坑。 + +当前工作将分成两条线并行推进: + +1. `GA 仓库内 PR 线` + 继续优化现有 `Qt` 前端,把 `GA Switch` 收口成一个可合入 GenericAgent 主仓库的正式工作台。 +2. `独立仓库探索线` + 后续单独起一个仓库,做更完整、风格更成熟的独立视图,追求更好的视觉品质和交互体验。 + +这两条线共享同一套后端语义,不应该各自发明一套路由系统。 + +--- + +## 2. 项目上下文 + +### 2.1 当前目标 + +`GA Switch` 的核心目标不是“再加一个模型切换按钮”,而是把 GenericAgent 现有的多路由能力产品化: + +- 显式路由 +- 备用链路 / failover +- 健康状态 +- 最近错误 +- 手动测试 +- 软重载 +- 与聊天页联动 + +当前已经证明: + +- 多路由本身可用 +- `failover` 底层可用 +- `mykey.py` 可导入到结构化配置 +- `Qt` 前端已经能吃结构化快照并驱动切换/测试/诊断 + +因此下一阶段的主要问题已经不是“能不能做”,而是“如何把现有能力做成产品级前端”。 + +### 2.2 当前前端策略 + +当前主策略是: + +- `Qt 优先` +- `Streamlit 回退` +- 不在 GA 主仓库里引入 `Tauri / Rust / Electron` +- 不把本地 `pixi` 变成仓库默认前提 + +也就是说: + +- GenericAgent 内部的正式桌面前端方向是 `PySide6 / Qt` +- `Streamlit` 只保留兼容和回退价值 +- 更激进、更独立的技术选型留给未来单独仓库 + +--- + +## 3. 当前工程状态 + +### 3.1 关键模块 + +#### 后端与运行时 + +- [agentmain.py]() + - `GeneraticAgent` 持有 `ga_switch` + - 支持 `set_active_route(...)` + - 支持 `reload_llm_config(...)` + - 支持 `describe_llms()` +- [ga_switch/service.py]() + - 对外服务层 + - 提供 `get_ui_snapshot(agent)` + - 提供 `upsert_provider / upsert_route / run_model_test / import_legacy_mykey` +- [ga_switch/viewmodel.py]() + - 纯 Python 视图模型层 + - 将 `get_ui_snapshot()` 转成前端可直接消费的结构 + - 当前已经输出中文化的 section、summary、overview、empty_state、edit_groups +- [llmcore.py]() + - 底层 session 和诊断 + - 已记录 `last_error_kind / last_error_message / last_ok_at / last_status_code` + +#### Qt 前端 + +- [frontends/qtapp.py]() + - GenericAgent 原生桌面壳 + - 聊天页、历史页、手册页、路由页、设置页 + - 聊天页顶部已经有路由状态条 + - `launch.pyw` 成功进入 Qt 时会落到这里 +- [frontends/qt_switch.py]() + - `RouteCenterPage` + - 当前已经完成第二轮信息架构收口: + - `总览` + - `全部路由` + - `模型服务` + - `诊断记录` + - 已做全中文和渐进式披露的第一版 + +#### Streamlit 保留前端 + +- [frontends/stapp.py]() + - 旧主 UI,保留 fallback 价值 +- [frontends/ga_switch_admin.py]() + - 工程管理台 + - 不再是正式体验主方向 + +### 3.2 启动逻辑 + +- [launch.pyw]() + - 优先探测 `PySide6` + - 能用则进入 `Qt` + - 否则回退到 `Streamlit + pywebview` + +### 3.3 当前已有参考资料 + +- [docs/CC_SWITCH_UI_AUDIT.md]() + - 已对 `cc-switch` 做过一轮 UI 审核 + - 本地 clone 路径:`D:\DEV\harness\cc-switch` + +--- + +## 4. 本机环境说明 + +### 4.1 当前机器 + +- OS:Windows +- Shell:PowerShell 5.1 +- 工作区:`D:\DEV\harness` +- GenericAgent 仓库:`D:\DEV\harness\generic-agent` + +### 4.2 Python / 运行时现状 + +当前这台机器有两套现实: + +1. 仓库常规 Python 路径 + - `python` 解析到:`C:\Users\苏祎成\AppData\Local\Microsoft\WindowsApps\python.exe` +2. 本地可用的 `pixi` + - `where.exe pixi` -> `D:\EnvironmentCache\pixi_home\bin\pixi.exe` + - `pixi --version` -> `pixi 0.66.0` + +### 4.3 Qt 的本机特殊问题 + +这台机器上,GenericAgent 现有 `.venv` 里的 `PySide6` 不是可靠状态。 + +实际开发时,`Qt` 前端目前是通过下面这条命令稳定启动的: + +```powershell +pixi exec --spec "python=3.11" --spec pyside6 --spec requests python frontends\qtapp.py +``` + +这只是**本机开发 workaround**,不是仓库交付约束。 + +### 4.4 这条 workaround 的工程含义 + +必须明确: + +- 可以在本机继续用 `pixi` 调试 Qt +- 不能把 `pixi.toml`、锁文件、repo 级强依赖写进 GenericAgent PR +- 不能让“仓库默认可运行”依赖 `pixi` + +换句话说: + +- `pixi` 是本机开发环境 +- 不是 GenericAgent 主仓库的产品依赖 + +--- + +## 5. 当前工作约束 + +### 5.1 PR 线约束 + +下面这些约束已经明确,不建议在 GA 主仓库里打破: + +1. 不引入 `Tauri / Rust / Electron` +2. 不把 `pixi` 变成 repo 基础配置 +3. `PySide6` 仍然按“可选桌面依赖”处理 +4. `launch.pyw` 继续保持 `Qt 优先,Streamlit 回退` +5. `Streamlit` 前端必须还能作为 fallback 使用 +6. 不轻易改 `ga_switch` 的数据库 schema +7. UI 必须继续复用现有结构化接口,而不是绕过它直接读 session 细节 + +### 5.2 前端实现约束 + +Qt 前端继续工作时,请遵守下面的技术边界: + +1. 首选使用 [ga_switch/service.py]() 的 `get_ui_snapshot(agent)` +2. 前端状态整理由 [ga_switch/viewmodel.py]() 承担 +3. 具体动作继续走: + - `agent.set_active_route(...)` + - `agent.reload_llm_config(...)` + - `service.import_legacy_mykey(...)` + - `service.run_model_test(...)` + - `service.upsert_provider(...)` + - `service.upsert_route(...)` +4. 不要在 Qt 页里直接拼装 `llmcore` 内部细节 +5. 不要新增一套“只给 Qt 用、后端完全不认识”的路由语义 + +### 5.3 体验约束 + +这一轮前端已经锁定了几个产品方向,后续不应回退: + +1. 全中文用户界面 +2. 路由总览优先 +3. 渐进式披露 +4. 高级设置默认折叠 +5. 诊断原始 JSON 默认折叠 +6. 首屏不再出现工程态三栏工作台和常驻长表单 + +--- + +## 6. 当前 UI 收敛方向 + +### 6.1 已经完成的方向 + +目前 `Qt Route Center` 已经从工程台形态转成第一版产品态: + +- 一级导航改为: + - `总览` + - `全部路由` + - `模型服务` + - `诊断记录` +- 路由页默认先看摘要,不直接摊开表单 +- 模型服务页默认先看摘要和最近状态 +- 空状态有明确动作入口 +- 聊天页与路由页已经联动 + +### 6.2 还没有完成的部分 + +这版还只能算“结构收拢到位,视觉品质未到位”。 + +当前剩余问题主要在: + +1. 卡片密度和间距还偏粗糙 +2. 按钮层级还不够克制 +3. 深色主题的精致度不够 +4. 列表项的状态表达还不够漂亮 +5. 还没有做真正统一的视觉语言 +6. 宿主壳的老风格与新路由页之间还有轻微割裂 + +### 6.3 PR 线建议收敛目标 + +如果是继续在 GenericAgent 内优化 UI 并开 PR,建议只收敛到下面这个层级: + +1. 让 Qt 版 `路由` 页成为可长期使用的正式工作台 +2. 继续提升视觉精度,但不重做技术栈 +3. 保持与聊天页的高耦合联动 +4. 不追求“独立产品壳”,而追求“GenericAgent 内自然的一部分” + +这条线更适合做: + +- 视觉 polish +- 交互节奏优化 +- 卡片层级 +- 状态 chips +- 列表/详情的动线优化 +- 空状态和错误状态设计 + +不适合做: + +- 全新桌面容器 +- 新的跨进程通信层 +- 前后端分离重构 + +--- + +## 7. 独立仓库方向建议 + +### 7.1 为什么值得单独做 + +GenericAgent 内的 Qt 版更适合“贴近宿主能力的正式工作台”。 + +如果目标变成: + +- 更激进的视觉语言 +- 更完整的独立产品感 +- 更自由的 UI 架构 +- 更接近 `cc-switch` 乃至超过它的桌面体验 + +那单独起仓库是合理的。 + +### 7.2 独立仓库不应该重复造什么 + +独立仓库可以换壳,但不应该再重造下面这些东西: + +1. 路由语义 +2. provider / route / diagnostic event 数据结构 +3. active route / soft reload / model test 的语义 +4. `legacy mykey` 导入语义 + +更好的方式是: + +- 继续复用 `ga_switch` 的领域模型 +- 明确一层稳定的服务契约 +- 独立仓库只重做“视图”和“交互壳” + +### 7.3 独立仓库建议前提 + +在真正独立之前,建议先把下面两件事补齐: + +1. 把 `get_ui_snapshot()` 和视图模型字段进一步稳定下来 +2. 明确一份“前端契约文档”,避免新仓库反向绑死主仓库内部实现 + +换句话说,独立仓库应建立在“后端契约已稳定”的前提上,而不是直接复制现在的 Qt 页面。 + +--- + +## 8. 关键文件导览 + +### 8.1 直接相关文件 + +- [launch.pyw]() + - 入口与 Qt/Streamlit fallback +- [frontends/qtapp.py]() + - 桌面壳 + - 聊天页与路由页联动 +- [frontends/qt_switch.py]() + - Route Center 主体 +- [ga_switch/viewmodel.py]() + - 前端视图模型 +- [ga_switch/service.py]() + - 后端服务入口 +- [agentmain.py]() + - agent 与 ga_switch 的 runtime 绑定 +- [tests/test_ga_switch.py]() + - 结构化配置、视图模型、payload builder、reload 相关测试 + +### 8.2 参考文件 + +- [docs/CC_SWITCH_UI_AUDIT.md]() +- [README.md]() +- [GETTING_STARTED.md]() + +--- + +## 9. 当前已知问题与注意事项 + +### 9.1 本机环境问题 + +这台机器当前最实际的问题是: + +- GenericAgent 现有 `.venv` 中的 `PySide6` 不可靠 +- 所以本机 Qt 调试走的是 `pixi` + +这不是仓库设计目标,而是本机现实。 + +### 9.2 工作树状态 + +当前仓库是脏工作树,已有较多未提交改动,包括: + +- `ga_switch/` +- `frontends/qt_switch.py` +- `frontends/qtapp.py` +- `launch.pyw` +- `agentmain.py` +- `llmcore.py` +- 文档与测试文件 + +接手时不要假定工作树是干净的,也不要粗暴回滚。 + +### 9.3 Qt 页面还不是最终视觉稿 + +当前 Qt 路由页只能算: + +- 信息架构已进入正确方向 +- 可用性显著好于第一版 +- 但还远未达到“高完成度美术品质” + +因此: + +- 如果是继续开 PR,请聚焦在“收敛和 polish” +- 如果是单独仓库,请把它当成“产品语义样板”,而不是视觉最终稿 + +--- + +## 10. 建议工作流 + +### 10.1 GenericAgent 内继续优化并开 PR + +建议顺序: + +1. 先稳定 `qt_switch.py` 的视觉和交互 +2. 再统一 `qtapp.py` 宿主层的风格和用词 +3. 尽量把前端展示逻辑继续下沉到 `ga_switch/viewmodel.py` +4. 每次改完都回归: + - `tests.test_ga_switch` + - `tests.test_minimax` +5. 保持 `launch.pyw` 的 fallback 逻辑不破 + +### 10.2 单独仓库方向 + +建议顺序: + +1. 先冻结一版服务契约 +2. 明确“宿主模式”和“独立模式”的边界 +3. 再决定新仓库 UI 技术栈 +4. 优先复用现有 `ga_switch` 语义,而不是复制 Qt 页面代码 + +--- + +## 11. 建议验收清单 + +### 11.1 PR 线验收 + +至少保证: + +1. `launch.pyw` 在安装了 `PySide6` 时优先进 Qt +2. 缺 `PySide6` 时仍能回退到 Streamlit +3. 聊天页能明确看到当前路由、当前成员、最近错误 +4. 路由页能完成: + - 创建/编辑路由 + - 创建/编辑模型服务 + - 切换当前路由 + - 连通性测试 + - 软重载 + - 导入 `mykey` +5. 编辑时不应因高级设置未展开而抹掉已有高级值 +6. UI 文案不再混用英文工程词 + +### 11.2 独立仓库线验收 + +至少保证: + +1. 独立视图不重新发明路由语义 +2. 能读稳定契约 +3. 聊天联动语义不丢 +4. 最近错误、健康状态、备用链路语义完整保留 + +--- + +## 12. 已验证命令 + +### 12.1 仓库内测试 + +```powershell +python -m unittest tests.test_ga_switch tests.test_minimax +``` + +本轮最近一次回归结果: + +- `31` 个测试通过 + +### 12.2 本机 Qt 启动 + +```powershell +pixi exec --spec "python=3.11" --spec pyside6 --spec requests python frontends\qtapp.py +``` + +### 12.3 仓库标准入口 + +```powershell +python launch.pyw +``` + +--- + +## 13. 推荐交接结论 + +如果下一位工程师是专业前端工程师,建议这样切分任务: + +### 阶段 A:GenericAgent 内 PR + +目标: + +- 把 `Qt Route Center` 打磨到“可以合入、可以长期使用”的程度 + +关注点: + +- 视觉 polish +- 交互动线 +- 中文化一致性 +- 状态层级 +- 空状态 / 错误状态体验 + +不要做: + +- 新技术栈迁移 +- 新运行时依赖体系 +- 与宿主脱钩 + +### 阶段 B:独立仓库 + +目标: + +- 在不背离现有后端语义的前提下,做一个更自由、更高完成度的独立视图 + +关注点: + +- 更成熟的 UI 语言 +- 更完整的独立桌面感 +- 更高的视觉上限 + +不要做: + +- 重造路由内核 +- 重造测试 / 诊断 / 软重载语义 + +--- + +## 14. 一句话总结 + +当前 `GA Switch` 已经完成了从“底层能力存在”到“结构化产品工作台”的第一步。 + +下一位前端工程师的真正任务,不是从零发明一个模型切换器,而是: + +- 在 GenericAgent 内,把现有 Qt 工作台收敛成一个成熟、统一、可合入的正式前端; +- 在未来独立仓库里,把同一套语义做成更高上限的独立视图。 diff --git a/frontends/design_tokens.py b/frontends/design_tokens.py new file mode 100644 index 0000000..1cb23e5 --- /dev/null +++ b/frontends/design_tokens.py @@ -0,0 +1,98 @@ +""" +设计 Token 系统 +统一的字体、颜色、间距、圆角定义 +""" +from PySide6.QtGui import QColor + +# ============================================================================ +# 字体系统 +# ============================================================================ + +FONTS = { + "ui": "system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Microsoft YaHei', sans-serif", + "code": "Consolas, 'Courier New', monospace", +} + +FONT_SIZES = { + "xs": 11, + "sm": 12, + "md": 13, + "lg": 14, + "xl": 16, + "2xl": 20, +} + +FONT_WEIGHTS = { + "normal": 400, + "medium": 500, + "semibold": 600, + "bold": 700, +} + +# ============================================================================ +# 颜色系统(蓝色主题) +# ============================================================================ + +COLORS = { + # 背景 + "bg_base": QColor(14, 14, 18), + "bg_elevated": QColor(20, 20, 24), + "bg_overlay": QColor(28, 28, 32), + + # 边框 + "border_subtle": QColor(39, 39, 42), + "border_default": QColor(63, 63, 70), + "border_strong": QColor(82, 82, 91), + + # 文字 + "text_primary": "#f4f4f5", + "text_secondary": "#a1a1aa", + "text_tertiary": "#71717a", + + # 品牌色(蓝色) + "brand_50": "#eff6ff", + "brand_100": "#dbeafe", + "brand_500": "#3b82f6", + "brand_600": "#2563eb", + "brand_700": "#1d4ed8", + + # 语义色 + "success": "#22c55e", + "warning": "#f59e0b", + "error": "#ef4444", + "info": "#3b82f6", +} + +# ============================================================================ +# 间距系统(4pt 网格) +# ============================================================================ + +SPACING = { + "xs": 4, + "sm": 8, + "md": 12, + "lg": 16, + "xl": 24, + "2xl": 32, + "3xl": 48, +} + +# ============================================================================ +# 圆角系统 +# ============================================================================ + +RADIUS = { + "sm": 4, + "md": 6, + "lg": 8, +} + +# ============================================================================ +# 阴影系统 +# ============================================================================ + +SHADOWS = { + "sm": "0 1px 2px 0 rgba(0, 0, 0, 0.05)", + "md": "0 4px 6px -1px rgba(0, 0, 0, 0.1)", + "lg": "0 10px 15px -3px rgba(0, 0, 0, 0.1)", +} diff --git a/frontends/ga_switch_admin.py b/frontends/ga_switch_admin.py new file mode 100644 index 0000000..214de67 --- /dev/null +++ b/frontends/ga_switch_admin.py @@ -0,0 +1,14 @@ +import os +import sys + +script_dir = os.path.dirname(__file__) +if script_dir not in sys.path: + sys.path.append(script_dir) +repo_dir = os.path.abspath(os.path.join(script_dir, "..")) +if repo_dir not in sys.path: + sys.path.append(repo_dir) + +from ga_switch_ui import render_admin_page, setup_switch_page + +setup_switch_page("GA Switch Admin") +render_admin_page() diff --git a/frontends/ga_switch_ui.py b/frontends/ga_switch_ui.py new file mode 100644 index 0000000..6375a1c --- /dev/null +++ b/frontends/ga_switch_ui.py @@ -0,0 +1,685 @@ +import json +import os +import sys + +import streamlit as st + +script_dir = os.path.dirname(__file__) +repo_dir = os.path.abspath(os.path.join(script_dir, "..")) +if repo_dir not in sys.path: + sys.path.append(repo_dir) + +from ga_switch.models import PROVIDER_BACKEND_KINDS, ROUTE_KINDS +from ga_switch.viewmodel import build_ui_viewmodel +from shared_runtime import get_shared_runtime + +NAV_ITEMS = [ + ("routes", "Routes", "R"), + ("providers", "Providers", "P"), + ("diagnostics", "Diagnostics", "D"), + ("tests", "Tests", "T"), + ("runtime", "Runtime", "RT"), +] + + +def setup_switch_page(page_title="GA Switch Admin"): + st.set_page_config(page_title=page_title, layout="wide", initial_sidebar_state="collapsed") + inject_switch_css() + + +def inject_switch_css(): + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) + + +def _chip(text, tone=""): + klass = f"ga-chip {tone}".strip() + return f'{text}' + + +def _route_caption(route, runtime): + if route["kind"] == "single": + primary = (route["provider"] or {}).get("name") or "No provider" + else: + primary = " -> ".join(member["name"] for member in route["members"]) or "No members" + runtime_bits = [] + if runtime: + runtime_bits.append(runtime.get("backend_class") or "") + runtime_bits.append(runtime.get("model") or "") + runtime_text = " | ".join(bit for bit in runtime_bits if bit) + return primary if not runtime_text else f"{primary} | {runtime_text}" + + +def _provider_caption(provider): + bits = [ + provider.get("backend_kind"), + provider.get("model"), + provider.get("api_mode"), + provider.get("health", {}).get("status"), + ] + return " | ".join(bit for bit in bits if bit) + + +def _ensure_state(snapshot): + if "ga_admin_section" not in st.session_state: + st.session_state.ga_admin_section = "routes" + if snapshot["routes"] and st.session_state.get("ga_admin_selected_route") not in snapshot["routes_by_id"]: + st.session_state.ga_admin_selected_route = snapshot["routes"][0]["id"] + if snapshot["providers"] and st.session_state.get("ga_admin_selected_provider") not in snapshot["providers_by_id"]: + st.session_state.ga_admin_selected_provider = snapshot["providers"][0]["id"] + if snapshot["events"] and st.session_state.get("ga_admin_selected_event") not in {event["id"] for event in snapshot["events"]}: + st.session_state.ga_admin_selected_event = snapshot["events"][0]["id"] + + +def _nav_button(section_key, label, icon): + current = st.session_state.ga_admin_section == section_key + if st.button(f"{icon} {label}", key=f"ga_nav_{section_key}", use_container_width=True, type="primary" if current else "secondary"): + st.session_state.ga_admin_section = section_key + st.rerun() + + +def _render_left_nav(snapshot): + with st.container(border=True): + st.markdown("##### Workspace") + for section_key, label, icon in NAV_ITEMS: + _nav_button(section_key, label, icon) + st.markdown("---") + active = snapshot["active_route_summary"] + st.markdown( + f""" +
+

Current Route

+

{active.get("route_name") or "No active route"}

+

{active.get("provider_name") or "No provider"} | {active.get("model") or "No model"}

+ {_chip(active.get("route_kind") or "unknown", "blue")} + {_chip("native tools" if active.get("native_tools") else "text tools", "active" if active.get("native_tools") else "")} +
+ """, + unsafe_allow_html=True, + ) + if hasattr(st, "page_link"): + st.page_link("stapp.py", label="Open Chat", icon="💬") + else: + st.caption("Chat page lives in stapp.py") + + +def _render_routes_list(snapshot): + runtime_by_route = snapshot["runtime_by_route_id"] + for route in snapshot["routes"]: + runtime = runtime_by_route.get(route["id"]) + with st.container(border=True): + cols = st.columns([5, 1.2]) + cols[0].markdown( + f""" +
+

{route['name']}

+

{_route_caption(route, runtime)}

+ {_chip('active', 'active') if route['active'] else ''} + {_chip(route['kind'], 'blue')} + {_chip('default') if route['is_default'] else ''} + {_chip('disabled', 'warn') if not route['is_enabled'] else ''} +
+ """, + unsafe_allow_html=True, + ) + if cols[1].button("Open", key=f"ga_route_open_{route['id']}", use_container_width=True, type="primary" if st.session_state.ga_admin_selected_route == route["id"] else "secondary"): + st.session_state.ga_admin_selected_route = route["id"] + st.session_state.ga_admin_section = "routes" + st.rerun() + + +def _render_providers_list(snapshot): + for provider in snapshot["providers"]: + with st.container(border=True): + cols = st.columns([5, 1.2]) + cols[0].markdown( + f""" +
+

{provider['name']}

+

{_provider_caption(provider)}

+ {_chip('native', 'active') if provider['is_native'] else _chip('text')} + {_chip(provider['health']['status'] or 'unknown', 'active' if provider['health']['status'] == 'healthy' else 'warn' if provider['health']['status'] == 'degraded' else 'error' if provider['health']['status'] == 'failed' else '')} +
+ """, + unsafe_allow_html=True, + ) + if cols[1].button("Open", key=f"ga_provider_open_{provider['id']}", use_container_width=True, type="primary" if st.session_state.ga_admin_selected_provider == provider["id"] else "secondary"): + st.session_state.ga_admin_selected_provider = provider["id"] + st.session_state.ga_admin_section = "providers" + st.rerun() + + +def _render_events_list(snapshot): + for event in snapshot["recent_events"]: + title = event.get("backend_name") or f"Event {event['id']}" + tone = "error" if not event.get("ok") else "active" + subtitle = f"{event['created_at']} | {event.get('error_kind') or 'ok'} | {event.get('message') or ''}" + with st.container(border=True): + cols = st.columns([5, 1.2]) + cols[0].markdown( + f""" +
+

{title}

+

{subtitle[:160]}

+ {_chip('ok', 'active') if event.get('ok') else _chip(event.get('error_kind') or 'error', tone)} +
+ """, + unsafe_allow_html=True, + ) + if cols[1].button("Open", key=f"ga_event_open_{event['id']}", use_container_width=True, type="primary" if st.session_state.ga_admin_selected_event == event["id"] else "secondary"): + st.session_state.ga_admin_selected_event = event["id"] + st.session_state.ga_admin_section = "diagnostics" + st.rerun() + + +def _render_runtime_list(snapshot): + for item in snapshot["runtime"]: + with st.container(border=True): + cols = st.columns([5, 1.2]) + cols[0].markdown( + f""" +
+

{item['name']}

+

{item.get('provider_name') or 'No provider'} | {item.get('model') or 'No model'} | {item.get('backend_class') or ''}

+ {_chip('active', 'active') if item.get('active') else ''} + {_chip(item.get('route_kind') or 'single', 'blue')} + {_chip(item.get('active_member_name') or 'no active member')} +
+ """, + unsafe_allow_html=True, + ) + if cols[1].button("Focus", key=f"ga_runtime_open_{item['idx']}", use_container_width=True): + if item.get("route_id") is not None: + st.session_state.ga_admin_selected_route = item["route_id"] + st.session_state.ga_admin_section = "routes" + st.rerun() + + +def _render_middle_panel(snapshot): + section = st.session_state.ga_admin_section + st.markdown( + f""" +
+

Current Module

+

{section.title()}

+

List on the left, details and actions on the right.

+
+ """, + unsafe_allow_html=True, + ) + if section == "routes": + _render_routes_list(snapshot) + elif section == "providers": + _render_providers_list(snapshot) + elif section == "diagnostics": + _render_events_list(snapshot) + elif section == "tests": + _render_providers_list(snapshot) + else: + _render_runtime_list(snapshot) + + +def _selected_route(snapshot): + return snapshot["routes_by_id"].get(st.session_state.get("ga_admin_selected_route")) + + +def _selected_provider(snapshot): + return snapshot["providers_by_id"].get(st.session_state.get("ga_admin_selected_provider")) + + +def _selected_event(snapshot): + return next((event for event in snapshot["events"] if event["id"] == st.session_state.get("ga_admin_selected_event")), None) + + +def _render_route_detail(service, agent, snapshot): + route = _selected_route(snapshot) + if route is None: + st.info("No route selected.") + return + runtime = snapshot["runtime_by_route_id"].get(route["id"]) + st.markdown( + f""" +
+

Route Detail

+

{route['name']}

+

{_route_caption(route, runtime)}

+ {_chip('active', 'active') if route['active'] else ''} + {_chip(route['kind'], 'blue')} + {_chip(runtime.get('backend_class') if runtime else 'not mounted')} + {_chip(runtime.get('active_member_name') or 'no active member', 'warn' if route['kind'] == 'failover' else '') if runtime else ''} +
+ """, + unsafe_allow_html=True, + ) + action_cols = st.columns([1, 1, 1]) + if action_cols[0].button("Activate route", key=f"ga_activate_route_{route['id']}", use_container_width=True, type="primary"): + try: + agent.set_active_route(route["id"]) + st.rerun() + except Exception as exc: + st.error(str(exc)) + if action_cols[1].button("Soft reload", key=f"ga_reload_route_{route['id']}", use_container_width=True): + try: + agent.reload_llm_config(preserve_history=True) + st.rerun() + except Exception as exc: + st.error(str(exc)) + if action_cols[2].button("Delete route", key=f"ga_delete_route_{route['id']}", use_container_width=True): + try: + service.delete_route(route["id"]) + st.rerun() + except Exception as exc: + st.error(str(exc)) + + provider_options = {provider["id"]: provider["name"] for provider in snapshot["providers"]} + with st.form("ga_route_form"): + form_cols = st.columns(4) + route_name = form_cols[0].text_input("Route name", value=route.get("name", "")) + route_kind = form_cols[1].selectbox("Route kind", ROUTE_KINDS, index=ROUTE_KINDS.index(route.get("kind", "single"))) + is_default = form_cols[2].checkbox("Default route", value=bool(route.get("is_default", False))) + is_enabled = form_cols[3].checkbox("Enabled", value=bool(route.get("is_enabled", True))) + if route_kind == "single": + provider_index = next((idx for idx, provider in enumerate(snapshot["providers"]) if provider["id"] == ((route.get("provider") or {}).get("id"))), 0) if snapshot["providers"] else 0 + provider_id = st.selectbox("Provider", options=[provider["id"] for provider in snapshot["providers"]] or [0], index=provider_index, format_func=lambda pid: provider_options.get(pid, "No providers"), disabled=not snapshot["providers"]) + member_provider_ids = [] + else: + provider_id = None + member_provider_ids = st.multiselect("Failover members", options=[provider["id"] for provider in snapshot["providers"]], default=route.get("member_provider_ids", []), format_func=lambda pid: provider_options.get(pid, str(pid))) + form_cols = st.columns(3) + mixin_retries = form_cols[0].number_input("Failover max retries", value=int((route.get("config") or {}).get("max_retries", 3)), min_value=0) + mixin_delay = form_cols[1].number_input("Base delay", value=float((route.get("config") or {}).get("base_delay", 1.5)), min_value=0.0) + spring_back = form_cols[2].number_input("Spring back seconds", value=int((route.get("config") or {}).get("spring_back", 300)), min_value=0) + if st.form_submit_button("Save route", use_container_width=True, type="primary"): + try: + service.upsert_route({ + "id": route["id"], + "name": route_name.strip(), + "kind": route_kind, + "provider_id": provider_id, + "member_provider_ids": member_provider_ids, + "is_default": is_default, + "is_enabled": is_enabled, + "config": {"max_retries": int(mixin_retries), "base_delay": float(mixin_delay), "spring_back": int(spring_back)}, + }) + st.rerun() + except Exception as exc: + st.error(str(exc)) + + with st.form("ga_route_create_form"): + st.markdown("##### Create Route") + form_cols = st.columns(3) + create_name = form_cols[0].text_input("New route name", value="") + create_kind = form_cols[1].selectbox("New route kind", ROUTE_KINDS, index=0, key="ga_create_route_kind") + create_default = form_cols[2].checkbox("Set as default", value=False, key="ga_create_route_default") + if create_kind == "single": + create_provider_id = st.selectbox("Provider for new route", options=[provider["id"] for provider in snapshot["providers"]] or [0], format_func=lambda pid: provider_options.get(pid, "No providers"), disabled=not snapshot["providers"], key="ga_create_route_provider") + create_members = [] + else: + create_provider_id = None + create_members = st.multiselect("Members for new failover route", options=[provider["id"] for provider in snapshot["providers"]], format_func=lambda pid: provider_options.get(pid, str(pid)), key="ga_create_route_members") + if st.form_submit_button("Create route", use_container_width=True): + if create_name.strip(): + try: + service.upsert_route({ + "name": create_name.strip(), + "kind": create_kind, + "provider_id": create_provider_id, + "member_provider_ids": create_members, + "is_default": create_default, + "is_enabled": True, + "config": {"max_retries": 3, "base_delay": 1.5, "spring_back": 300}, + }) + st.rerun() + except Exception as exc: + st.error(str(exc)) + + +def _render_provider_detail(service, snapshot): + provider = _selected_provider(snapshot) + if provider is None: + st.info("No provider selected.") + return + health = provider["health"] + st.markdown( + f""" +
+

Provider Detail

+

{provider['name']}

+

{_provider_caption(provider)}

+ {_chip(health.get('status') or 'unknown', 'active' if health.get('status') == 'healthy' else 'warn' if health.get('status') == 'degraded' else 'error' if health.get('status') == 'failed' else '')} + {_chip('native', 'active') if provider['is_native'] else _chip('text')} + {_chip(f"latency {health.get('latency_ms') or '-'} ms")} +
+ """, + unsafe_allow_html=True, + ) + action_cols = st.columns([1, 1, 1]) + if action_cols[0].button("Run model test", key=f"ga_test_provider_{provider['id']}", use_container_width=True, type="primary"): + try: + st.session_state.ga_switch_last_test_result = service.run_model_test(provider["id"]) + st.rerun() + except Exception as exc: + st.error(str(exc)) + if action_cols[1].button("Delete provider", key=f"ga_delete_provider_{provider['id']}", use_container_width=True): + try: + service.delete_provider(provider["id"]) + st.rerun() + except Exception as exc: + st.error(str(exc)) + if action_cols[2].button("Copy as legacy JSON", key=f"ga_export_provider_{provider['id']}", use_container_width=True): + st.session_state.ga_switch_export_preview = json.dumps(service.export_legacy_config(), ensure_ascii=False, indent=2) + + with st.form("ga_provider_form"): + form_cols = st.columns(3) + name = form_cols[0].text_input("Name", value=provider.get("name", "")) + backend_kind = form_cols[1].selectbox("Backend kind", PROVIDER_BACKEND_KINDS, index=PROVIDER_BACKEND_KINDS.index(provider.get("backend_kind", "oai_text"))) + model = form_cols[2].text_input("Model", value=provider.get("model", "")) + form_cols = st.columns(3) + apikey = form_cols[0].text_input("API key", value=provider.get("apikey", ""), type="password") + apibase = form_cols[1].text_input("API base", value=provider.get("apibase", "")) + api_mode = form_cols[2].selectbox("API mode", ["chat_completions", "responses"], index=["chat_completions", "responses"].index(provider.get("api_mode", "chat_completions"))) + form_cols = st.columns(4) + temperature = form_cols[0].number_input("Temperature", value=float(provider.get("temperature", 1.0))) + max_tokens = form_cols[1].number_input("Max tokens", value=int(provider.get("max_tokens", 8192)), min_value=1) + timeout = form_cols[2].number_input("Connect timeout", value=int(provider.get("timeout", 5)), min_value=1) + read_timeout = form_cols[3].number_input("Read timeout", value=int(provider.get("read_timeout", 30)), min_value=1) + proxy = st.text_input("Proxy", value=provider.get("proxy", "") or "") + extra_json = st.text_area("Extra JSON", value=json.dumps(provider.get("extra", {}), ensure_ascii=False, indent=2), height=120) + if st.form_submit_button("Save provider", use_container_width=True, type="primary"): + try: + service.upsert_provider({ + "id": provider["id"], + "name": name.strip(), + "backend_kind": backend_kind, + "apikey": apikey.strip(), + "apibase": apibase.strip(), + "model": model.strip(), + "api_mode": api_mode, + "temperature": float(temperature), + "max_tokens": int(max_tokens), + "timeout": int(timeout), + "read_timeout": int(read_timeout), + "proxy": proxy.strip() or None, + "extra": json.loads(extra_json or "{}"), + }) + st.rerun() + except Exception as exc: + st.error(str(exc)) + + with st.form("ga_provider_create_form"): + st.markdown("##### Create Provider") + form_cols = st.columns(3) + new_name = form_cols[0].text_input("New provider name", value="") + new_kind = form_cols[1].selectbox("New backend kind", PROVIDER_BACKEND_KINDS, index=PROVIDER_BACKEND_KINDS.index("oai_text"), key="ga_new_provider_kind") + new_model = form_cols[2].text_input("New model", value="", key="ga_new_provider_model") + form_cols = st.columns(2) + new_key = form_cols[0].text_input("API key", value="", type="password", key="ga_new_provider_key") + new_base = form_cols[1].text_input("API base", value="", key="ga_new_provider_base") + if st.form_submit_button("Create provider", use_container_width=True): + if new_name.strip() and new_key.strip() and new_base.strip(): + try: + service.upsert_provider({ + "name": new_name.strip(), + "backend_kind": new_kind, + "apikey": new_key.strip(), + "apibase": new_base.strip(), + "model": new_model.strip(), + }) + st.rerun() + except Exception as exc: + st.error(str(exc)) + + if st.session_state.get("ga_switch_last_test_result"): + st.markdown("##### Last Test Result") + st.json(st.session_state["ga_switch_last_test_result"]) + if st.session_state.get("ga_switch_export_preview"): + st.markdown("##### Legacy Export Preview") + st.code(st.session_state["ga_switch_export_preview"], language="json") + + +def _render_diagnostics_detail(snapshot): + event = _selected_event(snapshot) + if event is None: + st.info("No diagnostic event selected.") + return + st.markdown( + f""" +
+

Diagnostic Event

+

{event.get('backend_name') or ('Event ' + str(event['id']))}

+

{event.get('message') or 'No message'}

+ {_chip('ok', 'active') if event.get('ok') else _chip(event.get('error_kind') or 'error', 'error')} + {_chip(str(event.get('status_code') or '-'))} + {_chip(event.get('created_at') or '')} +
+ """, + unsafe_allow_html=True, + ) + st.json(event) + + +def _render_runtime_detail(service, agent, snapshot): + active = snapshot["active_route_summary"] + st.markdown( + f""" +
+

Runtime

+

{active.get('route_name') or 'No active route'}

+

{active.get('provider_name') or 'No provider'} | {active.get('model') or 'No model'} | {active.get('backend_class') or 'No backend'}

+ {_chip(active.get('route_kind') or 'unknown', 'blue')} + {_chip(active.get('active_member_name') or 'no active member')} + {_chip(active.get('last_error_kind') or 'healthy', 'error' if active.get('last_error_kind') else 'active')} +
+ """, + unsafe_allow_html=True, + ) + control_cols = st.columns([1, 1, 1, 1]) + structured_value = control_cols[0].checkbox("Use structured config", value=snapshot["use_structured_config"]) + import_path = control_cols[1].text_input("Import mykey path", value="") + preserve_history = control_cols[2].checkbox("Preserve history on reload", value=True) + activate_options = [route["id"] for route in snapshot["routes"]] or [0] + activate_route_id = control_cols[3].selectbox("Activate route", options=activate_options, format_func=lambda rid: snapshot["routes_by_id"].get(rid, {}).get("name", "No routes"), disabled=not snapshot["routes"]) + + action_cols = st.columns([1, 1, 1, 1]) + if action_cols[0].button("Apply mode", use_container_width=True): + try: + service.set_structured_config_enabled(structured_value) + st.rerun() + except Exception as exc: + st.error(str(exc)) + if action_cols[1].button("Import legacy mykey", use_container_width=True): + try: + service.import_legacy_mykey(import_path or None) + st.rerun() + except Exception as exc: + st.error(str(exc)) + if action_cols[2].button("Activate selected route", use_container_width=True): + if snapshot["routes"]: + try: + agent.set_active_route(activate_route_id) + st.rerun() + except Exception as exc: + st.error(str(exc)) + if action_cols[3].button("Soft reload runtime", use_container_width=True, type="primary"): + try: + agent.reload_llm_config(preserve_history=preserve_history) + st.rerun() + except Exception as exc: + st.error(str(exc)) + st.dataframe(snapshot["runtime"], use_container_width=True, hide_index=True) + + +def _render_right_panel(service, agent, snapshot): + section = st.session_state.ga_admin_section + if section == "routes": + _render_route_detail(service, agent, snapshot) + elif section == "providers": + _render_provider_detail(service, snapshot) + elif section == "diagnostics": + _render_diagnostics_detail(snapshot) + elif section == "tests": + _render_provider_detail(service, snapshot) + else: + _render_runtime_detail(service, agent, snapshot) + + +def render_admin_page(): + service, agent = get_shared_runtime() + if agent.llmclient is None: + st.error("No LLM routes are available. Configure mykey.py or import structured routes first.") + st.stop() + snapshot = service.get_ui_snapshot(agent) + viewmodel = build_ui_viewmodel(snapshot) + _ensure_state(snapshot) + st.markdown( + f""" +
+

{viewmodel['summary']['headline'] or 'GA Switch'}

+

Streamlit workbench shaped after cc-switch: route center, diagnostics, tests, and chat linkage in one place.

+ {_chip(str(viewmodel['stats']['provider_count']) + ' providers', 'blue')} + {_chip(str(viewmodel['stats']['route_count']) + ' routes')} + {_chip('structured config', 'active' if viewmodel['use_structured_config'] else 'warn')} +
+ """, + unsafe_allow_html=True, + ) + left, middle, right = st.columns([0.9, 1.45, 2.05], gap="large") + with left: + _render_left_nav(snapshot) + with middle: + _render_middle_panel(snapshot) + with right: + _render_right_panel(service, agent, snapshot) diff --git a/frontends/pages/1_GA_Switch_Admin.py b/frontends/pages/1_GA_Switch_Admin.py new file mode 100644 index 0000000..e9a6716 --- /dev/null +++ b/frontends/pages/1_GA_Switch_Admin.py @@ -0,0 +1,15 @@ +import os +import sys + +script_dir = os.path.dirname(__file__) +frontends_dir = os.path.abspath(os.path.join(script_dir, "..")) +repo_dir = os.path.abspath(os.path.join(frontends_dir, "..")) +if frontends_dir not in sys.path: + sys.path.append(frontends_dir) +if repo_dir not in sys.path: + sys.path.append(repo_dir) + +from ga_switch_ui import render_admin_page, setup_switch_page + +setup_switch_page("GA Switch Admin") +render_admin_page() diff --git a/frontends/qt_switch.py b/frontends/qt_switch.py new file mode 100644 index 0000000..8ecc66d --- /dev/null +++ b/frontends/qt_switch.py @@ -0,0 +1,1351 @@ +from __future__ import annotations + +import json + +from PySide6.QtCore import Qt, Signal +from PySide6.QtWidgets import ( + QCheckBox, + QComboBox, + QFileDialog, + QFormLayout, + QFrame, + QHBoxLayout, + QLabel, + QLineEdit, + QListWidget, + QListWidgetItem, + QPlainTextEdit, + QPushButton, + QScrollArea, + QSpinBox, + QDoubleSpinBox, + QStackedWidget, + QVBoxLayout, + QWidget, +) + +from ga_switch.models import PROVIDER_BACKEND_KINDS, ROUTE_KINDS +from ga_switch.viewmodel import build_provider_payload, build_route_payload, build_ui_viewmodel + +# 设计令牌 +COLORS = { + # 背景 + "bg_primary": "#0a0a0e", + "bg_secondary": "#12121a", + "bg_tertiary": "#1a1a24", + # 文本 + "text_primary": "#ececf1", + "text_secondary": "#a1a1aa", + "text_muted": "#71717a", + # 边框 + "border_subtle": "#27272a", + "border_default": "#3f3f46", + # 强调色(蓝色系,更专业) + "accent": "#3b82f6", + "accent_hover": "#2563eb", + "accent_bg": "rgba(59, 130, 246, 0.12)", + "accent_border": "rgba(59, 130, 246, 0.3)", + # 状态色 + "success": "#10b981", + "success_bg": "rgba(16, 185, 129, 0.12)", + "warning": "#f59e0b", + "warning_bg": "rgba(245, 158, 11, 0.12)", + "error": "#ef4444", + "error_bg": "rgba(239, 68, 68, 0.12)", +} + +SPACING = { + "xs": "4px", + "sm": "8px", + "md": "12px", + "lg": "16px", + "xl": "24px", +} + +RADIUS = { + "sm": "6px", + "md": "8px", + "lg": "10px", +} + +ROUTE_KIND_LABELS = { + "single": "单路由", + "failover": "备用链路", +} + +PROVIDER_KIND_LABELS = { + "native_claude": "原生 Claude", + "native_oai": "原生 OpenAI", + "claude_text": "Claude 文本接口", + "oai_text": "OpenAI 文本接口", +} + +API_MODE_LABELS = { + "chat_completions": "聊天补全", + "responses": "响应接口", +} + + +CARD_STYLE = f""" +QFrame {{ + background: {COLORS['bg_secondary']}; + border: 1px solid {COLORS['border_subtle']}; + border-radius: {RADIUS['md']}; +}} +""" + +LIST_STYLE = f""" +QListWidget {{ + background: {COLORS['bg_primary']}; + border: 1px solid {COLORS['border_subtle']}; + outline: none; + color: {COLORS['text_primary']}; + border-radius: {RADIUS['md']}; + padding: 8px; +}} +QListWidget::item {{ + margin: 4px 0; + padding: 10px 12px; + border-radius: {RADIUS['sm']}; +}} +QListWidget::item:hover {{ + background: rgba(62, 62, 75, 0.72); +}} +QListWidget::item:selected {{ + background: {COLORS['accent_bg']}; + color: white; +}} +""" + +INPUT_STYLE = f""" +QLineEdit, QComboBox, QPlainTextEdit, QSpinBox, QDoubleSpinBox {{ + background: {COLORS['bg_primary']}; + color: {COLORS['text_primary']}; + border: 1px solid {COLORS['border_default']}; + border-radius: {RADIUS['sm']}; + padding: 6px 9px; + selection-background-color: {COLORS['accent_bg']}; +}} +QComboBox::drop-down {{ + border: none; + width: 24px; +}} +QComboBox QAbstractItemView {{ + background: #111217; + color: {COLORS['text_primary']}; + border: 1px solid {COLORS['border_default']}; + selection-background-color: {COLORS['accent_bg']}; +}} +QPlainTextEdit {{ + padding: 8px 10px; +}} +""" + +BUTTON_STYLE = f""" +QPushButton {{ + background: rgba(39, 39, 42, 0.82); + color: {COLORS['text_primary']}; + border: 1px solid {COLORS['border_default']}; + border-radius: {RADIUS['sm']}; + padding: 7px 12px; + font-weight: 600; +}} +QPushButton:hover {{ + background: rgba(63, 63, 70, 0.82); +}} +""" + +PRIMARY_BUTTON_STYLE = f""" +QPushButton {{ + background: {COLORS['accent_bg']}; + color: {COLORS['text_primary']}; + border: 1px solid {COLORS['accent_border']}; + border-radius: {RADIUS['sm']}; + padding: 7px 12px; + font-weight: 600; +}} +QPushButton:hover {{ + background: rgba(59, 130, 246, 0.18); + border: 1px solid rgba(59, 130, 246, 0.4); +}} +""" + +DANGER_BUTTON_STYLE = f""" +QPushButton {{ + background: {COLORS['error_bg']}; + color: #fee2e2; + border: 1px solid rgba(239, 68, 68, 0.3); + border-radius: {RADIUS['sm']}; + padding: 7px 12px; + font-weight: 600; +}} +QPushButton:hover {{ + background: rgba(239, 68, 68, 0.18); +}} +""" + +SECTION_STYLE = f""" +QPushButton {{ + background: rgba(20, 20, 24, 0.86); + color: {COLORS['text_secondary']}; + border: 1px solid {COLORS['border_subtle']}; + border-radius: {RADIUS['md']}; + padding: 9px 16px; + font-size: 13px; + font-weight: 700; +}} +QPushButton:hover {{ + background: rgba(45, 45, 56, 0.92); + color: #f4f4f5; +}} +""" + +SECTION_ACTIVE_STYLE = f""" +QPushButton {{ + background: {COLORS['accent_bg']}; + color: {COLORS['text_primary']}; + border: 1px solid {COLORS['accent_border']}; + border-radius: {RADIUS['md']}; + padding: 9px 16px; + font-size: 13px; + font-weight: 700; +}} +QPushButton:hover {{ + background: rgba(59, 130, 246, 0.18); +}} +""" + + +def _clear_layout(layout): + while layout.count(): + item = layout.takeAt(0) + widget = item.widget() + if widget is not None: + widget.deleteLater() + child = item.layout() + if child is not None: + _clear_layout(child) + + +def _card(title: str | None = None): + frame = QFrame() + frame.setStyleSheet(CARD_STYLE) + layout = QVBoxLayout(frame) + layout.setContentsMargins(16, 16, 16, 16) + layout.setSpacing(10) + if title: + label = QLabel(title) + label.setStyleSheet("color: #f4f4f5; font-size: 14px; font-weight: 700;") + layout.addWidget(label) + return frame, layout + + +def _title(text): + label = QLabel(text) + label.setWordWrap(True) + label.setStyleSheet(f"color: {COLORS['text_primary']}; font-size: 20px; font-weight: 700;") + return label + + +def _muted(text="", *, wrap=True): + label = QLabel(text) + label.setWordWrap(wrap) + label.setStyleSheet(f"color: {COLORS['text_secondary']}; font-size: 13px;") + return label + + +def _chip(text, tone="neutral"): + styles = { + "active": f"background: {COLORS['success_bg']}; color: #86efac; border: 1px solid rgba(16,185,129,0.3);", + "warn": f"background: {COLORS['warning_bg']}; color: #fcd34d; border: 1px solid rgba(245,158,11,0.3);", + "error": f"background: {COLORS['error_bg']}; color: #fca5a5; border: 1px solid rgba(239,68,68,0.3);", + "blue": f"background: {COLORS['accent_bg']}; color: #93c5fd; border: 1px solid {COLORS['accent_border']};", + "neutral": f"background: rgba(63,63,70,0.5); color: {COLORS['text_secondary']}; border: 1px solid {COLORS['border_default']};", + } + label = QLabel(text) + label.setStyleSheet( + f"QLabel {{ border-radius: {RADIUS['sm']}; padding: 4px 8px; font-size: 11px; font-weight: 600; " + + styles.get(tone, styles["neutral"]) + + " }" + ) + return label + + +def _button(text, *, primary=False, danger=False): + btn = QPushButton(text) + btn.setCursor(Qt.PointingHandCursor) + if danger: + btn.setStyleSheet(DANGER_BUTTON_STYLE) + elif primary: + btn.setStyleSheet(PRIMARY_BUTTON_STYLE) + else: + btn.setStyleSheet(BUTTON_STYLE) + return btn + + +def _apply_input_style(widget): + widget.setStyleSheet(INPUT_STYLE) + return widget + + +def _make_scroll_page(): + scroll = QScrollArea() + scroll.setWidgetResizable(True) + scroll.setFrameShape(QFrame.NoFrame) + scroll.setStyleSheet("QScrollArea { background: transparent; border: none; }") + container = QWidget() + layout = QVBoxLayout(container) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(12) + scroll.setWidget(container) + return scroll, layout + + +class _MemberOrderList(QListWidget): + def __init__(self, parent=None): + super().__init__(parent) + self.setStyleSheet(LIST_STYLE) + self.setMaximumHeight(180) + + def load_options(self, providers, selected_ids): + selected_ids = list(selected_ids or []) + provider_by_id = {provider["id"]: provider for provider in providers} + ordered = [provider_by_id[pid] for pid in selected_ids if pid in provider_by_id] + ordered.extend(provider for provider in providers if provider["id"] not in selected_ids) + self.clear() + for provider in ordered: + item = QListWidgetItem(f"{provider['name']} · {provider.get('model') or '未设置模型'}") + item.setData(Qt.UserRole, provider["id"]) + item.setFlags(item.flags() | Qt.ItemIsUserCheckable | Qt.ItemIsSelectable | Qt.ItemIsEnabled) + item.setCheckState(Qt.Checked if provider["id"] in selected_ids else Qt.Unchecked) + self.addItem(item) + + def selected_ids(self): + result = [] + for idx in range(self.count()): + item = self.item(idx) + if item.checkState() == Qt.Checked: + result.append(item.data(Qt.UserRole)) + return result + + def move_current(self, delta): + row = self.currentRow() + target = row + delta + if row < 0 or target < 0 or target >= self.count(): + return + item = self.takeItem(row) + self.insertItem(target, item) + self.setCurrentRow(target) + + +class RouteCenterPage(QWidget): + runtime_changed = Signal(object, object) + request_chat_focus = Signal() + + def __init__(self, agent, service, parent=None): + super().__init__(parent) + self.agent = agent + self.service = service + self.page_id = "overview" + self.snapshot = {} + self.viewmodel = {} + self._selected_route_id = None + self._selected_provider_id = None + self._selected_event_id = None + self._route_edit_open = False + self._route_create_open = False + self._route_advanced_open = False + self._route_create_advanced_open = False + self._provider_edit_open = False + self._provider_create_open = False + self._provider_advanced_open = False + self._provider_create_advanced_open = False + self._diagnostic_raw_open = False + self._overview_more_open = False + self._last_test_result = None + self._build_ui() + self.refresh_snapshot() + + def _build_ui(self): + root = QVBoxLayout(self) + root.setContentsMargins(16, 16, 16, 16) + root.setSpacing(12) + + summary_card, summary_layout = _card() + self._summary_title = _title("当前模型") + self._summary_meta = _muted("继续使用当前模型") + self._summary_notice = _muted("") + self._summary_notice.hide() + self._summary_chips = QHBoxLayout() + self._summary_chips.setSpacing(8) + self._summary_chips.addStretch() + summary_layout.addWidget(self._summary_title) + summary_layout.addWidget(self._summary_meta) + summary_layout.addLayout(self._summary_chips) + summary_layout.addWidget(self._summary_notice) + root.addWidget(summary_card) + + nav_row = QHBoxLayout() + nav_row.setSpacing(8) + self._page_buttons = {} + for page_id, label in self._page_defs(): + btn = QPushButton(label) + btn.setCursor(Qt.PointingHandCursor) + btn.clicked.connect(lambda _checked=False, pid=page_id: self._set_page(pid)) + nav_row.addWidget(btn) + self._page_buttons[page_id] = btn + nav_row.addStretch() + root.addLayout(nav_row) + + self._stack = QStackedWidget() + self._pages = {} + for page_id, _label in self._page_defs(): + scroll, layout = _make_scroll_page() + self._stack.addWidget(scroll) + self._pages[page_id] = {"widget": scroll, "layout": layout} + root.addWidget(self._stack, 1) + self._update_page_buttons() + + def _page_defs(self): + return ( + ("overview", "总览"), + ("routes", "全部路由"), + ("providers", "模型服务"), + ("diagnostics", "诊断记录"), + ) + + def _set_notice(self, text="", *, error=False): + if not text: + self._summary_notice.hide() + self._summary_notice.setText("") + return + self._summary_notice.setText(text) + self._summary_notice.setStyleSheet("color: #fecaca; font-size: 12px;" if error else "color: #bbf7d0; font-size: 12px;") + self._summary_notice.show() + + def _set_page(self, page_id): + if page_id == self.page_id: + return + self.page_id = page_id + self._update_page_buttons() + self._stack.setCurrentIndex(dict((pid, idx) for idx, (pid, _label) in enumerate(self._page_defs()))[page_id]) + + def _update_page_buttons(self): + for pid, btn in self._page_buttons.items(): + btn.setStyleSheet(SECTION_ACTIVE_STYLE if pid == self.page_id else SECTION_STYLE) + + def _run_action(self, func, success="", refresh=True): + try: + result = func() + except Exception as exc: + self._set_notice(str(exc), error=True) + return None + if refresh: + self.refresh_snapshot(success) + elif success: + self._set_notice(success) + return result + + def _ensure_selection(self): + route_ids = [item["id"] for item in self.viewmodel.get("routes", [])] + provider_ids = [item["id"] for item in self.viewmodel.get("providers", [])] + event_ids = [item["id"] for item in self.viewmodel.get("events", [])] + if self._selected_route_id not in route_ids: + self._selected_route_id = route_ids[0] if route_ids else None + self._route_edit_open = False + if self._selected_provider_id not in provider_ids: + self._selected_provider_id = provider_ids[0] if provider_ids else None + self._provider_edit_open = False + if self._selected_event_id not in event_ids: + self._selected_event_id = event_ids[0] if event_ids else None + + def refresh_snapshot(self, notice="", *, error=False): + if notice: + self._set_notice(notice, error=error) + self.snapshot = self.service.get_ui_snapshot(self.agent) + self.viewmodel = build_ui_viewmodel(self.snapshot) + self._ensure_selection() + self._refresh_summary() + self._render_overview_page() + self._render_routes_page() + self._render_providers_page() + self._render_diagnostics_page() + self.runtime_changed.emit(self.snapshot, self.viewmodel) + + def _refresh_summary(self): + summary = self.viewmodel.get("summary", {}) + self._summary_title.setText(summary.get("headline") or "当前模型") + self._summary_meta.setText(summary.get("meta") or "继续使用当前模型") + while self._summary_chips.count(): + item = self._summary_chips.takeAt(0) + widget = item.widget() + if widget is not None: + widget.deleteLater() + tone = summary.get("health_tone") or "neutral" + chips = [ + _chip(summary.get("route_kind_label") or "单路由", "blue"), + _chip(summary.get("health_label") or "状态未知", tone), + ] + if summary.get("active_member_name"): + chips.append(_chip(summary["active_member_name"], "active")) + if summary.get("native_tools"): + chips.append(_chip("原生工具", "blue")) + for widget in chips: + self._summary_chips.addWidget(widget) + self._summary_chips.addStretch() + + def _selected_route(self): + return next((item for item in self.viewmodel.get("routes", []) if item["id"] == self._selected_route_id), None) + + def _selected_provider(self): + return next((item for item in self.viewmodel.get("providers", []) if item["id"] == self._selected_provider_id), None) + + def _selected_event(self): + return next((item for item in self.viewmodel.get("events", []) if item["id"] == self._selected_event_id), None) + + def _render_overview_page(self): + layout = self._pages["overview"]["layout"] + _clear_layout(layout) + empty_state = self.viewmodel.get("empty_state") + overview = self.viewmodel.get("overview", {}) + + if empty_state: + current_card, current_layout = _card("当前模型") + current_layout.addWidget(_title(self.viewmodel["summary"].get("headline") or "当前模型")) + current_layout.addWidget(_muted(self.viewmodel["summary"].get("meta") or "继续使用当前模型")) + keep_btn = _button("继续使用当前模型", primary=True) + keep_btn.clicked.connect(self.request_chat_focus.emit) + current_layout.addWidget(keep_btn) + layout.addWidget(current_card) + + card, card_layout = _card("开始配置路由") + card_layout.addWidget(_title(empty_state["title"])) + card_layout.addWidget(_muted(empty_state["message"])) + action_row = QHBoxLayout() + action_row.setSpacing(8) + for action in empty_state["actions"]: + btn = _button(action["label"], primary=action.get("primary", False)) + btn.clicked.connect(lambda _checked=False, aid=action["id"]: self._handle_overview_action(aid, None, None, None)) + action_row.addWidget(btn) + action_row.addStretch() + card_layout.addLayout(action_row) + layout.addWidget(card) + layout.addStretch() + return + + current_card, current_layout = _card("当前路由") + current = overview.get("current_route_card", {}) + current_layout.addWidget(_title(current.get("headline") or "当前路由")) + current_layout.addWidget(_muted(current.get("subtitle") or "暂无说明")) + chip_row = QHBoxLayout() + chip_row.setSpacing(8) + chip_row.addWidget(_chip(current.get("status_label") or "状态未知", current.get("status_tone") or "neutral")) + for badge in current.get("badges", []): + if badge: + chip_row.addWidget(_chip(badge, "blue")) + chip_row.addStretch() + current_layout.addLayout(chip_row) + layout.addWidget(current_card) + + top_row = QHBoxLayout() + top_row.setSpacing(12) + + health_card, health_layout = _card("健康状态") + health = overview.get("health_card", {}) + health_layout.addWidget(_title(health.get("headline") or "状态未知")) + health_layout.addWidget(_muted(health.get("detail") or "暂无诊断信息")) + top_row.addWidget(health_card, 1) + + action_card, action_layout = _card("快捷操作") + route_combo = _apply_input_style(QComboBox()) + for route in self.viewmodel.get("routes", []): + route_combo.addItem(f"{route['title']} · {route['subtitle']}", route["id"]) + current_route_id = self.viewmodel["summary"].get("route_id") + idx = route_combo.findData(current_route_id) + if idx >= 0: + route_combo.setCurrentIndex(idx) + action_layout.addWidget(route_combo) + action_row = QHBoxLayout() + action_row.setSpacing(8) + actions = overview.get("quick_actions", []) + primary_done = False + for action in actions: + is_primary = action.get("primary", False) and not primary_done + if is_primary: + primary_done = True + btn = _button(action["label"], primary=is_primary) + btn.clicked.connect( + lambda _checked=False, aid=action["id"], combo=route_combo: self._handle_overview_action( + aid, combo.currentData(), None, None + ) + ) + action_row.addWidget(btn) + action_row.addStretch() + action_layout.addLayout(action_row) + if self._overview_more_open: + more = self._build_runtime_more_panel() + action_layout.addWidget(more) + top_row.addWidget(action_card, 1) + layout.addLayout(top_row) + + summary_card, summary_layout = _card("路由列表摘要") + for route in overview.get("route_summary_items", []): + row = QHBoxLayout() + row.setSpacing(8) + row.addWidget(_chip(route.get("status_label") or "待命", "active" if route.get("active") else "neutral")) + text = QLabel(f"{route['title']} · {route['subtitle']}") + text.setWordWrap(True) + text.setStyleSheet("color: #ececf1; font-size: 13px; font-weight: 600;") + row.addWidget(text, 1) + btn = _button("查看详情") + btn.clicked.connect(lambda _checked=False, rid=route["id"]: self._open_route_detail(rid)) + row.addWidget(btn) + summary_layout.addLayout(row) + all_btn = _button("查看全部路由") + all_btn.clicked.connect(lambda: self._set_page("routes")) + summary_layout.addWidget(all_btn) + layout.addWidget(summary_card) + layout.addStretch() + + def _render_routes_page(self): + layout = self._pages["routes"]["layout"] + _clear_layout(layout) + routes = self.viewmodel.get("routes", []) + + head, head_layout = _card("全部路由") + head_layout.addWidget(_muted("先看摘要,再决定是否展开编辑。")) + toolbar = QHBoxLayout() + toolbar.setSpacing(8) + create_btn = _button("新建路由", primary=True) + create_btn.clicked.connect(self._toggle_route_create) + toolbar.addWidget(create_btn) + toolbar.addStretch() + head_layout.addLayout(toolbar) + layout.addWidget(head) + + if not routes: + empty, empty_layout = _card() + empty_layout.addWidget(_title("还没有路由")) + empty_layout.addWidget(_muted("先导入 mykey 或新建路由。")) + layout.addWidget(empty) + if self._route_create_open: + layout.addWidget(self._build_route_form_panel(None, create_mode=True)) + layout.addStretch() + return + + list_card, list_layout = _card("路由列表") + list_widget = QListWidget() + list_widget.setStyleSheet(LIST_STYLE) + list_widget.setMaximumHeight(230) + current_row = 0 + for idx, route in enumerate(routes): + item = QListWidgetItem(f"{route['title']}\n{route['subtitle']} · {route['status_label']} · {route['health_label']}") + item.setData(Qt.UserRole, route["id"]) + list_widget.addItem(item) + if route["id"] == self._selected_route_id: + current_row = idx + list_widget.setCurrentRow(current_row) + list_widget.currentItemChanged.connect(lambda current, _previous: self._select_route_item(current)) + list_layout.addWidget(list_widget) + layout.addWidget(list_card) + + route = self._selected_route() + if route: + detail, detail_layout = _card("路由摘要") + detail_layout.addWidget(_title(route["title"])) + detail_layout.addWidget(_muted(route["subtitle"])) + chips = QHBoxLayout() + chips.setSpacing(8) + chips.addWidget(_chip(route["kind_label"], "blue")) + chips.addWidget(_chip(route["status_label"], "active" if route["active"] else "neutral")) + chips.addWidget(_chip(route["health_label"], route["health_tone"])) + if route.get("active_member_name"): + chips.addWidget(_chip(route["active_member_name"], "active")) + chips.addStretch() + detail_layout.addLayout(chips) + if route.get("last_error_message"): + detail_layout.addWidget(_muted(route["last_error_message"])) + action_row = QHBoxLayout() + action_row.setSpacing(8) + activate_btn = _button("设为当前") + activate_btn.clicked.connect(lambda: self._activate_route(route["id"], route["title"])) + edit_btn = _button("编辑路由") + edit_btn.clicked.connect(self._toggle_route_edit) + delete_btn = _button("删除路由", danger=True) + delete_btn.clicked.connect(lambda: self._delete_route(route["id"], route["title"])) + action_row.addWidget(activate_btn) + action_row.addWidget(edit_btn) + action_row.addWidget(delete_btn) + action_row.addStretch() + detail_layout.addLayout(action_row) + layout.addWidget(detail) + if self._route_edit_open: + layout.addWidget(self._build_route_form_panel(route["id"], create_mode=False)) + + if self._route_create_open: + layout.addWidget(self._build_route_form_panel(None, create_mode=True)) + layout.addStretch() + + def _render_providers_page(self): + layout = self._pages["providers"]["layout"] + _clear_layout(layout) + providers = self.viewmodel.get("providers", []) + + head, head_layout = _card("模型服务") + head_layout.addWidget(_muted("默认只显示摘要和高频动作,编辑时再展开完整配置。")) + toolbar = QHBoxLayout() + create_btn = _button("新建模型服务", primary=True) + create_btn.clicked.connect(self._toggle_provider_create) + toolbar.addWidget(create_btn) + toolbar.addStretch() + head_layout.addLayout(toolbar) + layout.addWidget(head) + + if not providers: + empty, empty_layout = _card() + empty_layout.addWidget(_title("还没有模型服务")) + empty_layout.addWidget(_muted("先新建模型服务,再创建路由。")) + layout.addWidget(empty) + if self._provider_create_open: + layout.addWidget(self._build_provider_form_panel(None, create_mode=True)) + layout.addStretch() + return + + list_card, list_layout = _card("模型服务列表") + list_widget = QListWidget() + list_widget.setStyleSheet(LIST_STYLE) + list_widget.setMaximumHeight(230) + current_row = 0 + for idx, provider in enumerate(providers): + item = QListWidgetItem(f"{provider['title']}\n{provider['subtitle']} · {provider['health_label']}") + item.setData(Qt.UserRole, provider["id"]) + list_widget.addItem(item) + if provider["id"] == self._selected_provider_id: + current_row = idx + list_widget.setCurrentRow(current_row) + list_widget.currentItemChanged.connect(lambda current, _previous: self._select_provider_item(current)) + list_layout.addWidget(list_widget) + layout.addWidget(list_card) + + provider = self._selected_provider() + if provider: + detail, detail_layout = _card("模型服务摘要") + detail_layout.addWidget(_title(provider["title"])) + detail_layout.addWidget(_muted(provider["subtitle"])) + chips = QHBoxLayout() + chips.setSpacing(8) + chips.addWidget(_chip(provider["health_label"], provider["health_tone"])) + chips.addWidget(_chip("原生工具" if provider["is_native"] else "文本接口", "blue")) + chips.addWidget(_chip(f"延迟 {provider.get('latency_ms') or '-'} ms")) + chips.addStretch() + detail_layout.addLayout(chips) + if provider.get("last_error"): + detail_layout.addWidget(_muted(provider["last_error"])) + if self._last_test_result and self._last_test_result.get("provider_id") == provider["id"]: + detail_layout.addWidget(_muted(f"最近测试:{self._last_test_result.get('status', '完成')}")) + action_row = QHBoxLayout() + test_btn = _button("连通性测试", primary=True) + test_btn.clicked.connect(lambda: self._run_model_test(provider["id"], provider["title"])) + edit_btn = _button("编辑模型服务") + edit_btn.clicked.connect(self._toggle_provider_edit) + delete_btn = _button("删除模型服务", danger=True) + delete_btn.clicked.connect(lambda: self._delete_provider(provider["id"], provider["title"])) + action_row.addWidget(test_btn) + action_row.addWidget(edit_btn) + action_row.addWidget(delete_btn) + action_row.addStretch() + detail_layout.addLayout(action_row) + layout.addWidget(detail) + if self._provider_edit_open: + layout.addWidget(self._build_provider_form_panel(provider["id"], create_mode=False)) + + if self._provider_create_open: + layout.addWidget(self._build_provider_form_panel(None, create_mode=True)) + layout.addStretch() + + def _render_diagnostics_page(self): + layout = self._pages["diagnostics"]["layout"] + _clear_layout(layout) + events = self.viewmodel.get("events", []) + + head, head_layout = _card("诊断记录") + head_layout.addWidget(_muted("默认只看摘要;需要时再展开原始详情。")) + layout.addWidget(head) + + if not events: + empty, empty_layout = _card() + empty_layout.addWidget(_title("还没有诊断记录")) + empty_layout.addWidget(_muted("调用模型、切换路由或连通性测试后,这里会出现记录。")) + layout.addWidget(empty) + layout.addStretch() + return + + list_card, list_layout = _card("事件列表") + list_widget = QListWidget() + list_widget.setStyleSheet(LIST_STYLE) + list_widget.setMaximumHeight(250) + current_row = 0 + for idx, event in enumerate(events): + prefix = "● " if event.get("tone") == "active" else "!" + item = QListWidgetItem(f"{prefix}{event['title']}\n{event['subtitle']}") + item.setData(Qt.UserRole, event["id"]) + list_widget.addItem(item) + if event["id"] == self._selected_event_id: + current_row = idx + list_widget.setCurrentRow(current_row) + list_widget.currentItemChanged.connect(lambda current, _previous: self._select_event_item(current)) + list_layout.addWidget(list_widget) + layout.addWidget(list_card) + + event = self._selected_event() + if event: + detail, detail_layout = _card("记录摘要") + detail_layout.addWidget(_title(event["title"])) + detail_layout.addWidget(_muted(event["subtitle"])) + chips = QHBoxLayout() + chips.setSpacing(8) + chips.addWidget(_chip(event.get("error_kind") or ("成功" if event.get("tone") == "active" else "异常"), "active" if event.get("tone") == "active" else "error")) + if event.get("status_code") is not None: + chips.addWidget(_chip(str(event["status_code"]), "blue")) + if event.get("created_at"): + chips.addWidget(_chip(event["created_at"])) + chips.addStretch() + detail_layout.addLayout(chips) + raw_btn = _button("查看原始详情" if not self._diagnostic_raw_open else "收起原始详情") + raw_btn.clicked.connect(self._toggle_diagnostic_raw) + detail_layout.addWidget(raw_btn) + if self._diagnostic_raw_open: + raw = QPlainTextEdit() + raw.setReadOnly(True) + raw.setMaximumHeight(220) + _apply_input_style(raw) + raw.setPlainText(json.dumps(event["payload"], ensure_ascii=False, indent=2)) + detail_layout.addWidget(raw) + layout.addWidget(detail) + layout.addStretch() + + def _handle_overview_action(self, action_id, route_id, _one, _two): + if action_id == "switch_route": + self._activate_route(route_id, "所选路由") + elif action_id == "soft_reload": + self._run_action(lambda: self.agent.reload_llm_config(preserve_history=True), success="已完成软重载。") + elif action_id == "import_legacy": + self._pick_and_import_legacy() + elif action_id == "create_provider": + self._provider_create_open = True + self._set_page("providers") + self._render_providers_page() + elif action_id == "continue_chat": + self.request_chat_focus.emit() + elif action_id == "more_actions": + self._overview_more_open = not self._overview_more_open + self._render_overview_page() + + def _build_runtime_more_panel(self): + panel, layout = _card("更多操作") + form = QFormLayout() + form.setLabelAlignment(Qt.AlignLeft) + form.setHorizontalSpacing(14) + form.setVerticalSpacing(10) + + structured_box = QCheckBox("启用结构化路由") + structured_box.setChecked(bool(self.snapshot.get("use_structured_config"))) + structured_box.setStyleSheet("color: #d4d4d8;") + + preserve_box = QCheckBox("软重载时保留上下文") + preserve_box.setChecked(True) + preserve_box.setStyleSheet("color: #d4d4d8;") + + import_path = _apply_input_style(QLineEdit("")) + browse_btn = _button("选择文件") + path_row = QWidget() + path_layout = QHBoxLayout(path_row) + path_layout.setContentsMargins(0, 0, 0, 0) + path_layout.setSpacing(8) + path_layout.addWidget(import_path, 1) + path_layout.addWidget(browse_btn) + browse_btn.clicked.connect(lambda: self._browse_import_path(import_path)) + + form.addRow("结构化路由", structured_box) + form.addRow("软重载", preserve_box) + form.addRow("导入文件", path_row) + layout.addLayout(form) + + row = QHBoxLayout() + apply_btn = _button("保存模式") + apply_btn.clicked.connect( + lambda: self._run_action( + lambda: self.service.set_structured_config_enabled(structured_box.isChecked()), + success="已更新结构化路由开关。", + ) + ) + import_btn = _button("导入并应用") + import_btn.clicked.connect( + lambda: self._run_action( + lambda: self.service.import_legacy_mykey(import_path.text().strip() or None), + success="已导入旧版配置。", + ) + ) + reload_btn = _button("立即软重载") + reload_btn.clicked.connect( + lambda: self._run_action( + lambda: self.agent.reload_llm_config(preserve_history=preserve_box.isChecked()), + success="已完成软重载。", + ) + ) + row.addWidget(apply_btn) + row.addWidget(import_btn) + row.addWidget(reload_btn) + row.addStretch() + layout.addLayout(row) + return panel + + def _build_route_form_panel(self, route_id, *, create_mode): + route = next((item for item in self.viewmodel.get("routes", []) if item["id"] == route_id), None) if route_id is not None else None + payload = self.snapshot.get("routes_by_id", {}).get(route_id) if route_id is not None else None + panel, layout = _card("新建路由" if create_mode else "编辑路由") + form = QFormLayout() + form.setLabelAlignment(Qt.AlignLeft) + form.setHorizontalSpacing(14) + form.setVerticalSpacing(10) + + name_edit = _apply_input_style(QLineEdit((route or {}).get("title", ""))) + kind_combo = _apply_input_style(QComboBox()) + for route_kind in ROUTE_KINDS: + kind_combo.addItem(ROUTE_KIND_LABELS.get(route_kind, route_kind), route_kind) + current_kind = (route or {}).get("kind", "single") + kind_idx = kind_combo.findData(current_kind) + kind_combo.setCurrentIndex(kind_idx if kind_idx >= 0 else 0) + + provider_combo = _apply_input_style(QComboBox()) + provider_combo.addItem("请选择模型服务", None) + for provider in self.snapshot.get("providers", []): + provider_combo.addItem(f"{provider['name']} · {provider.get('model') or '未设置模型'}", provider["id"]) + if route and route.get("provider_id") is not None: + idx = provider_combo.findData(route["provider_id"]) + if idx >= 0: + provider_combo.setCurrentIndex(idx) + + member_list = _MemberOrderList() + member_list.load_options(self.snapshot.get("providers", []), (route or {}).get("member_provider_ids", [])) + member_row = QHBoxLayout() + move_up = _button("上移") + move_up.clicked.connect(lambda: member_list.move_current(-1)) + move_down = _button("下移") + move_down.clicked.connect(lambda: member_list.move_current(1)) + member_row.addWidget(move_up) + member_row.addWidget(move_down) + member_row.addStretch() + + default_box = QCheckBox("设为默认路由") + default_box.setChecked(bool((route or {}).get("is_default", False))) + default_box.setStyleSheet("color: #d4d4d8;") + enabled_box = QCheckBox("启用此路由") + enabled_box.setChecked(bool((route or {}).get("enabled", True))) + enabled_box.setStyleSheet("color: #d4d4d8;") + + form.addRow("路由名称", name_edit) + form.addRow("路由类型", kind_combo) + form.addRow("主要模型服务", provider_combo) + flags = QWidget() + flags_layout = QHBoxLayout(flags) + flags_layout.setContentsMargins(0, 0, 0, 0) + flags_layout.addWidget(default_box) + flags_layout.addWidget(enabled_box) + flags_layout.addStretch() + form.addRow("路由状态", flags) + layout.addLayout(form) + + advanced_open = self._route_create_advanced_open if create_mode else self._route_advanced_open + advanced_btn = _button("展开高级设置" if not advanced_open else "收起高级设置") + advanced_btn.clicked.connect(lambda: self._toggle_route_advanced(create_mode)) + layout.addWidget(advanced_btn) + + if advanced_open: + adv, adv_layout = _card("高级设置") + hint = _muted("备用链路的成员顺序会按列表上下顺序生效。") + adv_layout.addWidget(hint) + single_box = QWidget() + single_layout = QVBoxLayout(single_box) + single_layout.setContentsMargins(0, 0, 0, 0) + single_layout.addWidget(_muted("单路由只使用上面的“主要模型服务”。")) + + failover_box = QWidget() + failover_layout = QVBoxLayout(failover_box) + failover_layout.setContentsMargins(0, 0, 0, 0) + failover_layout.setSpacing(8) + failover_layout.addWidget(_muted("勾选成员后可用“上移 / 下移”调整备用链路顺序。")) + failover_layout.addWidget(member_list) + failover_layout.addLayout(member_row) + + retries_spin = _apply_input_style(QSpinBox()) + retries_spin.setRange(0, 100) + retries_spin.setValue(int(((route or {}).get("config") or {}).get("max_retries", 3))) + delay_spin = _apply_input_style(QDoubleSpinBox()) + delay_spin.setRange(0.0, 999.0) + delay_spin.setDecimals(2) + delay_spin.setSingleStep(0.25) + delay_spin.setValue(float(((route or {}).get("config") or {}).get("base_delay", 1.5))) + spring_spin = _apply_input_style(QSpinBox()) + spring_spin.setRange(0, 86400) + spring_spin.setValue(int(((route or {}).get("config") or {}).get("spring_back", 300))) + + adv_form = QFormLayout() + adv_form.setLabelAlignment(Qt.AlignLeft) + adv_form.setHorizontalSpacing(14) + adv_form.setVerticalSpacing(10) + adv_form.addRow("备用链路成员", failover_box) + adv_form.addRow("最大重试次数", retries_spin) + adv_form.addRow("基础退避秒数", delay_spin) + adv_form.addRow("回弹时间(秒)", spring_spin) + adv_layout.addLayout(adv_form) + + def submit_values(): + try: + values = { + "name": name_edit.text(), + "kind": kind_combo.currentData(), + "is_default": default_box.isChecked(), + "is_enabled": enabled_box.isChecked(), + "provider_id": provider_combo.currentData(), + "member_provider_ids": member_list.selected_ids(), + "max_retries": retries_spin.value(), + "base_delay": delay_spin.value(), + "spring_back": spring_spin.value(), + } + return build_route_payload(values, route_id=(payload or {}).get("id")) + except Exception as exc: + self._set_notice(str(exc), error=True) + return None + + submit_builder = submit_values + layout.addWidget(adv) + else: + def submit_values(): + try: + existing_config = dict((payload or {}).get("config") or {}) + selected_kind = kind_combo.currentData() or "single" + values = { + "name": name_edit.text(), + "kind": selected_kind, + "is_default": default_box.isChecked(), + "is_enabled": enabled_box.isChecked(), + "provider_id": provider_combo.currentData() if selected_kind == "single" else None, + "member_provider_ids": [] if selected_kind == "single" else list((payload or {}).get("member_provider_ids") or []), + "max_retries": existing_config.get("max_retries", 3), + "base_delay": existing_config.get("base_delay", 1.5), + "spring_back": existing_config.get("spring_back", 300), + } + return build_route_payload(values, route_id=(payload or {}).get("id")) + except Exception as exc: + self._set_notice(str(exc), error=True) + return None + + submit_builder = submit_values + + def sync_kind(): + is_single = kind_combo.currentData() == "single" + provider_combo.setEnabled(True) + if not advanced_open: + return + member_list.setVisible(not is_single) + move_up.setVisible(not is_single) + move_down.setVisible(not is_single) + kind_combo.currentIndexChanged.connect(lambda _idx: sync_kind()) + sync_kind() + + actions = QHBoxLayout() + save_btn = _button("创建路由" if create_mode else "保存路由", primary=True) + save_btn.clicked.connect( + lambda: self._submit_route(submit_builder(), create_mode=create_mode) + ) + actions.addWidget(save_btn) + actions.addStretch() + layout.addLayout(actions) + return panel + + def _submit_route(self, payload, *, create_mode): + if not payload: + return + result = self._run_action( + lambda: self.service.upsert_route(payload), + success="已保存路由配置。若需立即生效,请执行软重载。", + refresh=False, + ) + if not result: + return + if create_mode: + self._route_create_open = False + self._route_create_advanced_open = False + else: + self._route_edit_open = False + self._route_advanced_open = False + self.refresh_snapshot() + + def _build_provider_form_panel(self, provider_id, *, create_mode): + provider_vm = next((item for item in self.viewmodel.get("providers", []) if item["id"] == provider_id), None) if provider_id is not None else None + provider = self.snapshot.get("providers_by_id", {}).get(provider_id) if provider_id is not None else None + panel, layout = _card("新建模型服务" if create_mode else "编辑模型服务") + form = QFormLayout() + form.setLabelAlignment(Qt.AlignLeft) + form.setHorizontalSpacing(14) + form.setVerticalSpacing(10) + + name_edit = _apply_input_style(QLineEdit((provider or {}).get("name", ""))) + kind_combo = _apply_input_style(QComboBox()) + for backend_kind in PROVIDER_BACKEND_KINDS: + kind_combo.addItem(PROVIDER_KIND_LABELS.get(backend_kind, backend_kind), backend_kind) + current_backend_kind = (provider or {}).get("backend_kind", "oai_text") + kind_idx = kind_combo.findData(current_backend_kind) + kind_combo.setCurrentIndex(kind_idx if kind_idx >= 0 else 0) + model_edit = _apply_input_style(QLineEdit((provider or {}).get("model", ""))) + base_edit = _apply_input_style(QLineEdit((provider or {}).get("apibase", ""))) + form.addRow("名称", name_edit) + form.addRow("接口类型", kind_combo) + form.addRow("模型 ID", model_edit) + form.addRow("接口地址", base_edit) + layout.addLayout(form) + + advanced_open = self._provider_create_advanced_open if create_mode else self._provider_advanced_open + advanced_btn = _button("展开高级设置" if not advanced_open else "收起高级设置") + advanced_btn.clicked.connect(lambda: self._toggle_provider_advanced(create_mode)) + layout.addWidget(advanced_btn) + + if advanced_open: + adv, adv_layout = _card("高级设置") + adv_form = QFormLayout() + adv_form.setLabelAlignment(Qt.AlignLeft) + adv_form.setHorizontalSpacing(14) + adv_form.setVerticalSpacing(10) + + apikey_edit = _apply_input_style(QLineEdit((provider or {}).get("apikey", ""))) + apikey_edit.setEchoMode(QLineEdit.Password) + api_mode_combo = _apply_input_style(QComboBox()) + for api_mode in ("chat_completions", "responses"): + api_mode_combo.addItem(API_MODE_LABELS.get(api_mode, api_mode), api_mode) + current_api_mode = (provider or {}).get("api_mode", "chat_completions") + api_mode_idx = api_mode_combo.findData(current_api_mode) + api_mode_combo.setCurrentIndex(api_mode_idx if api_mode_idx >= 0 else 0) + temperature_spin = _apply_input_style(QDoubleSpinBox()) + temperature_spin.setRange(0.0, 2.0) + temperature_spin.setDecimals(2) + temperature_spin.setSingleStep(0.1) + temperature_spin.setValue(float((provider or {}).get("temperature", 1.0))) + max_tokens_spin = _apply_input_style(QSpinBox()) + max_tokens_spin.setRange(1, 1000000) + max_tokens_spin.setValue(int((provider or {}).get("max_tokens", 8192))) + timeout_spin = _apply_input_style(QSpinBox()) + timeout_spin.setRange(1, 300) + timeout_spin.setValue(int((provider or {}).get("timeout", 5))) + read_timeout_spin = _apply_input_style(QSpinBox()) + read_timeout_spin.setRange(1, 600) + read_timeout_spin.setValue(int((provider or {}).get("read_timeout", 30))) + proxy_edit = _apply_input_style(QLineEdit((provider or {}).get("proxy", "") or "")) + extra_edit = _apply_input_style(QPlainTextEdit()) + extra_edit.setMaximumHeight(120) + extra_edit.setPlainText(json.dumps((provider or {}).get("extra", {}), ensure_ascii=False, indent=2)) + + adv_form.addRow("密钥", apikey_edit) + adv_form.addRow("请求模式", api_mode_combo) + adv_form.addRow("温度", temperature_spin) + adv_form.addRow("最大词元", max_tokens_spin) + adv_form.addRow("连接超时", timeout_spin) + adv_form.addRow("读取超时", read_timeout_spin) + adv_form.addRow("代理", proxy_edit) + adv_form.addRow("附加 JSON", extra_edit) + adv_layout.addLayout(adv_form) + layout.addWidget(adv) + + def submit_values(): + try: + values = { + "name": name_edit.text(), + "backend_kind": kind_combo.currentData(), + "apikey": apikey_edit.text(), + "apibase": base_edit.text(), + "model": model_edit.text(), + "api_mode": api_mode_combo.currentData(), + "temperature": temperature_spin.value(), + "max_tokens": max_tokens_spin.value(), + "timeout": timeout_spin.value(), + "read_timeout": read_timeout_spin.value(), + "proxy": proxy_edit.text(), + "extra": extra_edit.toPlainText(), + } + return build_provider_payload(values, provider_id=(provider or {}).get("id")) + except Exception as exc: + self._set_notice(str(exc), error=True) + return None + + submit_builder = submit_values + else: + def submit_values(): + try: + existing_provider = provider or {} + values = { + "name": name_edit.text(), + "backend_kind": kind_combo.currentData(), + "apikey": existing_provider.get("apikey", ""), + "apibase": base_edit.text(), + "model": model_edit.text(), + "api_mode": existing_provider.get("api_mode", "chat_completions"), + "temperature": existing_provider.get("temperature", 1.0), + "max_tokens": existing_provider.get("max_tokens", 8192), + "timeout": existing_provider.get("timeout", 5), + "read_timeout": existing_provider.get("read_timeout", 30), + "proxy": existing_provider.get("proxy", "") or "", + "extra": json.dumps(existing_provider.get("extra", {}), ensure_ascii=False), + } + return build_provider_payload(values, provider_id=(provider or {}).get("id")) + except Exception as exc: + self._set_notice(str(exc), error=True) + return None + + submit_builder = submit_values + + if provider_vm and provider_vm.get("last_error"): + layout.addWidget(_muted(f"最近错误:{provider_vm['last_error']}")) + + actions = QHBoxLayout() + save_btn = _button("创建模型服务" if create_mode else "保存模型服务", primary=True) + save_btn.clicked.connect(lambda: self._submit_provider(submit_builder(), create_mode=create_mode)) + actions.addWidget(save_btn) + actions.addStretch() + layout.addLayout(actions) + return panel + + def _submit_provider(self, payload, *, create_mode): + if not payload: + return + result = self._run_action( + lambda: self.service.upsert_provider(payload), + success="已保存模型服务。若需立即生效,请执行软重载。", + refresh=False, + ) + if not result: + return + if create_mode: + self._provider_create_open = False + self._provider_create_advanced_open = False + else: + self._provider_edit_open = False + self._provider_advanced_open = False + self.refresh_snapshot() + + def _activate_route(self, route_id, route_name): + self._run_action(lambda: self.agent.set_active_route(route_id), success=f"已切换到 {route_name}。") + + def _delete_route(self, route_id, route_name): + self._run_action(lambda: self.service.delete_route(route_id), success=f"已删除路由 {route_name}。") + + def _delete_provider(self, provider_id, provider_name): + self._run_action(lambda: self.service.delete_provider(provider_id), success=f"已删除模型服务 {provider_name}。") + + def _run_model_test(self, provider_id, provider_name): + def _action(): + result = self.service.run_model_test(provider_id) + result["provider_id"] = provider_id + self._last_test_result = result + return result + + result = self._run_action(_action, success=f"已完成 {provider_name} 的连通性测试。") + if result: + self._set_notice(f"连通性测试结果:{result.get('status', '完成')}") + + def _pick_and_import_legacy(self): + path, _ = QFileDialog.getOpenFileName( + self, + "选择旧版配置文件", + "", + "配置文件 (*.py *.json);;所有文件 (*)", + ) + if path: + self._run_action(lambda: self.service.import_legacy_mykey(path), success="已导入旧版配置。") + + def _browse_import_path(self, line_edit): + path, _ = QFileDialog.getOpenFileName( + self, + "选择旧版配置文件", + "", + "配置文件 (*.py *.json);;所有文件 (*)", + ) + if path: + line_edit.setText(path) + + def _open_route_detail(self, route_id): + self._selected_route_id = route_id + self._set_page("routes") + self._render_routes_page() + + def _toggle_route_edit(self): + self._route_create_open = False + self._route_create_advanced_open = False + self._route_edit_open = not self._route_edit_open + if not self._route_edit_open: + self._route_advanced_open = False + self._render_routes_page() + + def _toggle_route_create(self): + self._route_edit_open = False + self._route_advanced_open = False + self._route_create_open = not self._route_create_open + if not self._route_create_open: + self._route_create_advanced_open = False + self._render_routes_page() + + def _toggle_route_advanced(self, create_mode): + if create_mode: + self._route_create_advanced_open = not self._route_create_advanced_open + else: + self._route_advanced_open = not self._route_advanced_open + self._render_routes_page() + + def _toggle_provider_edit(self): + self._provider_create_open = False + self._provider_create_advanced_open = False + self._provider_edit_open = not self._provider_edit_open + if not self._provider_edit_open: + self._provider_advanced_open = False + self._render_providers_page() + + def _toggle_provider_create(self): + self._provider_edit_open = False + self._provider_advanced_open = False + self._provider_create_open = not self._provider_create_open + if not self._provider_create_open: + self._provider_create_advanced_open = False + self._render_providers_page() + + def _toggle_provider_advanced(self, create_mode): + if create_mode: + self._provider_create_advanced_open = not self._provider_create_advanced_open + else: + self._provider_advanced_open = not self._provider_advanced_open + self._render_providers_page() + + def _toggle_diagnostic_raw(self): + self._diagnostic_raw_open = not self._diagnostic_raw_open + self._render_diagnostics_page() + + def _select_route_item(self, current): + if current is None: + return + self._selected_route_id = current.data(Qt.UserRole) + self._route_edit_open = False + self._route_advanced_open = False + self._render_routes_page() + + def _select_provider_item(self, current): + if current is None: + return + self._selected_provider_id = current.data(Qt.UserRole) + self._provider_edit_open = False + self._provider_advanced_open = False + self._render_providers_page() + + def _select_event_item(self, current): + if current is None: + return + self._selected_event_id = current.data(Qt.UserRole) + self._diagnostic_raw_open = False + self._render_diagnostics_page() diff --git a/frontends/qtapp.py b/frontends/qtapp.py index ce69c34..fd15a3c 100644 --- a/frontends/qtapp.py +++ b/frontends/qtapp.py @@ -13,7 +13,7 @@ from PySide6.QtWidgets import ( QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, - QScrollArea, QFrame, QTextEdit, QStackedWidget, + QScrollArea, QFrame, QTextEdit, QStackedWidget, QComboBox, QListWidget, QListWidgetItem, QSizePolicy, QFileDialog, QSplitter, QTextBrowser, QApplication, QMessageBox, ) @@ -28,6 +28,8 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from agentmain import GeneraticAgent +from qt_switch import RouteCenterPage +from frontends.design_tokens import FONTS, FONT_SIZES, FONT_WEIGHTS, COLORS, SPACING, RADIUS # ══════════════════════════════════════════════════════════════════════ @@ -249,6 +251,53 @@ def mouseReleaseEvent(self, event): self._dragged = False self._drag_origin_global = None + def contextMenuEvent(self, event): + """Right-click menu for force stop and quit.""" + from PySide6.QtWidgets import QMenu + menu = QMenu(self) + menu.setStyleSheet(""" + QMenu { + background: rgba(20, 20, 24, 250); + border: 1px solid rgba(59, 130, 246, 0.3); + border-radius: 6px; + padding: 4px; + } + QMenu::item { + padding: 6px 20px; + color: #ececf1; + border-radius: 4px; + } + QMenu::item:selected { + background: rgba(59, 130, 246, 0.15); + } + """) + + stop_action = menu.addAction("⏹ 强制终止 Agent") + hide_action = menu.addAction("隐藏面板") + quit_action = menu.addAction("退出程序") + + action = menu.exec(event.globalPos()) + if action == stop_action: + self._force_stop_agent() + elif action == hide_action: + self.chat_panel.hide() + elif action == quit_action: + QApplication.quit() + + def _force_stop_agent(self): + """Force stop the running agent.""" + agent = getattr(self.chat_panel, "agent", None) + if agent: + # Stop agent if it has a stop method + if hasattr(agent, "stop"): + agent.stop() + # Clear streaming queue + if hasattr(self.chat_panel, "_display_queue"): + self.chat_panel._display_queue = None + # Reset streaming state + if hasattr(self.chat_panel, "_streaming_row"): + self.chat_panel._streaming_row = None + # ── Toggle panel ────────────────────────────────────── def _toggle(self): from PySide6.QtCore import QDateTime @@ -291,19 +340,19 @@ def _position_panel(self): MAX_INLINE_CHARS = 6000 C = { - "bg": QColor(14, 14, 18), - "panel": QColor(20, 20, 24, 248), - "border": QColor(45, 45, 50), - "accent": "#7c3aed", - "text": "#e4e4e7", - "muted": "#71717a", - "user_g0": QColor(79, 70, 229), - "user_g1": QColor(124, 58, 237), - "asst_bg": QColor(39, 39, 42, 210), - "asst_bdr": QColor(63, 63, 70), - "send_g0": QColor(220, 38, 38), - "send_g1": QColor(239, 68, 68), - "green": "#22c55e", + "bg": COLORS["bg_base"], + "panel": COLORS["bg_elevated"], + "border": COLORS["border_default"], + "accent": COLORS["brand_500"], + "text": COLORS["text_primary"], + "muted": COLORS["text_tertiary"], + "user_g0": QColor(59, 130, 246), + "user_g1": QColor(96, 165, 250), + "asst_bg": COLORS["bg_overlay"], + "asst_bdr": COLORS["border_default"], + "send_g0": QColor(59, 130, 246), + "send_g1": QColor(96, 165, 250), + "green": COLORS["success"], } SCROLLBAR_STYLE = """ @@ -320,6 +369,7 @@ def _position_panel(self): _SVG_CHAT = '' _SVG_CLOCK = '' _SVG_BOOK = '' +_SVG_ROUTE = '' _SVG_GEAR = '' _SVG_CLIP = '' _SVG_STOP = '' @@ -334,27 +384,27 @@ def _position_panel(self): _SVG_SEND = '' _SVG_PLUS = '' -_MD_CSS = """ -body { color: #e4e4e7; font-family: "Arial", "Microsoft YaHei", sans-serif; font-size: 13px; line-height: 1.6; font-weight: 400; } -h1 { color: #f4f4f5; font-size: 20px; font-weight: 700; border-bottom: 1px solid #3f3f46; padding-bottom: 4px; margin-top: 16px; } -h2 { color: #f4f4f5; font-size: 17px; font-weight: 700; border-bottom: 1px solid #3f3f46; padding-bottom: 3px; margin-top: 14px; } -h3 { color: #f4f4f5; font-size: 15px; font-weight: 600; margin-top: 12px; } -h4,h5,h6 { color: #d4d4d8; font-size: 13px; font-weight: 600; margin-top: 10px; } -code { background: rgba(63,63,70,0.6); color: #c4b5fd; padding: 1px 4px; border-radius: 3px; - font-family: Consolas, "Courier New", monospace; font-size: 12px; } -pre { background: rgba(24,24,30,0.95); border: 1px solid #3f3f46; border-radius: 6px; - padding: 10px 12px; margin: 8px 0; } -pre code { background: transparent; padding: 0; color: #d4d4d8; } -a { color: #818cf8; text-decoration: none; } -a:hover { text-decoration: underline; } -blockquote { border-left: 3px solid #7c3aed; margin: 8px 0 8px 0; padding: 4px 0 4px 12px; color: #a1a1aa; } -table { border-collapse: collapse; margin: 8px 0; } -th, td { border: 1px solid #3f3f46; padding: 5px 10px; } -th { background: rgba(63,63,70,0.35); color: #d4d4d8; font-weight: 700; } -hr { border: none; border-top: 1px solid #3f3f46; margin: 12px 0; } -ul, ol { padding-left: 22px; margin: 4px 0; } -li { margin: 2px 0; } -p { margin: 6px 0; } +_MD_CSS = f""" +body {{ color: {COLORS['text_primary']}; font-family: {FONTS['ui']}; font-size: {FONT_SIZES['md']}px; line-height: 1.6; font-weight: {FONT_WEIGHTS['normal']}; }} +h1 {{ color: {COLORS['text_primary']}; font-size: {FONT_SIZES['2xl']}px; font-weight: {FONT_WEIGHTS['bold']}; border-bottom: 1px solid {COLORS['border_default'].name()}; padding-bottom: 4px; margin-top: 16px; }} +h2 {{ color: {COLORS['text_primary']}; font-size: {FONT_SIZES['xl']}px; font-weight: {FONT_WEIGHTS['bold']}; border-bottom: 1px solid {COLORS['border_default'].name()}; padding-bottom: 3px; margin-top: 14px; }} +h3 {{ color: {COLORS['text_primary']}; font-size: {FONT_SIZES['lg']}px; font-weight: {FONT_WEIGHTS['semibold']}; margin-top: 12px; }} +h4,h5,h6 {{ color: {COLORS['text_secondary']}; font-size: {FONT_SIZES['md']}px; font-weight: {FONT_WEIGHTS['semibold']}; margin-top: 10px; }} +code {{ background: rgba(63,63,70,0.6); color: #c4b5fd; padding: 1px 4px; border-radius: {RADIUS['sm']}px; + font-family: {FONTS['code']}; font-size: {FONT_SIZES['sm']}px; }} +pre {{ background: rgba(24,24,30,0.95); border: 1px solid {COLORS['border_default'].name()}; border-radius: {RADIUS['md']}px; + padding: 10px 12px; margin: 8px 0; }} +pre code {{ background: transparent; padding: 0; color: {COLORS['text_secondary']}; }} +a {{ color: #818cf8; text-decoration: none; }} +a:hover {{ text-decoration: underline; }} +blockquote {{ border-left: 3px solid {COLORS['brand_500']}; margin: 8px 0 8px 0; padding: 4px 0 4px 12px; color: {COLORS['text_secondary']}; }} +table {{ border-collapse: collapse; margin: 8px 0; }} +th, td {{ border: 1px solid {COLORS['border_default'].name()}; padding: 5px 10px; }} +th {{ background: rgba(63,63,70,0.35); color: {COLORS['text_secondary']}; font-weight: {FONT_WEIGHTS['bold']}; }} +hr {{ border: none; border-top: 1px solid {COLORS['border_default'].name()}; margin: 12px 0; }} +ul, ol {{ padding-left: 22px; margin: 4px 0; }} +li {{ margin: 2px 0; }} +p {{ margin: 6px 0; }} """ @@ -531,8 +581,8 @@ class _StreamingBadge(QLabel): def __init__(self, parent=None): super().__init__("处理中…", parent) self.setStyleSheet( - "QLabel { background: rgba(124,58,237,0.18); color: #c4b5fd;" - " border: 1px solid rgba(124,58,237,0.35); border-radius: 9px;" + "QLabel { background: rgba(59,130,246,0.15); color: #60a5fa;" + " border: 1px solid rgba(59,130,246,0.3); border-radius: 9px;" " padding: 1px 8px; font-size: 11px; }" ) self.hide() @@ -714,7 +764,7 @@ class _TabButton(QPushButton): background: rgba(63,63,70,0.6); color: {text}; }} QPushButton:checked {{ - background: #7c3aed; color: white; + background: #3b82f6; color: white; }} """.format(muted=C["muted"], text=C["text"]) @@ -733,14 +783,23 @@ def _action_btn(label: str, color: str, icon: QIcon | None = None) -> QPushButto btn.setFixedHeight(36) btn.setStyleSheet(f""" QPushButton {{ - background: rgba(35,35,40,0.8); color: {C['text']}; - border: 1px solid {C['border'].name()}; - border-left: 3px solid {color}; - border-radius: 8px; padding: 0 14px; - font-size: 13px; font-weight: 700; text-align: left; + background: {COLORS['bg_elevated'].name()}; + color: {COLORS['text_primary']}; + border: 1px solid {COLORS['border_default'].name()}; + border-radius: {RADIUS['sm']}px; + padding: 0 {SPACING['md']}px; + font-size: {FONT_SIZES['md']}px; + font-weight: {FONT_WEIGHTS['semibold']}; + text-align: left; + }} + QPushButton:hover {{ + background: {COLORS['bg_overlay'].name()}; + }} + QPushButton:checked {{ + background: rgba(59,130,246,0.15); + border-color: {color}; + color: {color}; }} - QPushButton:hover {{ background: rgba(55,55,62,0.9); }} - QPushButton:checked {{ color: {color}; background: rgba(35,35,40,0.95); }} """) return btn @@ -752,13 +811,16 @@ class ChatPanel(QWidget): def __init__(self, agent): super().__init__() self.agent = agent + self._switch_service = getattr(agent, "ga_switch", None) + self._switch_snapshot = {} + self._switch_vm = {} + self._route_combo_updating = False # session state self._messages: list[dict] = [] self._session = {"id": _make_session_id(), "title": "新对话", "messages": []} self._history: list[dict] = _load_history() self._pending_files: list[dict] = [] # {'name','type','raw'} - self._settings_health_checked = False # streaming state self._display_queue: Optional[_queue.Queue] = None @@ -776,32 +838,41 @@ def __init__(self, agent): Qt.FramelessWindowHint | Qt.WindowStaysOnTopHint | Qt.Tool ) self.setAttribute(Qt.WA_TranslucentBackground) + self.setMinimumSize(400, 500) self.resize(530, 700) # drag state (title bar) self._drag_pos: Optional[QPoint] = None + # resize grip + from PySide6.QtWidgets import QSizeGrip + self._resize_grip = QSizeGrip(self) + self._resize_grip.setStyleSheet("QSizeGrip { background: transparent; width: 16px; height: 16px; }") + self._build_ui() + self._refresh_switch_state() def paintEvent(self, _event): p = QPainter(self) p.setRenderHint(QPainter.Antialiasing) path = QPainterPath() path.addRoundedRect(0.5, 0.5, self.width() - 1.0, self.height() - 1.0, - 20.0, 20.0) + float(RADIUS["lg"]), float(RADIUS["lg"])) grad = QLinearGradient(0, 0, 0, self.height()) grad.setColorAt(0.0, QColor(20, 20, 28, 228)) grad.setColorAt(1.0, QColor(10, 10, 14, 242)) p.fillPath(path, grad) - p.setPen(QPen(QColor(99, 102, 241, 80), 1.0)) + p.setPen(QPen(QColor(59, 130, 246, 80), 1.0)) p.drawPath(path) def resizeEvent(self, event): - path = QPainterPath() - path.addRoundedRect(0, 0, float(self.width()), float(self.height()), - 20.0, 20.0) - self.setMask(QRegion(path.toFillPolygon().toPolygon())) + # 不使用 setMask,让 QSizeGrip 正常工作 super().resizeEvent(event) + # Position resize grip at bottom-right + if hasattr(self, '_resize_grip'): + self._resize_grip.move(self.width() - 16, self.height() - 16) + self._resize_grip.raise_() + self._resize_grip.show() # ── UI construction ─────────────────────────────────────────────────────── def _build_ui(self): @@ -819,7 +890,8 @@ def _build_ui(self): self._stack.addWidget(self._build_chat_page()) # 0 self._stack.addWidget(self._build_history_page()) # 1 self._stack.addWidget(self._build_sop_page()) # 2 - self._stack.addWidget(self._build_settings_page())# 3 + self._stack.addWidget(self._build_route_center_page()) # 3 + self._stack.addWidget(self._build_settings_page())# 4 root.addWidget(self._stack) # Now that _stack exists, activate the first tab @@ -856,12 +928,23 @@ def _build_titlebar(self) -> QWidget: ly.addStretch() + # Minimize button + minimize = QPushButton("−") + minimize.setFixedSize(26, 26) + minimize.setStyleSheet(""" + QPushButton { background: rgba(63,63,70,0.6); color: #a1a1aa; + border: none; border-radius: 6px; font-size: 18px; font-weight: bold; } + QPushButton:hover { background: rgba(63,63,70,0.9); color: white; } + """) + minimize.clicked.connect(self.showMinimized) + ly.addWidget(minimize) + # Close button close = QPushButton("×") close.setFixedSize(26, 26) close.setStyleSheet(""" QPushButton { background: rgba(63,63,70,0.6); color: #a1a1aa; - border: none; border-radius: 13px; font-size: 15px; font-weight: bold; } + border: none; border-radius: 6px; font-size: 15px; font-weight: bold; } QPushButton:hover { background: rgba(220,38,38,0.85); color: white; } """) close.clicked.connect(self.hide) @@ -898,7 +981,8 @@ def _build_tabbar(self) -> QWidget: tab_defs = [ (_SVG_CHAT, "对话"), (_SVG_CLOCK, "历史"), - (_SVG_BOOK, "SOP"), + (_SVG_BOOK, "手册"), + (_SVG_ROUTE, "路由"), (_SVG_GEAR, "设置"), ] for i, (svg, text) in enumerate(tab_defs): @@ -916,10 +1000,10 @@ def _build_tabbar(self) -> QWidget: new_btn.setIconSize(QSize(12, 12)) new_btn.setFixedHeight(27) new_btn.setStyleSheet(f""" - QPushButton {{ background: rgba(124,58,237,0.18); color: #a78bfa; - border: 1px solid rgba(124,58,237,0.3); border-radius: 7px; - padding: 0 10px; font-size: 12px; font-weight: 700; }} - QPushButton:hover {{ background: rgba(124,58,237,0.35); color: white; }} + QPushButton {{ background: rgba(59,130,246,0.15); color: #60a5fa; + border: 1px solid rgba(59,130,246,0.3); border-radius: {RADIUS['md']}px; + padding: 0 10px; font-size: {FONT_SIZES['sm']}px; font-weight: {FONT_WEIGHTS['bold']}; }} + QPushButton:hover {{ background: rgba(59,130,246,0.25); color: white; }} """) new_btn.clicked.connect(self._new_session) ly.addWidget(new_btn) @@ -936,10 +1020,180 @@ def _switch_tab(self, idx: int): if idx == 2: self._refresh_sop() if idx == 3: - self._refresh_model_rows_style() - if not self._settings_health_checked: - self._start_health_checks() - self._settings_health_checked = True + self._route_center.refresh_snapshot() + if idx == 4: + self._refresh_session_controls() + + def _build_route_center_page(self) -> QWidget: + self._route_center = RouteCenterPage(self.agent, self._switch_service or self.agent.ga_switch) + self._route_center.runtime_changed.connect(self._on_route_center_runtime_changed) + self._route_center.request_chat_focus.connect(lambda: self._switch_tab(0)) + return self._route_center + + def _build_chat_route_bar(self) -> QWidget: + wrap = QWidget() + wrap.setStyleSheet("background: rgba(10,10,14,0.52);") + ly = QHBoxLayout(wrap) + ly.setContentsMargins(18, 12, 18, 12) + ly.setSpacing(12) + + info = QVBoxLayout() + info.setSpacing(3) + self._chat_route_title = QLabel("当前模型") + self._chat_route_title.setStyleSheet("color: #f4f4f5; font-size: 13px; font-weight: 700;") + self._chat_route_meta = QLabel("继续使用当前模型") + self._chat_route_meta.setStyleSheet("color: #a1a1aa; font-size: 12px;") + self._chat_route_error = QLabel("") + self._chat_route_error.setWordWrap(True) + self._chat_route_error.setStyleSheet("color: #fca5a5; font-size: 11px;") + info.addWidget(self._chat_route_title) + info.addWidget(self._chat_route_meta) + info.addWidget(self._chat_route_error) + ly.addLayout(info, 1) + + controls = QVBoxLayout() + controls.setSpacing(6) + self._chat_route_select = QComboBox() + self._chat_route_select.setStyleSheet(""" + QComboBox { + background: rgba(24,24,30,0.92); + color: #f4f4f5; + border: 1px solid #3f3f46; + border-radius: 8px; + padding: 6px 10px; + min-width: 230px; + } + QComboBox::drop-down { border: none; width: 22px; } + QComboBox QAbstractItemView { + background: #111217; + color: #e4e4e7; + border: 1px solid #3f3f46; + selection-background-color: rgba(59,130,246,0.25); + } + """) + self._chat_route_select.currentIndexChanged.connect(self._on_chat_route_selected) + controls.addWidget(self._chat_route_select) + + route_btn_row = QHBoxLayout() + route_btn_row.setSpacing(6) + self._chat_route_member = _Badge("member") + self._chat_route_member.hide() + route_btn_row.addWidget(self._chat_route_member) + route_btn_row.addStretch() + open_route_btn = QPushButton("打开路由页") + open_route_btn.setCursor(QCursor(Qt.PointingHandCursor)) + open_route_btn.setStyleSheet(f""" + QPushButton {{ background: rgba(59,130,246,0.15); color: #60a5fa; + border: 1px solid rgba(59,130,246,0.3); border-radius: {RADIUS['lg']}px; + padding: 5px 10px; font-size: {FONT_SIZES['sm']}px; font-weight: {FONT_WEIGHTS['bold']}; }} + QPushButton:hover {{ background: rgba(59,130,246,0.25); color: white; }} + """) + open_route_btn.clicked.connect(lambda: self._switch_tab(3)) + route_btn_row.addWidget(open_route_btn) + controls.addLayout(route_btn_row) + ly.addLayout(controls, 0) + return wrap + + def _available_route_targets(self): + routes = self._switch_snapshot.get("routes") or [] + if routes and self.agent.config_source == "store": + return [ + { + "value": route["id"], + "label": f"{route['name']} | {((route.get('provider') or {}).get('name') or 'failover')}", + } + for route in routes + ] + described = self.agent.describe_llms() if hasattr(self.agent, "describe_llms") else [] + return [ + { + "value": item["idx"], + "label": item["display_name"], + } + for item in described + ] + + def _refresh_switch_state(self): + if self._switch_service is not None: + self._switch_snapshot = self._switch_service.get_ui_snapshot(self.agent) + try: + from ga_switch.viewmodel import build_ui_viewmodel + self._switch_vm = build_ui_viewmodel(self._switch_snapshot) + except Exception: + self._switch_vm = {} + else: + self._switch_snapshot = {} + self._switch_vm = {} + self._refresh_route_badges() + self._refresh_chat_route_controls() + if hasattr(self, "_session_runtime_info"): + self._refresh_session_controls() + + def _refresh_route_badges(self): + summary = (self._switch_vm or {}).get("summary") or {} + headline = summary.get("headline") or self._model_name() + meta = summary.get("meta") or "继续使用当前模型" + self._model_badge.setText(headline) + self._model_badge.setToolTip(meta) + if hasattr(self, "_chat_route_title"): + self._chat_route_title.setText(headline) + self._chat_route_meta.setText(meta) + has_error = bool(summary.get("last_error_message")) + err = summary.get("last_error_message") or "最近没有错误" + self._chat_route_error.setText(err) + self._chat_route_error.setStyleSheet( + "color: #fca5a5; font-size: 11px;" if has_error else "color: #71717a; font-size: 11px;" + ) + active_member = summary.get("active_member_name") + if active_member: + self._chat_route_member.setText(active_member) + self._chat_route_member.show() + else: + self._chat_route_member.hide() + + def _refresh_chat_route_controls(self): + if not hasattr(self, "_chat_route_select"): + return + options = self._available_route_targets() + summary = (self._switch_vm or {}).get("summary") or {} + current_value = summary.get("route_id") + if current_value is None: + active_runtime = next((item for item in self.agent.describe_llms() if item.get("active")), None) if hasattr(self.agent, "describe_llms") else None + current_value = (active_runtime or {}).get("idx") + + self._route_combo_updating = True + self._chat_route_select.blockSignals(True) + self._chat_route_select.clear() + selected_idx = 0 + for idx, item in enumerate(options): + self._chat_route_select.addItem(item["label"], item["value"]) + if item["value"] == current_value: + selected_idx = idx + if options: + self._chat_route_select.setCurrentIndex(selected_idx) + self._chat_route_select.setEnabled(bool(options)) + self._chat_route_select.blockSignals(False) + self._route_combo_updating = False + + def _switch_route_target(self, target): + if target is None: + return + try: + self.agent.set_active_route(target) + except Exception as exc: + self._add_system_notice(f"切换路由失败:{exc}") + return + self._refresh_switch_state() + summary = (self._switch_vm or {}).get("summary") or {} + self._add_system_notice(f"已切换至 {summary.get('headline') or target},对话上下文已保留") + + def _on_chat_route_selected(self, _idx): + if self._route_combo_updating: + return + self._switch_route_target(self._chat_route_select.currentData()) + + def _on_route_center_runtime_changed(self, _snapshot, _viewmodel): + self._refresh_switch_state() # ── chat page ───────────────────────────────────────────────────────────── def _build_chat_page(self) -> QWidget: @@ -949,6 +1203,9 @@ def _build_chat_page(self) -> QWidget: ly.setContentsMargins(0, 0, 0, 0) ly.setSpacing(0) + ly.addWidget(self._build_chat_route_bar()) + ly.addWidget(_Separator()) + # ── message scroll area ── self._scroll = QScrollArea() self._scroll.setWidgetResizable(True) @@ -996,7 +1253,7 @@ def _build_input_area(self) -> QWidget: border-radius: 16px; }} QWidget#inputCard:focus-within {{ - border-color: rgba(124,58,237,0.55); + border-color: rgba(59,130,246,0.55); }} """) card.setObjectName("inputCard") @@ -1011,7 +1268,7 @@ def _build_input_area(self) -> QWidget: QTextEdit {{ background: transparent; color: {C['text']}; border: none; padding: 0; font-size: 14px; - selection-background-color: rgba(124,58,237,0.4); + selection-background-color: rgba(59,130,246,0.25); }} """) self._input.installEventFilter(self) @@ -1090,9 +1347,9 @@ def _build_history_page(self) -> QWidget: padding: 8px 12px; margin: 2px 0; }} QListWidget::item:hover {{ background: rgba(55,55,65,0.8); - border-color: rgba(124,58,237,0.4); }} - QListWidget::item:selected {{ background: rgba(124,58,237,0.25); - border-color: rgba(124,58,237,0.6); }} + border-color: rgba(59,130,246,0.4); }} + QListWidget::item:selected {{ background: rgba(59,130,246,0.25); + border-color: rgba(59,130,246,0.6); }} {SCROLLBAR_STYLE} """) self._hist_list.itemDoubleClicked.connect(self._restore_selected) @@ -1116,7 +1373,7 @@ def _build_sop_page(self) -> QWidget: QListWidget::item {{ color: {C['muted']}; padding: 7px 10px; border-radius: 4px; margin: 1px 4px; }} QListWidget::item:hover {{ background: rgba(55,55,65,0.7); color: {C['text']}; }} - QListWidget::item:selected {{ background: rgba(124,58,237,0.28); color: white; }} + QListWidget::item:selected {{ background: rgba(59,130,246,0.25); color: white; }} {SCROLLBAR_STYLE} """) self._sop_list.currentItemChanged.connect(self._load_sop) @@ -1143,31 +1400,25 @@ def _build_settings_page(self) -> QWidget: page.setStyleSheet("background: transparent;") ly = QVBoxLayout(page) ly.setContentsMargins(16, 16, 16, 16) - ly.setSpacing(8) + ly.setSpacing(10) - lbl = QLabel("控制面板") + lbl = QLabel("会话控制") lbl.setStyleSheet("color: #f4f4f5; font-weight: 600; font-size: 14px;") ly.addWidget(lbl) - self._model_info = QLabel(f"当前模型:{self._model_name()} (#{self.agent.llm_no})") - self._model_info.setStyleSheet(f"color: {C['muted']}; font-size: 12px;") - ly.addWidget(self._model_info) - ly.addSpacing(4) - - model_hdr = QLabel("模型列表") - model_hdr.setStyleSheet("color: #d4d4d8; font-weight: 600; font-size: 13px;") - ly.addWidget(model_hdr) + self._session_runtime_info = QLabel("") + self._session_runtime_info.setWordWrap(True) + self._session_runtime_info.setStyleSheet(f"color: {C['muted']}; font-size: 12px;") + ly.addWidget(self._session_runtime_info) - self._model_rows_container = QWidget() - self._model_rows_container.setStyleSheet("background: transparent;") - self._model_rows_layout = QVBoxLayout(self._model_rows_container) - self._model_rows_layout.setContentsMargins(0, 0, 0, 0) - self._model_rows_layout.setSpacing(3) - ly.addWidget(self._model_rows_container) + hint = QLabel("路由切换、模型服务编辑、诊断和连通性测试都在“路由”页中。这里保留会话级操作。") + hint.setWordWrap(True) + hint.setStyleSheet("color: #a1a1aa; font-size: 12px;") + ly.addWidget(hint) - self._model_row_widgets: list[dict] = [] - self._health_results: dict[int, bool | None] = {} - self._build_model_rows() + open_route_btn = _action_btn("打开路由页", "#3b82f6", _svg_icon("route", _SVG_ROUTE)) + open_route_btn.clicked.connect(lambda: self._switch_tab(3)) + ly.addWidget(open_route_btn) ly.addSpacing(6) @@ -1185,7 +1436,7 @@ def _build_settings_page(self) -> QWidget: sep.setStyleSheet("color: #f4f4f5; font-weight: 600; font-size: 13px;") ly.addWidget(sep) - self._auto_btn = _action_btn("开启自主行动 (idle > 30 min 自动触发)", "#f59e0b", + self._auto_btn = _action_btn("开启自主行动(空闲超过 30 分钟后自动触发)", "#f59e0b", _svg_icon("bolt", _SVG_BOLT)) self._auto_btn.setCheckable(True) self._auto_btn.clicked.connect(self._do_toggle_auto) @@ -1199,6 +1450,16 @@ def _build_settings_page(self) -> QWidget: ly.addStretch() return page + def _refresh_session_controls(self): + summary = (self._switch_vm or {}).get("summary") or {} + headline = summary.get("headline") or self._model_name() + meta = summary.get("meta") or "继续使用当前模型" + member = summary.get("active_member_name") or "无" + last_error = summary.get("last_error_message") or "最近没有错误" + self._session_runtime_info.setText( + f"当前路由:{headline}\n{meta}\n当前成员:{member}\n最近错误:{last_error}" + ) + # ── model list ──────────────────────────────────────────────────────────── _MODEL_ROW_STYLE = ( "QPushButton { background: rgba(39,39,42,0.7); color: #e4e4e7;" @@ -1207,10 +1468,10 @@ def _build_settings_page(self) -> QWidget: " QPushButton:hover { background: rgba(63,63,70,0.8); }" ) _MODEL_ROW_ACTIVE = ( - "QPushButton { background: rgba(124,58,237,0.25); color: #c4b5fd;" - " border: 1px solid rgba(124,58,237,0.5); border-radius: 8px;" + "QPushButton { background: rgba(59,130,246,0.2); color: #60a5fa;" + " border: 1px solid rgba(59,130,246,0.4); border-radius: 8px;" " padding: 6px 10px; font-size: 12px; font-weight: 700; text-align: left; }" - " QPushButton:hover { background: rgba(124,58,237,0.35); }" + " QPushButton:hover { background: rgba(59,130,246,0.3); }" ) def _build_model_rows(self): @@ -1264,11 +1525,13 @@ def _do_switch_to(self, idx: int): if idx == self.agent.llm_no: return self.agent.next_llm(n=idx) - name = self._model_name() - self._model_badge.setText(name) - self._model_info.setText(f"当前模型:{name} (#{self.agent.llm_no})") + self._refresh_switch_state() + name = (self._switch_vm or {}).get("summary", {}).get("headline") or self._model_name() + if hasattr(self, "_model_info"): + self._model_info.setText(f"当前模型:{name} (#{self.agent.llm_no})") self._add_system_notice(f"已切换至 {name},对话上下文已保留") - self._refresh_model_rows_style() + if hasattr(self, "_model_row_widgets"): + self._refresh_model_rows_style() def _start_health_checks(self): self._health_results.clear() @@ -1322,9 +1585,9 @@ def _on_text_changed(self): def _attach_files(self): paths, _ = QFileDialog.getOpenFileNames( self, "选择附件", "", - "All Files (*);;" - "Images (*.png *.jpg *.jpeg *.gif *.webp *.bmp);;" - "Text (*.txt *.md *.py *.json *.csv *.yaml *.yml *.log *.js *.ts *.sql)", + "所有文件 (*);;" + "图片 (*.png *.jpg *.jpeg *.gif *.webp *.bmp);;" + "文本与代码 (*.txt *.md *.py *.json *.csv *.yaml *.yml *.log *.js *.ts *.sql)", ) for path in paths: name = os.path.basename(path) @@ -1395,6 +1658,10 @@ def _on_send_btn_click(self): self._handle_send() def _handle_send(self): + if self.agent.llmclient is None: + self._add_system_notice("当前没有可用路由,请先到“路由”页导入 mykey,或配置模型服务与路由。") + self._switch_tab(3) + return text = self._input.toPlainText().strip() files = self._pending_files.copy() if not text and not files: @@ -1453,6 +1720,7 @@ def _poll_queue(self): self._update_token_usage() self._scroll_bottom() self._auto_save() + self._refresh_switch_state() break except _queue.Empty: pass @@ -1545,7 +1813,7 @@ def _update_token_usage(self): if in_tokens == 0 and out_tokens == 0: self._token_lbl.setText("") else: - self._token_lbl.setText(f"| 会话上下文消耗: 入 {in_tokens} 出 {out_tokens} tokens") + self._token_lbl.setText(f"| 会话上下文消耗:输入 {in_tokens},输出 {out_tokens},按词元估算") # ── SOP ──────────────────────────────────────────────────────────────────── def _refresh_sop(self): @@ -1650,7 +1918,7 @@ def _new_session(self): def _do_toggle_auto(self): self.autonomous_enabled = not self.autonomous_enabled self._auto_btn.setChecked(self.autonomous_enabled) - lbl = "暂停自主行动" if self.autonomous_enabled else "开启自主行动 (idle > 30 min 自动触发)" + lbl = "暂停自主行动" if self.autonomous_enabled else "开启自主行动(空闲超过 30 分钟后自动触发)" self._auto_btn.setText(lbl) def _do_trigger_auto(self): @@ -1693,14 +1961,13 @@ def main(): # ── Agent initialisation ────────────────────────────── agent = GeneraticAgent() + threading.Thread(target=agent.run, daemon=True).start() if agent.llmclient is None: QMessageBox.critical( None, "未配置 LLM", "未在 mykey.py 中发现任何可用的 LLM 接口配置,\n程序将在无 LLM 模式下运行。", ) - else: - threading.Thread(target=agent.run, daemon=True).start() # ── Windows ─────────────────────────────────────────── panel = ChatPanel(agent) diff --git a/frontends/shared_runtime.py b/frontends/shared_runtime.py new file mode 100644 index 0000000..e57b58e --- /dev/null +++ b/frontends/shared_runtime.py @@ -0,0 +1,29 @@ +import os +import sys +import threading + +script_dir = os.path.dirname(__file__) +repo_dir = os.path.abspath(os.path.join(script_dir, "..")) +if repo_dir not in sys.path: + sys.path.append(repo_dir) + +from agentmain import GeneraticAgent +from ga_switch import get_service + +_runtime_lock = threading.Lock() +_runtime = None +_worker_started = False + + +def get_shared_runtime(): + global _runtime, _worker_started + with _runtime_lock: + if _runtime is None: + service = get_service() + agent = GeneraticAgent() + _runtime = (service, agent) + service, agent = _runtime + if agent.llmclient is not None and not _worker_started: + threading.Thread(target=agent.run, daemon=True, name="ga-shared-agent").start() + _worker_started = True + return service, agent diff --git a/frontends/stapp.py b/frontends/stapp.py index 2fe9afa..dbc7bf6 100644 --- a/frontends/stapp.py +++ b/frontends/stapp.py @@ -1,197 +1,418 @@ -import os, sys, subprocess -from urllib.request import urlopen +import json +import os +import queue +import re +import subprocess +import sys +import threading +import time from urllib.parse import quote -if sys.stdout is None: sys.stdout = open(os.devnull, "w") -if sys.stderr is None: sys.stderr = open(os.devnull, "w") -try: sys.stdout.reconfigure(errors='replace') -except: pass -try: sys.stderr.reconfigure(errors='replace') -except: pass +from urllib.request import urlopen + +if sys.stdout is None: + sys.stdout = open(os.devnull, "w") +if sys.stderr is None: + sys.stderr = open(os.devnull, "w") +try: + sys.stdout.reconfigure(errors="replace") +except Exception: + pass +try: + sys.stderr.reconfigure(errors="replace") +except Exception: + pass + script_dir = os.path.dirname(__file__) -sys.path.append(os.path.abspath(os.path.join(script_dir, '..'))) +repo_dir = os.path.abspath(os.path.join(script_dir, "..")) +if repo_dir not in sys.path: + sys.path.append(repo_dir) +if script_dir not in sys.path: + sys.path.append(script_dir) import streamlit as st -import time, json, re, threading, queue -from agentmain import GeneraticAgent -st.set_page_config(page_title="Cowork", layout="wide") +from shared_runtime import get_shared_runtime + +st.set_page_config(page_title="Cowork", layout="wide", initial_sidebar_state="expanded") + +service, agent = get_shared_runtime() +if agent.llmclient is None: + st.error("No LLM routes are available. Configure mykey.py or import structured routes first.") + st.stop() + +st.markdown( + """ + + """, + unsafe_allow_html=True, +) -@st.cache_resource -def init(): - agent = GeneraticAgent() - if agent.llmclient is None: - st.error("⚠️ 未配置任何可用的 LLM 接口,请设置mykey.py。") - st.stop() - else: threading.Thread(target=agent.run, daemon=True).start() - return agent +if "autonomous_enabled" not in st.session_state: + st.session_state.autonomous_enabled = False +if "messages" not in st.session_state: + st.session_state.messages = [] -agent = init() -st.title("🖥️ Cowork") +def get_snapshot(): + return service.get_ui_snapshot(agent) + + +def route_label(route): + if route["kind"] == "single": + provider = (route["provider"] or {}).get("name") or "No provider" + return f"{route['name']} | {provider}" + members = " -> ".join(member["name"] for member in route["members"]) or "No members" + return f"{route['name']} | {members}" + + +def activate_route(route_id): + try: + agent.set_active_route(route_id) + st.toast(f"Switched route to {route_id}") + st.rerun() + except Exception as exc: + st.error(str(exc)) + + +def render_route_banner(): + snapshot = get_snapshot() + active = snapshot["active_route_summary"] + routes = snapshot["routes"] + route_ids = [route["id"] for route in routes] + if not route_ids: + return + active_route_id = snapshot["active_route_id"] if snapshot["active_route_id"] in route_ids else route_ids[0] + route_error = active.get("last_error_message") or "No recent error" + st.markdown( + f""" +
+

{active.get('route_name') or 'No active route'}

+

{active.get('provider_name') or 'No provider'} | {active.get('model') or 'No model'} | {active.get('backend_class') or 'No backend'}

+ {active.get('route_kind') or 'unknown'} + {'native tools' if active.get('native_tools') else 'text tools'} + {active.get('active_member_name') or 'no active member'} + {active.get('last_error_kind') or 'healthy'} +
+ """, + unsafe_allow_html=True, + ) + cols = st.columns([2.2, 1.1, 1.1, 1.3]) + selected_route = cols[0].selectbox( + "Route", + options=route_ids, + index=route_ids.index(active_route_id), + format_func=lambda rid: route_label(snapshot["routes_by_id"][rid]), + key="chat_route_banner_select", + ) + if selected_route != active_route_id: + activate_route(selected_route) + cols[1].caption(f"API Mode: {active.get('api_mode') or 'n/a'}") + cols[1].caption(f"Active Member: {active.get('active_member_name') or 'n/a'}") + cols[2].caption(f"Last OK: {active.get('last_ok_at') or 'n/a'}") + cols[2].caption(f"Last Error: {active.get('last_error_at') or 'n/a'}") + if hasattr(st, "page_link"): + cols[3].page_link("pages/1_GA_Switch_Admin.py", label="Open GA Switch Admin", icon="🧭") + cols[3].caption(route_error[:120]) -if 'autonomous_enabled' not in st.session_state: st.session_state.autonomous_enabled = False @st.fragment def render_sidebar(): - current_idx = agent.llm_no - st.caption(f"LLM Core: {current_idx}: {agent.get_llm_name()}", help="点击切换备用链路") - last_reply_time = st.session_state.get('last_reply_time', 0) - if last_reply_time > 0: - st.caption(f"空闲时间:{int(time.time()) - last_reply_time}秒", help="当超过30分钟未收到回复时,系统会自动任务") - if st.button("切换备用链路"): - agent.next_llm(); st.rerun(scope="fragment") - if st.button("强行停止任务"): - agent.abort(); st.toast("已发送停止信号"); st.rerun() - if st.button("重新注入工具"): - agent.llmclient.last_tools = '' + snapshot = get_snapshot() + active = snapshot["active_route_summary"] + route_ids = [route["id"] for route in snapshot["routes"]] + st.caption(f"Current route: {active.get('route_name') or 'n/a'}") + if route_ids: + selected_route = st.selectbox( + "Switch route", + options=route_ids, + index=route_ids.index(snapshot["active_route_id"]) if snapshot["active_route_id"] in route_ids else 0, + format_func=lambda rid: route_label(snapshot["routes_by_id"][rid]), + key="chat_route_sidebar_select", + ) + if selected_route != snapshot["active_route_id"]: + activate_route(selected_route) + st.caption(f"Provider: {active.get('provider_name') or 'n/a'}") + st.caption(f"Model: {active.get('model') or 'n/a'}") + st.caption(f"Backend: {active.get('backend_class') or 'n/a'}") + st.caption(f"Active member: {active.get('active_member_name') or 'n/a'}") + if active.get("last_error_message"): + st.warning(active["last_error_message"][:160]) + if hasattr(st, "page_link"): + st.page_link("pages/1_GA_Switch_Admin.py", label="Open GA Switch Admin", icon="🧭") + if st.button("Abort current task", use_container_width=True): + agent.abort() + st.toast("Abort signal sent") + st.rerun() + if st.button("Re-inject tool schema", use_container_width=True): + if hasattr(agent.llmclient, "last_tools"): + agent.llmclient.last_tools = "" try: - hist_path = os.path.join(script_dir, '..', 'assets', 'tool_usable_history.json') - with open(hist_path, 'r', encoding='utf-8') as f: tool_hist = json.load(f) + hist_path = os.path.join(script_dir, "..", "assets", "tool_usable_history.json") + with open(hist_path, "r", encoding="utf-8") as f: + tool_hist = json.load(f) agent.llmclient.backend.history.extend(tool_hist) - st.toast(f"已重新注入工具,追加了 {len(tool_hist)} 条示范记录") - except Exception as e: st.toast(f"注入工具示范失败: {e}") - if st.button("🐱 桌面宠物"): - kwargs = {'creationflags': 0x08} if sys.platform == 'win32' else {} - pet_script = os.path.join(script_dir, 'desktop_pet_v2.pyw') - if not os.path.exists(pet_script): pet_script = os.path.join(script_dir, 'desktop_pet.pyw') + st.toast(f"Injected {len(tool_hist)} tool history entries") + except Exception as exc: + st.toast(f"Tool history injection failed: {exc}") + if st.button("Desktop pet", use_container_width=True): + kwargs = {"creationflags": 0x08} if sys.platform == "win32" else {} + pet_script = os.path.join(script_dir, "desktop_pet_v2.pyw") + if not os.path.exists(pet_script): + pet_script = os.path.join(script_dir, "desktop_pet.pyw") subprocess.Popen([sys.executable, pet_script], **kwargs) - def _pet_req(q): + + def pet_request(query): def _do(): - try: urlopen(f'http://127.0.0.1:41983/?{q}', timeout=2) - except Exception: pass + try: + urlopen(f"http://127.0.0.1:41983/?{query}", timeout=2) + except Exception: + pass + threading.Thread(target=_do, daemon=True).start() - agent._pet_req = _pet_req - if not hasattr(agent, '_turn_end_hooks'): agent._turn_end_hooks = {} - def _pet_hook(ctx): - parts = [f"Turn {ctx.get('turn','?')}"] - if ctx.get('summary'): parts.append(ctx['summary']) - if ctx.get('exit_reason'): parts.append('任务已完成') - _pet_req(f'msg={quote(chr(10).join(parts))}') - if ctx.get('exit_reason'): _pet_req('state=idle') - agent._turn_end_hooks['pet'] = _pet_hook - st.toast("桌面宠物已启动") - + + agent._pet_req = pet_request + if not hasattr(agent, "_turn_end_hooks"): + agent._turn_end_hooks = {} + + def pet_hook(ctx): + parts = [f"Turn {ctx.get('turn', '?')}"] + if ctx.get("summary"): + parts.append(ctx["summary"]) + if ctx.get("exit_reason"): + parts.append("Task completed") + pet_request(f"msg={quote(chr(10).join(parts))}") + if ctx.get("exit_reason"): + pet_request("state=idle") + + agent._turn_end_hooks["pet"] = pet_hook + st.toast("Desktop pet started") st.divider() - if st.button("开始空闲自主行动"): + if st.button("Start autonomous idle mode", use_container_width=True): st.session_state.last_reply_time = int(time.time()) - 1800 - st.toast("已将上次回复时间设为1800秒前"); st.rerun() + st.toast("Idle timer moved back by 1800 seconds") + st.rerun() if st.session_state.autonomous_enabled: - if st.button("⏸️ 禁止自主行动"): + if st.button("Disable autonomous mode", use_container_width=True): st.session_state.autonomous_enabled = False - st.toast("⏸️ 已禁止自主行动"); st.rerun() - st.caption("🟢 自主行动运行中,会在你离开它30分钟后自动进行") + st.toast("Autonomous mode disabled") + st.rerun() + st.caption("Autonomous mode will trigger after 30 minutes of inactivity.") else: - if st.button("▶️ 允许自主行动", type="primary"): + if st.button("Enable autonomous mode", type="primary", use_container_width=True): st.session_state.autonomous_enabled = True - st.toast("✅ 已允许自主行动"); st.rerun() - st.caption("🔴 自主行动已停止") -with st.sidebar: render_sidebar() + st.toast("Autonomous mode enabled") + st.rerun() + st.caption("Autonomous mode is currently disabled.") + + +with st.sidebar: + render_sidebar() + +st.title("Cowork") +render_route_banner() + def fold_turns(text): - """Return list of segments: [{'type':'text','content':...}, {'type':'fold','title':...,'content':...}]""" - parts = re.split(r'(\**LLM Running \(Turn \d+\) \.\.\.\*\**)', text) - if len(parts) < 4: return [{'type': 'text', 'content': text}] + parts = re.split(r"(\**LLM Running \(Turn \d+\) \.\.\.\*\**)", text) + if len(parts) < 4: + return [{"type": "text", "content": text}] segments = [] - if parts[0].strip(): segments.append({'type': 'text', 'content': parts[0]}) + if parts[0].strip(): + segments.append({"type": "text", "content": parts[0]}) turns = [] - for i in range(1, len(parts), 2): - marker = parts[i] - content = parts[i+1] if i+1 < len(parts) else '' + for idx in range(1, len(parts), 2): + marker = parts[idx] + content = parts[idx + 1] if idx + 1 < len(parts) else "" turns.append((marker, content)) for idx, (marker, content) in enumerate(turns): if idx < len(turns) - 1: - _c = re.sub(r'```.*?```|.*?', '', content, flags=re.DOTALL) - matches = re.findall(r'\s*((?:(?!).)*?)\s*', _c, re.DOTALL) + content_no_think = re.sub(r"```.*?```|.*?", "", content, flags=re.DOTALL) + matches = re.findall(r"\s*((?:(?!).)*?)\s*", content_no_think, re.DOTALL) if matches: - title = matches[0].strip() - title = title.split('\n')[0] - if len(title) > 50: title = title[:50] + '...' - else: title = marker.strip('*') - segments.append({'type': 'fold', 'title': title, 'content': content}) - else: segments.append({'type': 'text', 'content': marker + content}) + title = matches[0].strip().split("\n")[0] + if len(title) > 50: + title = title[:50] + "..." + else: + title = marker.strip("*") + segments.append({"type": "fold", "title": title, "content": content}) + else: + segments.append({"type": "text", "content": marker + content}) return segments -def render_segments(segments, suffix=''): - # 整块重画:调用方用 slot.container() 包裹,保证 DOM 路径稳定、跨 rerun 对齐(消除"灰色重影")。 - # heartbeat 空转时 segments 不变 → Streamlit 后端 diff 无变化 → 前端零闪烁; - # 但 container/markdown 本身是 API 调用,StopException 仍会被抛出(abort 照常起作用)。 - for seg in segments: - if seg['type'] == 'fold': - with st.expander(seg['title'], expanded=False): st.markdown(seg['content']) + + +def render_segments(segments, suffix=""): + for segment in segments: + if segment["type"] == "fold": + with st.expander(segment["title"], expanded=False): + st.markdown(segment["content"]) else: - st.markdown(seg['content'] + suffix) + st.markdown(segment["content"] + suffix) + def agent_backend_stream(prompt): display_queue = agent.put_task(prompt, source="user") - response = '' + response = "" try: while True: - try: item = display_queue.get(timeout=1) + try: + item = display_queue.get(timeout=1) except queue.Empty: - yield response # heartbeat: let outer st.markdown() run → Streamlit checks StopException + yield response continue - if 'next' in item: - response = item['next']; yield response - if 'done' in item: - yield item['done']; break - finally: agent.abort() + if "next" in item: + response = item["next"] + yield response + if "done" in item: + yield item["done"] + break + finally: + agent.abort() + -if "messages" not in st.session_state: st.session_state.messages = [] for msg in st.session_state.messages: with st.chat_message(msg["role"]): - # 用 slot=st.empty() + with slot.container(): ... 的外壳,DOM 路径和流式渲染完全一致,跨 rerun 对齐 slot = st.empty() with slot.container(): - if msg["role"] == "assistant": render_segments(fold_turns(msg["content"])) - else: st.markdown(msg["content"]) + if msg["role"] == "assistant": + render_segments(fold_turns(msg["content"])) + else: + st.markdown(msg["content"]) -# Scroll-height ghost fix: during streaming, expander open/close mid-animation can leave -# phantom height → scrollbar long but can't scroll to bottom. Periodically detect & reflow. try: - from streamlit import iframe as _st_iframe # 1.56+ + from streamlit import iframe as _st_iframe + _embed_html = lambda html, **kw: _st_iframe(html, **{k: max(v, 1) if isinstance(v, int) else v for k, v in kw.items()}) except (ImportError, AttributeError): - from streamlit.components.v1 import html as _embed_html # ≤1.55 -_js_scroll_fix = ("!function(){var p=window.parent;if(p.__sfx)return;p.__sfx=1;" + from streamlit.components.v1 import html as _embed_html + +_js_scroll_fix = ( + "!function(){var p=window.parent;if(p.__sfx)return;p.__sfx=1;" "var d=p.document;setInterval(function(){" "var m=d.querySelector('section.main');if(!m)return;" "var b=m.querySelector('.block-container');if(!b)return;" - "if(m.scrollHeight>b.scrollHeight+150){" - "m.style.overflow='hidden';void m.offsetHeight;m.style.overflow=''}" - "},3000)}()") -# IME composition fix (macOS only) - prevents Enter from submitting during CJK input -_js_ime_fix = ("" if os.name == 'nt' else - "!function(){if(window.parent.__imeFix)return;window.parent.__imeFix=1;" + "if(m.scrollHeight>b.scrollHeight+150){m.style.overflow='hidden';void m.offsetHeight;m.style.overflow=''}" + "},3000)}()" +) +_js_ime_fix = ( + "" + if os.name == "nt" + else "!function(){if(window.parent.__imeFix)return;window.parent.__imeFix=1;" "var d=window.parent.document,c=0;" "d.addEventListener('compositionstart',()=>c=1,!0);" "d.addEventListener('compositionend',()=>c=0,!0);" - "function f(){d.querySelectorAll('textarea[data-testid=stChatInputTextArea]')" - ".forEach(t=>{t.__imeFix||(t.__imeFix=1,t.addEventListener('keydown',e=>{" - "e.key==='Enter'&&!e.shiftKey&&(e.isComposing||c||e.keyCode===229)&&" - "(e.stopImmediatePropagation(),e.preventDefault())},!0))})}" - "f();new MutationObserver(f).observe(d.body,{childList:1,subtree:1})}()") -_embed_html(f'', height=0) - -if prompt := st.chat_input("any task?"): + "function f(){d.querySelectorAll('textarea[data-testid=stChatInputTextArea]').forEach(function(t){" + "if(t.__imeFix)return;t.__imeFix=1;t.addEventListener('keydown',function(e){" + "if(e.key==='Enter'&&!e.shiftKey&&(e.isComposing||c||e.keyCode===229)){e.stopImmediatePropagation();e.preventDefault();}" + "},!0);});}" + "f();new MutationObserver(f).observe(d.body,{childList:1,subtree:1})}()" +) +_embed_html(f"", height=0) + +if prompt := st.chat_input("Any task?"): st.session_state.messages.append({"role": "user", "content": prompt}) - if hasattr(agent, '_pet_req') and not prompt.startswith('/'): agent._pet_req('state=walk') - with st.chat_message("user"): st.markdown(prompt) + if hasattr(agent, "_pet_req") and not prompt.startswith("/"): + agent._pet_req("state=walk") + with st.chat_message("user"): + st.markdown(prompt) with st.chat_message("assistant"): - frozen = 0; live = st.empty(); response = '' - CURSOR = ' ▌' + frozen = 0 + live = st.empty() + response = "" + cursor = " ▍" for response in agent_backend_stream(prompt): - segs = fold_turns(response) - n_done = max(0, len(segs) - 1) - while frozen < n_done: - with live.container(): render_segments([segs[frozen]]) - live = st.empty(); frozen += 1 - with live.container(): render_segments([segs[-1]], suffix=CURSOR) # live 区域 - segs = fold_turns(response) - for i in range(frozen, len(segs)): - with live.container(): render_segments([segs[i]]) - if i < len(segs) - 1: live = st.empty() + segments = fold_turns(response) + completed = max(0, len(segments) - 1) + while frozen < completed: + with live.container(): + render_segments([segments[frozen]]) + live = st.empty() + frozen += 1 + with live.container(): + render_segments([segments[-1]], suffix=cursor) + segments = fold_turns(response) + for idx in range(frozen, len(segments)): + with live.container(): + render_segments([segments[idx]]) + if idx < len(segments) - 1: + live = st.empty() st.session_state.messages.append({"role": "assistant", "content": response}) st.session_state.last_reply_time = int(time.time()) if st.session_state.autonomous_enabled: - st.markdown(f"""""", unsafe_allow_html=True) + st.markdown( + f"""""", + unsafe_allow_html=True, + ) diff --git a/frontends/stapp2.py b/frontends/stapp2.py index 1d7968f..b9770eb 100644 --- a/frontends/stapp2.py +++ b/frontends/stapp2.py @@ -979,7 +979,7 @@ def render_sidebar(): st.rerun() st.divider() if st.button("重新注入System Prompt"): - agent.llmclient.last_tools = '' + if hasattr(agent.llmclient, 'last_tools'): agent.llmclient.last_tools = '' st.toast("下次将重新注入System Prompt") with st.sidebar: render_sidebar() diff --git a/ga_switch/__init__.py b/ga_switch/__init__.py new file mode 100644 index 0000000..7d81462 --- /dev/null +++ b/ga_switch/__init__.py @@ -0,0 +1,14 @@ +import os +from functools import lru_cache + + +def get_default_db_path(): + root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + return os.path.join(root_dir, ".ga-switch", "ga-switch.db") + + +@lru_cache(maxsize=8) +def get_service(db_path=None): + from .service import GASwitchService + + return GASwitchService(db_path=db_path or get_default_db_path()) diff --git a/ga_switch/diagnostics.py b/ga_switch/diagnostics.py new file mode 100644 index 0000000..42fb5a9 --- /dev/null +++ b/ga_switch/diagnostics.py @@ -0,0 +1,48 @@ +from datetime import datetime, timezone + + +ERROR_KINDS = ( + "auth", + "quota", + "rate_limit", + "timeout", + "network", + "server", + "bad_request", + "model_not_found", + "unsupported_param", + "unknown", +) + + +def utcnow_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z") + + +def normalize_message(message, limit=2000) -> str: + text = "" if message is None else str(message).strip() + return text[:limit] + + +def classify_error(*, status_code=None, message="", body="", exc_type="") -> str: + status = None if status_code is None else int(status_code) + hay = " ".join(x for x in (str(message or ""), str(body or ""), str(exc_type or "")) if x).lower() + + if status in (401, 403): + return "auth" + if status == 404 or "model_not_found" in hay or "model not found" in hay or "no such model" in hay: + return "model_not_found" + if "unsupported_param" in hay or ("unsupported" in hay and any(k in hay for k in ("param", "reasoning_effort", "reasoning.effort", "api_mode"))): + return "unsupported_param" + if status == 400: + return "unsupported_param" if "unsupported" in hay else "bad_request" + if status == 429: + quota_tokens = ("insufficient_quota", "quota", "credit", "billing", "余额", "配额") + return "quota" if any(token in hay for token in quota_tokens) else "rate_limit" + if status is not None and status >= 500: + return "server" + if any(token in hay for token in ("timeout", "timed out", "readtimeout", "connecttimeout")): + return "timeout" + if any(token in hay for token in ("connectionerror", "proxyerror", "sslerror", "name or service not known", "connection reset", "dns", "proxy", "connection refused")): + return "network" + return "unknown" diff --git a/ga_switch/models.py b/ga_switch/models.py new file mode 100644 index 0000000..3f23907 --- /dev/null +++ b/ga_switch/models.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass, field +from typing import Any + + +PROVIDER_BACKEND_KINDS = ( + "native_claude", + "native_oai", + "claude_text", + "oai_text", +) + +ROUTE_KINDS = ("single", "failover") + + +def is_native_backend_kind(kind: str) -> bool: + return str(kind or "").startswith("native_") + + +def backend_family(kind: str) -> str: + kind = str(kind or "") + if "claude" in kind: + return "claude" + return "oai" + + +@dataclass +class ProviderModel: + id: int | None = None + name: str = "" + backend_kind: str = "oai_text" + apikey: str = "" + apibase: str = "" + model: str = "" + api_mode: str = "chat_completions" + temperature: float = 1.0 + max_tokens: int = 8192 + context_win: int = 24000 + proxy: str | None = None + timeout: int = 5 + read_timeout: int = 30 + max_retries: int = 1 + reasoning_effort: str | None = None + thinking_type: str | None = None + thinking_budget_tokens: int | None = None + stream: bool = True + is_enabled: bool = True + extra: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class RouteModel: + id: int | None = None + name: str = "" + kind: str = "single" + provider_id: int | None = None + member_provider_ids: list[int] = field(default_factory=list) + is_enabled: bool = True + is_default: bool = False + config: dict[str, Any] = field(default_factory=dict) diff --git a/ga_switch/service.py b/ga_switch/service.py new file mode 100644 index 0000000..3430674 --- /dev/null +++ b/ga_switch/service.py @@ -0,0 +1,401 @@ +import importlib.util +import json +import os + +from .models import PROVIDER_BACKEND_KINDS, is_native_backend_kind +from .store import GASwitchStore +from .testing import ModelTester + + +class GASwitchService: + def __init__(self, db_path): + self.db_path = db_path + self.store = GASwitchStore(db_path) + self.tester = ModelTester(self) + + def list_providers(self): + return self.store.list_providers(enabled_only=False) + + def upsert_provider(self, provider): + return self.store.upsert_provider(provider) + + def delete_provider(self, provider_id): + return self.store.delete_provider(provider_id) + + def list_routes(self): + return self.store.list_routes(enabled_only=False) + + def upsert_route(self, route): + return self.store.upsert_route(route) + + def delete_route(self, route_id): + return self.store.delete_route(route_id) + + def set_active_route(self, route_id): + return self.store.set_active_route(route_id) + + def set_structured_config_enabled(self, enabled): + self.store.set_setting("use_structured_config", bool(enabled)) + + def use_structured_config(self): + return bool(self.store.get_setting("use_structured_config", False)) + + def has_usable_routes(self): + return any(route["is_enabled"] for route in self.store.list_routes(enabled_only=True)) + + def get_active_route_id(self): + return self.store.get_setting("active_route_id") + + def _legacy_payload_to_provider(self, var_name, cfg): + lower_name = str(var_name).lower() + if "mixin" in lower_name: + return None + if "native" in lower_name and "claude" in lower_name: + backend_kind = "native_claude" + elif "native" in lower_name and "oai" in lower_name: + backend_kind = "native_oai" + elif "claude" in lower_name: + backend_kind = "claude_text" + elif "oai" in lower_name: + backend_kind = "oai_text" + else: + return None + return { + "name": cfg.get("name") or var_name, + "backend_kind": backend_kind, + "apikey": cfg.get("apikey", ""), + "apibase": cfg.get("apibase", ""), + "model": cfg.get("model", ""), + "api_mode": cfg.get("api_mode", "chat_completions"), + "temperature": cfg.get("temperature", 1.0), + "max_tokens": cfg.get("max_tokens", 8192), + "context_win": cfg.get("context_win", 24000), + "proxy": cfg.get("proxy"), + "timeout": cfg.get("timeout", cfg.get("connect_timeout", 5)), + "read_timeout": cfg.get("read_timeout", 30), + "max_retries": cfg.get("max_retries", 1), + "reasoning_effort": cfg.get("reasoning_effort"), + "thinking_type": cfg.get("thinking_type"), + "thinking_budget_tokens": cfg.get("thinking_budget_tokens"), + "stream": cfg.get("stream", True), + "is_enabled": True, + "extra": {"legacy_var_name": var_name}, + } + + def _load_legacy_config(self, path=None): + if not path: + repo_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + py_path = os.path.join(repo_dir, "mykey.py") + json_path = os.path.join(repo_dir, "mykey.json") + path = py_path if os.path.exists(py_path) else json_path + if not path or not os.path.exists(path): + raise FileNotFoundError("Legacy config not found.") + if path.lower().endswith(".json"): + with open(path, "r", encoding="utf-8") as f: + payload = json.load(f) + else: + spec = importlib.util.spec_from_file_location("ga_switch_legacy_mykey", path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + payload = {k: v for k, v in vars(module).items() if not k.startswith("_")} + return payload, path + + def import_legacy_mykey(self, path=None): + payload, source_path = self._load_legacy_config(path) + ordered_providers = [] + providers_by_name = {} + for var_name, cfg in payload.items(): + if not isinstance(cfg, dict): + continue + provider = self._legacy_payload_to_provider(var_name, cfg) + if not provider: + continue + saved = self.store.upsert_provider(provider) + providers_by_name[saved["name"]] = saved + ordered_providers.append(saved) + self.store.upsert_route({ + "name": saved["name"], + "kind": "single", + "provider_id": saved["id"], + "is_enabled": True, + "is_default": False, + }) + for var_name, cfg in payload.items(): + if not (isinstance(cfg, dict) and "mixin" in str(var_name).lower()): + continue + llm_nos = cfg.get("llm_nos") or [] + member_ids = [] + for ref in llm_nos: + if isinstance(ref, int): + if ref < 0 or ref >= len(ordered_providers): + raise ValueError(f"Invalid mixin index {ref} in {var_name}") + member_ids.append(ordered_providers[ref]["id"]) + else: + ref_name = str(ref).strip() + provider = providers_by_name.get(ref_name) + if provider is None: + raise ValueError(f"Mixin {var_name} references unknown provider name {ref_name}") + member_ids.append(provider["id"]) + self.store.upsert_route({ + "name": cfg.get("name") or var_name, + "kind": "failover", + "member_provider_ids": member_ids, + "is_enabled": True, + "is_default": False, + "config": { + "max_retries": cfg.get("max_retries", 3), + "base_delay": cfg.get("base_delay", 1.5), + "spring_back": cfg.get("spring_back", 300), + }, + }) + routes = self.store.list_routes(enabled_only=False) + if routes: + self.store.set_setting("active_route_id", routes[0]["id"]) + self.store.set_setting("use_structured_config", True) + return { + "source_path": source_path, + "providers": self.store.list_providers(enabled_only=False), + "routes": self.store.list_routes(enabled_only=False), + } + + def export_legacy_config(self): + payload = {} + for provider in self.store.list_providers(enabled_only=False): + payload[provider["name"]] = { + "name": provider["name"], + "apikey": provider["apikey"], + "apibase": provider["apibase"], + "model": provider["model"], + "api_mode": provider["api_mode"], + "temperature": provider["temperature"], + "max_tokens": provider["max_tokens"], + "context_win": provider["context_win"], + "proxy": provider["proxy"], + "timeout": provider["timeout"], + "read_timeout": provider["read_timeout"], + "max_retries": provider["max_retries"], + "reasoning_effort": provider["reasoning_effort"], + "thinking_type": provider["thinking_type"], + "thinking_budget_tokens": provider["thinking_budget_tokens"], + "stream": provider["stream"], + } + for route in self.store.list_routes(enabled_only=False): + if route["kind"] != "failover": + continue + payload[f"mixin_{route['name']}"] = { + "name": route["name"], + "llm_nos": [member["name"] for member in route["members"]], + "max_retries": route["config"].get("max_retries", 3), + "base_delay": route["config"].get("base_delay", 1.5), + "spring_back": route["config"].get("spring_back", 300), + } + return payload + + def _build_provider_cfg(self, provider, override=None): + cfg = { + "name": provider["name"], + "apikey": provider["apikey"], + "apibase": provider["apibase"], + "model": provider.get("model") or "", + "api_mode": provider.get("api_mode") or "chat_completions", + "temperature": provider.get("temperature", 1.0), + "max_tokens": provider.get("max_tokens", 8192), + "context_win": provider.get("context_win", 24000), + "proxy": provider.get("proxy"), + "timeout": provider.get("timeout", 5), + "read_timeout": provider.get("read_timeout", 30), + "max_retries": provider.get("max_retries", 1), + "reasoning_effort": provider.get("reasoning_effort"), + "thinking_type": provider.get("thinking_type"), + "thinking_budget_tokens": provider.get("thinking_budget_tokens"), + "stream": provider.get("stream", True), + } + if override: + cfg.update({k: v for k, v in override.items() if v is not None}) + return cfg + + def _diagnostic_recorder(self, provider, route_id, route_name): + def _record(event): + self.store.append_diagnostic_event( + provider_id=provider["id"], + route_id=route_id, + backend_name=event.get("backend_name") or provider["name"], + ok=event.get("ok", False), + error_kind=event.get("error_kind"), + message=event.get("message", ""), + status_code=event.get("status_code"), + extra=event.get("extra") or {"route_name": route_name}, + ) + if event.get("ok"): + self.store.update_provider_health( + provider["id"], + status="healthy", + latency_ms=(event.get("extra") or {}).get("latency_ms"), + ttfb_ms=(event.get("extra") or {}).get("ttfb_ms"), + last_error="", + ) + else: + self.store.update_provider_health( + provider["id"], + status="failed", + last_error=event.get("message", ""), + ) + return _record + + def build_client_from_provider(self, provider, *, route_id=None, route_name=None, route_kind="single", for_testing=False, override=None): + from llmcore import LLMSession, ToolClient, ClaudeSession, NativeToolClient, NativeClaudeSession, NativeOAISession + + cfg = self._build_provider_cfg(provider, override=override) + backend_kind = provider["backend_kind"] + if backend_kind not in PROVIDER_BACKEND_KINDS: + raise ValueError(f"Unsupported backend_kind: {backend_kind}") + if backend_kind == "native_claude": + backend = NativeClaudeSession(cfg=cfg) + client = NativeToolClient(backend) + elif backend_kind == "native_oai": + backend = NativeOAISession(cfg=cfg) + client = NativeToolClient(backend) + elif backend_kind == "claude_text": + backend = ClaudeSession(cfg=cfg) + client = ToolClient(backend) + else: + backend = LLMSession(cfg=cfg) + client = ToolClient(backend) + backend.provider_id = provider["id"] + backend.provider_name = provider["name"] + backend.route_id = route_id + backend.route_name = route_name or provider["name"] + backend.route_kind = route_kind + backend.backend_kind = backend_kind + backend._diagnostic_recorder = self._diagnostic_recorder(provider, route_id, route_name or provider["name"]) + backend._ga_switch_testing = bool(for_testing) + client.ga_switch_provider = provider + client.ga_switch_route_id = route_id + client.ga_switch_route_name = route_name or provider["name"] + client.ga_switch_route_kind = route_kind + client.ga_switch_backend_kind = backend_kind + return client + + def build_client_for_route(self, route): + from llmcore import MixinSession, ToolClient, NativeToolClient + + if route["kind"] == "single": + provider = route["provider"] + if provider is None: + raise ValueError(f"Single route {route['name']} is missing provider.") + client = self.build_client_from_provider( + provider, + route_id=route["id"], + route_name=route["name"], + route_kind=route["kind"], + ) + client.ga_switch_members = [provider] + return client + members = [ + self.build_client_from_provider( + provider, + route_id=route["id"], + route_name=route["name"], + route_kind=route["kind"], + ) + for provider in route["members"] + ] + mixin_cfg = { + "llm_nos": [member.backend.name for member in members], + "max_retries": route["config"].get("max_retries", 3), + "base_delay": route["config"].get("base_delay", 1.5), + "spring_back": route["config"].get("spring_back", 300), + } + mixin = MixinSession(members, mixin_cfg) + mixin.name = route["name"] + mixin.route_id = route["id"] + mixin.route_name = route["name"] + mixin.route_kind = route["kind"] + mixin.ga_switch_members = route["members"] + if is_native_backend_kind(route["members"][0]["backend_kind"]): + client = NativeToolClient(mixin) + else: + client = ToolClient(mixin) + client.ga_switch_provider = None + client.ga_switch_route_id = route["id"] + client.ga_switch_route_name = route["name"] + client.ga_switch_route_kind = route["kind"] + client.ga_switch_backend_kind = "mixin" + client.ga_switch_members = route["members"] + return client + + def build_clients_from_store(self): + routes = self.store.list_routes(enabled_only=True) + clients = [self.build_client_for_route(route) for route in routes] + active_route_id = self.store.get_setting("active_route_id") + if not clients: + return [], {"active_route_id": active_route_id, "routes": routes, "source": "store", "active_index": 0} + active_index = next((i for i, route in enumerate(routes) if route["id"] == active_route_id), 0) + return clients, {"active_route_id": active_route_id, "routes": routes, "source": "store", "active_index": active_index} + + def run_model_test(self, provider_id): + return self.tester.run(provider_id) + + def reload_agent(self, agent, preserve_history=True): + return agent.reload_llm_config(preserve_history=preserve_history) + + def get_ui_snapshot(self, agent=None): + providers = self.store.list_providers(enabled_only=False) + routes = self.store.list_routes(enabled_only=False) + events = self.store.list_diagnostic_events(limit=100) + runtime = agent.describe_llms() if agent is not None and hasattr(agent, "describe_llms") else [] + active_route_id = self.get_active_route_id() + active_route = next((route for route in routes if route["id"] == active_route_id), None) + runtime_by_route_id = {item["route_id"]: item for item in runtime if item.get("route_id") is not None} + active_runtime = runtime_by_route_id.get(active_route_id) or next((item for item in runtime if item.get("active")), None) + active_route_summary = { + "route_id": active_route_id, + "route_name": (active_route or {}).get("name"), + "route_kind": (active_route or {}).get("kind"), + "provider_name": (active_runtime or {}).get("provider_name"), + "model": (active_runtime or {}).get("model"), + "backend_class": (active_runtime or {}).get("backend_class"), + "backend_kind": (active_runtime or {}).get("backend_kind"), + "api_mode": (active_runtime or {}).get("api_mode"), + "native_tools": bool((active_runtime or {}).get("native_tools")), + "member_names": list((active_runtime or {}).get("member_names", [])), + "active_member_name": (active_runtime or {}).get("active_member_name"), + "last_error_kind": (active_runtime or {}).get("last_error_kind"), + "last_error_message": (active_runtime or {}).get("last_error_message"), + "last_error_at": (active_runtime or {}).get("last_error_at"), + "last_ok_at": (active_runtime or {}).get("last_ok_at"), + "last_status_code": (active_runtime or {}).get("last_status_code"), + "last_switch_reason": (active_runtime or {}).get("last_switch_reason"), + } + return { + "use_structured_config": self.use_structured_config(), + "active_route_id": active_route_id, + "active_route": active_route, + "active_runtime": active_runtime, + "active_route_summary": active_route_summary, + "providers": providers, + "providers_by_id": {provider["id"]: provider for provider in providers}, + "routes": routes, + "routes_by_id": {route["id"]: route for route in routes}, + "runtime": runtime, + "runtime_by_route_id": runtime_by_route_id, + "events": events, + "recent_events": events[:20], + "stats": { + "provider_count": len(providers), + "route_count": len(routes), + "runtime_count": len(runtime), + }, + } + + def get_runtime_diagnostics(self, agent=None): + snapshot = self.get_ui_snapshot(agent) + return { + "use_structured_config": snapshot["use_structured_config"], + "active_route_id": snapshot["active_route_id"], + "providers": snapshot["providers"], + "routes": snapshot["routes"], + "events": snapshot["events"], + "runtime": snapshot["runtime"], + "active_route_summary": snapshot["active_route_summary"], + } diff --git a/ga_switch/store.py b/ga_switch/store.py new file mode 100644 index 0000000..7605d4b --- /dev/null +++ b/ga_switch/store.py @@ -0,0 +1,509 @@ +import json +import os +import sqlite3 + +from .diagnostics import utcnow_iso +from .models import PROVIDER_BACKEND_KINDS, ROUTE_KINDS, backend_family, is_native_backend_kind + + +class GASwitchStore: + def __init__(self, db_path): + self.db_path = db_path + os.makedirs(os.path.dirname(os.path.abspath(db_path)), exist_ok=True) + self._init_db() + self._ensure_default_test_configs() + + def _connect(self): + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys = ON") + return conn + + @staticmethod + def _loads(payload, default): + if not payload: + return default + try: + return json.loads(payload) + except Exception: + return default + + @staticmethod + def _dumps(payload): + return json.dumps(payload or {}, ensure_ascii=False, separators=(",", ":")) + + def _row_to_provider(self, row, health_by_id): + if row is None: + return None + provider = dict(row) + provider["extra"] = self._loads(provider.pop("extra_json", None), {}) + provider["stream"] = bool(provider.get("stream", 1)) + provider["is_enabled"] = bool(provider.get("is_enabled", 1)) + provider["is_native"] = is_native_backend_kind(provider["backend_kind"]) + provider["backend_family"] = backend_family(provider["backend_kind"]) + provider["health"] = health_by_id.get(provider["id"], { + "provider_id": provider["id"], + "status": "unknown", + "latency_ms": None, + "ttfb_ms": None, + "last_checked_at": None, + "last_error": "", + }) + return provider + + def _health_map(self, conn): + rows = conn.execute("SELECT * FROM provider_health").fetchall() + return {row["provider_id"]: dict(row) for row in rows} + + def _init_db(self): + with self._connect() as conn: + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS providers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + backend_kind TEXT NOT NULL, + apikey TEXT NOT NULL, + apibase TEXT NOT NULL, + model TEXT, + api_mode TEXT DEFAULT 'chat_completions', + temperature REAL DEFAULT 1.0, + max_tokens INTEGER DEFAULT 8192, + context_win INTEGER DEFAULT 24000, + proxy TEXT, + timeout INTEGER DEFAULT 5, + read_timeout INTEGER DEFAULT 30, + max_retries INTEGER DEFAULT 1, + reasoning_effort TEXT, + thinking_type TEXT, + thinking_budget_tokens INTEGER, + stream INTEGER DEFAULT 1, + is_enabled INTEGER DEFAULT 1, + extra_json TEXT DEFAULT '{}', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS routes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + kind TEXT NOT NULL, + provider_id INTEGER, + is_enabled INTEGER DEFAULT 1, + is_default INTEGER DEFAULT 0, + config_json TEXT DEFAULT '{}', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY(provider_id) REFERENCES providers(id) + ); + CREATE TABLE IF NOT EXISTS route_members ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + route_id INTEGER NOT NULL, + provider_id INTEGER NOT NULL, + position INTEGER NOT NULL, + FOREIGN KEY(route_id) REFERENCES routes(id) ON DELETE CASCADE, + FOREIGN KEY(provider_id) REFERENCES providers(id), + UNIQUE(route_id, position) + ); + CREATE TABLE IF NOT EXISTS provider_health ( + provider_id INTEGER PRIMARY KEY, + status TEXT NOT NULL, + latency_ms INTEGER, + ttfb_ms INTEGER, + last_checked_at TEXT, + last_error TEXT, + FOREIGN KEY(provider_id) REFERENCES providers(id) ON DELETE CASCADE + ); + CREATE TABLE IF NOT EXISTS model_test_config ( + backend_family TEXT PRIMARY KEY, + test_model TEXT, + api_mode TEXT, + reasoning_effort TEXT, + prompt TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS app_settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS diagnostic_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + provider_id INTEGER, + route_id INTEGER, + backend_name TEXT, + ok INTEGER NOT NULL DEFAULT 0, + error_kind TEXT, + message TEXT, + status_code INTEGER, + created_at TEXT NOT NULL, + extra_json TEXT DEFAULT '{}', + FOREIGN KEY(provider_id) REFERENCES providers(id) ON DELETE SET NULL, + FOREIGN KEY(route_id) REFERENCES routes(id) ON DELETE SET NULL + ); + """ + ) + + def _ensure_default_test_configs(self): + defaults = { + "claude": { + "test_model": "claude-3-5-haiku-latest", + "api_mode": "chat_completions", + "reasoning_effort": None, + "prompt": "Reply with exactly: pong", + }, + "oai": { + "test_model": "gpt-4.1-mini", + "api_mode": "chat_completions", + "reasoning_effort": "low", + "prompt": "Reply with exactly: pong", + }, + } + with self._connect() as conn: + for family, payload in defaults.items(): + conn.execute( + """ + INSERT INTO model_test_config (backend_family, test_model, api_mode, reasoning_effort, prompt) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(backend_family) DO NOTHING + """, + (family, payload["test_model"], payload["api_mode"], payload["reasoning_effort"], payload["prompt"]), + ) + + def get_setting(self, key, default=None): + with self._connect() as conn: + row = conn.execute("SELECT value FROM app_settings WHERE key = ?", (key,)).fetchone() + if row is None: + return default + try: + return json.loads(row["value"]) + except Exception: + return row["value"] + + def set_setting(self, key, value): + payload = json.dumps(value, ensure_ascii=False) + with self._connect() as conn: + conn.execute( + """ + INSERT INTO app_settings (key, value) VALUES (?, ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """, + (key, payload), + ) + + def list_providers(self, enabled_only=False): + sql = "SELECT * FROM providers" + if enabled_only: + sql += " WHERE is_enabled = 1" + sql += " ORDER BY name COLLATE NOCASE" + with self._connect() as conn: + health_by_id = self._health_map(conn) + rows = conn.execute(sql).fetchall() + return [self._row_to_provider(row, health_by_id) for row in rows] + + def get_provider(self, provider_id): + with self._connect() as conn: + row = conn.execute("SELECT * FROM providers WHERE id = ?", (provider_id,)).fetchone() + return self._row_to_provider(row, self._health_map(conn)) + + def get_provider_by_name(self, name): + with self._connect() as conn: + row = conn.execute("SELECT * FROM providers WHERE name = ?", (name,)).fetchone() + return self._row_to_provider(row, self._health_map(conn)) + + def upsert_provider(self, provider): + backend_kind = str(provider.get("backend_kind") or "").strip() + if backend_kind not in PROVIDER_BACKEND_KINDS: + raise ValueError(f"Unsupported backend_kind: {backend_kind}") + name = str(provider.get("name") or "").strip() + apikey = str(provider.get("apikey") or "").strip() + apibase = str(provider.get("apibase") or "").strip() + if not name or not apikey or not apibase: + raise ValueError("Provider requires name, apikey and apibase.") + now = utcnow_iso() + payload = ( + name, + backend_kind, + apikey, + apibase.rstrip("/"), + provider.get("model") or "", + provider.get("api_mode") or "chat_completions", + float(provider.get("temperature", 1.0)), + int(provider.get("max_tokens", 8192)), + int(provider.get("context_win", 24000)), + provider.get("proxy"), + int(provider.get("timeout", 5)), + int(provider.get("read_timeout", 30)), + int(provider.get("max_retries", 1)), + provider.get("reasoning_effort"), + provider.get("thinking_type"), + provider.get("thinking_budget_tokens"), + 1 if provider.get("stream", True) else 0, + 1 if provider.get("is_enabled", True) else 0, + self._dumps(provider.get("extra")), + now, + now, + ) + with self._connect() as conn: + if provider.get("id"): + conn.execute( + """ + UPDATE providers + SET name=?, backend_kind=?, apikey=?, apibase=?, model=?, api_mode=?, temperature=?, max_tokens=?, + context_win=?, proxy=?, timeout=?, read_timeout=?, max_retries=?, reasoning_effort=?, + thinking_type=?, thinking_budget_tokens=?, stream=?, is_enabled=?, extra_json=?, updated_at=? + WHERE id = ? + """, + payload[:19] + (now, int(provider["id"])), + ) + provider_id = int(provider["id"]) + else: + cursor = conn.execute( + """ + INSERT INTO providers + (name, backend_kind, apikey, apibase, model, api_mode, temperature, max_tokens, context_win, + proxy, timeout, read_timeout, max_retries, reasoning_effort, thinking_type, thinking_budget_tokens, + stream, is_enabled, extra_json, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + payload, + ) + provider_id = int(cursor.lastrowid) + conn.execute( + """ + INSERT INTO provider_health (provider_id, status, latency_ms, ttfb_ms, last_checked_at, last_error) + VALUES (?, 'unknown', NULL, NULL, NULL, '') + ON CONFLICT(provider_id) DO NOTHING + """, + (provider_id,), + ) + return self.get_provider(provider_id) + + def delete_provider(self, provider_id): + with self._connect() as conn: + ref = conn.execute( + """ + SELECT routes.name FROM routes + LEFT JOIN route_members ON route_members.route_id = routes.id + WHERE routes.provider_id = ? OR route_members.provider_id = ? + LIMIT 1 + """, + (provider_id, provider_id), + ).fetchone() + if ref: + raise ValueError(f"Provider is still used by route {ref['name']}.") + conn.execute("DELETE FROM providers WHERE id = ?", (provider_id,)) + + def list_routes(self, enabled_only=False): + sql = "SELECT * FROM routes" + if enabled_only: + sql += " WHERE is_enabled = 1" + sql += " ORDER BY is_default DESC, id ASC" + with self._connect() as conn: + routes = [dict(row) for row in conn.execute(sql).fetchall()] + providers = {provider["id"]: provider for provider in self.list_providers(enabled_only=False)} + member_rows = conn.execute( + """ + SELECT route_id, provider_id, position + FROM route_members + ORDER BY route_id ASC, position ASC + """ + ).fetchall() + active_route_id = self.get_setting("active_route_id") + members_by_route = {} + for row in member_rows: + members_by_route.setdefault(row["route_id"], []).append(providers.get(row["provider_id"])) + result = [] + for route in routes: + route["config"] = self._loads(route.pop("config_json", None), {}) + route["is_enabled"] = bool(route.get("is_enabled", 1)) + route["is_default"] = bool(route.get("is_default", 0)) + route["provider"] = providers.get(route.get("provider_id")) + route["members"] = [member for member in members_by_route.get(route["id"], []) if member] + route["member_provider_ids"] = [member["id"] for member in route["members"]] + route["active"] = active_route_id == route["id"] + result.append(route) + return result + + def get_route(self, route_id): + return next((route for route in self.list_routes(enabled_only=False) if route["id"] == route_id), None) + + def _ensure_route_defaults(self, conn, route_id=None, make_default=False): + any_default = conn.execute("SELECT id FROM routes WHERE is_default = 1 LIMIT 1").fetchone() + if make_default or not any_default: + conn.execute("UPDATE routes SET is_default = 0") + if route_id: + conn.execute("UPDATE routes SET is_default = 1 WHERE id = ?", (route_id,)) + + def _validate_failover_members(self, providers): + if len(providers) < 2: + raise ValueError("Failover route requires at least two providers.") + native_groups = {is_native_backend_kind(provider["backend_kind"]) for provider in providers} + if len(native_groups) != 1: + kinds = [provider["backend_kind"] for provider in providers] + raise ValueError(f"Failover route cannot mix native and non-native providers: {kinds}") + + def upsert_route(self, route): + kind = str(route.get("kind") or "").strip() + if kind not in ROUTE_KINDS: + raise ValueError(f"Unsupported route kind: {kind}") + name = str(route.get("name") or "").strip() + if not name: + raise ValueError("Route requires a name.") + now = utcnow_iso() + with self._connect() as conn: + if kind == "single": + provider_id = int(route.get("provider_id") or 0) + provider = self.get_provider(provider_id) + if not provider: + raise ValueError(f"Single route provider not found: {provider_id}") + member_provider_ids = [] + else: + provider_id = None + member_provider_ids = [int(pid) for pid in (route.get("member_provider_ids") or [])] + providers = [self.get_provider(pid) for pid in member_provider_ids] + if not all(providers): + raise ValueError("Failover route references missing providers.") + self._validate_failover_members(providers) + is_enabled = 1 if route.get("is_enabled", True) else 0 + is_default = 1 if route.get("is_default", False) else 0 + config_json = self._dumps(route.get("config")) + if route.get("id"): + conn.execute( + """ + UPDATE routes + SET name = ?, kind = ?, provider_id = ?, is_enabled = ?, config_json = ?, updated_at = ? + WHERE id = ? + """, + (name, kind, provider_id, is_enabled, config_json, now, int(route["id"])), + ) + route_id = int(route["id"]) + conn.execute("DELETE FROM route_members WHERE route_id = ?", (route_id,)) + else: + cursor = conn.execute( + """ + INSERT INTO routes (name, kind, provider_id, is_enabled, is_default, config_json, created_at, updated_at) + VALUES (?, ?, ?, ?, 0, ?, ?, ?) + """, + (name, kind, provider_id, is_enabled, config_json, now, now), + ) + route_id = int(cursor.lastrowid) + for position, member_provider_id in enumerate(member_provider_ids): + conn.execute( + "INSERT INTO route_members (route_id, provider_id, position) VALUES (?, ?, ?)", + (route_id, member_provider_id, position), + ) + self._ensure_route_defaults(conn, route_id=route_id, make_default=bool(is_default)) + active_row = conn.execute("SELECT value FROM app_settings WHERE key = 'active_route_id'").fetchone() + if active_row is None: + conn.execute( + """ + INSERT INTO app_settings (key, value) VALUES ('active_route_id', ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """, + (json.dumps(route_id, ensure_ascii=False),), + ) + conn.execute( + """ + INSERT INTO app_settings (key, value) VALUES ('use_structured_config', ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """, + (json.dumps(True, ensure_ascii=False),), + ) + return self.get_route(route_id) + + def delete_route(self, route_id): + with self._connect() as conn: + conn.execute("DELETE FROM routes WHERE id = ?", (route_id,)) + next_route = conn.execute("SELECT id FROM routes ORDER BY is_default DESC, id ASC LIMIT 1").fetchone() + self.set_setting("active_route_id", next_route["id"] if next_route else None) + + def set_active_route(self, route_id): + route = self.get_route(route_id) + if route is None: + raise ValueError(f"Route not found: {route_id}") + if not route["is_enabled"]: + raise ValueError(f"Route is disabled: {route['name']}") + self.set_setting("active_route_id", route["id"]) + self.set_setting("use_structured_config", True) + return route + + def get_test_config(self, family): + with self._connect() as conn: + row = conn.execute("SELECT * FROM model_test_config WHERE backend_family = ?", (family,)).fetchone() + return dict(row) if row else None + + def set_test_config(self, family, payload): + with self._connect() as conn: + conn.execute( + """ + INSERT INTO model_test_config (backend_family, test_model, api_mode, reasoning_effort, prompt) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(backend_family) DO UPDATE SET + test_model = excluded.test_model, + api_mode = excluded.api_mode, + reasoning_effort = excluded.reasoning_effort, + prompt = excluded.prompt + """, + ( + family, + payload.get("test_model"), + payload.get("api_mode"), + payload.get("reasoning_effort"), + payload.get("prompt") or "Reply with exactly: pong", + ), + ) + return self.get_test_config(family) + + def update_provider_health(self, provider_id, *, status, latency_ms=None, ttfb_ms=None, last_error="", checked_at=None): + with self._connect() as conn: + conn.execute( + """ + INSERT INTO provider_health (provider_id, status, latency_ms, ttfb_ms, last_checked_at, last_error) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(provider_id) DO UPDATE SET + status = excluded.status, + latency_ms = excluded.latency_ms, + ttfb_ms = excluded.ttfb_ms, + last_checked_at = excluded.last_checked_at, + last_error = excluded.last_error + """, + (provider_id, status, latency_ms, ttfb_ms, checked_at or utcnow_iso(), last_error or ""), + ) + + def append_diagnostic_event(self, *, provider_id=None, route_id=None, backend_name="", ok=False, error_kind=None, message="", status_code=None, extra=None): + created_at = utcnow_iso() + with self._connect() as conn: + conn.execute( + """ + INSERT INTO diagnostic_events + (provider_id, route_id, backend_name, ok, error_kind, message, status_code, created_at, extra_json) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + provider_id, + route_id, + backend_name or "", + 1 if ok else 0, + error_kind, + message or "", + status_code, + created_at, + self._dumps(extra), + ), + ) + + def list_diagnostic_events(self, limit=50): + with self._connect() as conn: + rows = conn.execute( + """ + SELECT * FROM diagnostic_events + ORDER BY id DESC + LIMIT ? + """, + (int(limit),), + ).fetchall() + events = [] + for row in rows: + event = dict(row) + event["ok"] = bool(event["ok"]) + event["extra"] = self._loads(event.pop("extra_json", None), {}) + events.append(event) + return events diff --git a/ga_switch/testing.py b/ga_switch/testing.py new file mode 100644 index 0000000..d9f4943 --- /dev/null +++ b/ga_switch/testing.py @@ -0,0 +1,67 @@ +import time + + +class ModelTester: + def __init__(self, service): + self.service = service + + def run(self, provider_id): + provider = self.service.store.get_provider(provider_id) + if not provider: + raise ValueError(f"Provider not found: {provider_id}") + family = provider["backend_family"] + test_cfg = self.service.store.get_test_config(family) or {} + client = self.service.build_client_from_provider( + provider, + route_id=None, + route_name=f"test:{provider['name']}", + route_kind="single", + for_testing=True, + override={ + "model": test_cfg.get("test_model") or provider.get("model"), + "api_mode": test_cfg.get("api_mode") or provider.get("api_mode"), + "reasoning_effort": test_cfg.get("reasoning_effort") if family == "oai" else provider.get("reasoning_effort"), + }, + ) + prompt = test_cfg.get("prompt") or "Reply with exactly: pong" + started = time.perf_counter() + first_chunk_at = None + raw_text = "" + response = None + gen = client.chat([{"role": "user", "content": prompt}], tools=None) + try: + while True: + chunk = next(gen) + if first_chunk_at is None: + first_chunk_at = time.perf_counter() + raw_text += chunk + except StopIteration as stop: + response = stop.value + finished = time.perf_counter() + backend = client.backend + latency_ms = int((finished - started) * 1000) + ttfb_ms = int(((first_chunk_at or finished) - started) * 1000) + last_error = getattr(backend, "last_error_message", "") or "" + error_kind = getattr(backend, "last_error_kind", None) + success = not (raw_text.startswith("Error:") or last_error) + status = "healthy" if success else "failed" + if success and latency_ms >= 15000: + status = "degraded" + self.service.store.update_provider_health( + provider_id, + status=status, + latency_ms=latency_ms, + ttfb_ms=ttfb_ms, + last_error=last_error, + ) + return { + "provider_id": provider_id, + "provider_name": provider["name"], + "status": status, + "latency_ms": latency_ms, + "ttfb_ms": ttfb_ms, + "last_error": last_error, + "error_kind": error_kind, + "raw_text": raw_text[:500], + "response_repr": repr(response) if response is not None else "", + } diff --git a/ga_switch/viewmodel.py b/ga_switch/viewmodel.py new file mode 100644 index 0000000..0367a22 --- /dev/null +++ b/ga_switch/viewmodel.py @@ -0,0 +1,363 @@ +import json + +from .models import is_native_backend_kind + + +SECTIONS = ( + ("overview", "总览"), + ("routes", "全部路由"), + ("providers", "模型服务"), + ("diagnostics", "诊断记录"), +) + + +def _bool(value): + if isinstance(value, str): + return value.strip().lower() in ("1", "true", "yes", "on") + return bool(value) + + +def _int(value, default=0): + if value in (None, ""): + return default + return int(value) + + +def _float(value, default=0.0): + if value in (None, ""): + return default + return float(value) + + +def _json(value, default=None): + if value in (None, ""): + return {} if default is None else default + if isinstance(value, dict): + return value + return json.loads(value) + + +def _caption(parts): + return " · ".join(str(part) for part in parts if part) + + +def _route_kind_label(kind): + return "备用链路" if kind == "failover" else "单路由" + + +def _route_subtitle(route, runtime): + if route.get("kind") == "single": + provider = (route.get("provider") or {}).get("name") or "未绑定模型服务" + return provider + members = " -> ".join(member.get("name", "") for member in route.get("members", [])) + return members or "未配置备用链路成员" + + +def _health_tone(last_error_kind, status=None): + if last_error_kind: + return "error" + if status == "healthy": + return "active" + if status in ("degraded", "failed"): + return "warn" if status == "degraded" else "error" + return "neutral" + + +def _health_label(last_error_kind, status=None): + if last_error_kind: + return f"最近错误:{last_error_kind}" + if status == "healthy": + return "状态正常" + if status == "degraded": + return "状态波动" + if status == "failed": + return "状态异常" + return "状态未知" + + +def _route_status_label(route, active=False): + if active: + return "当前生效" + if not route.get("is_enabled", True): + return "已停用" + return "待命" + + +def _overview_actions(has_routes): + base = [ + {"id": "import_legacy", "label": "导入 mykey", "primary": not has_routes}, + {"id": "create_provider", "label": "新建模型服务", "primary": False}, + {"id": "continue_chat", "label": "继续使用当前模型", "primary": False}, + ] + if has_routes: + return [ + {"id": "switch_route", "label": "切换到所选路由", "primary": True}, + {"id": "soft_reload", "label": "软重载", "primary": False}, + {"id": "import_legacy", "label": "导入 mykey", "primary": False}, + {"id": "more_actions", "label": "更多操作", "primary": False}, + ] + return base + + +def _route_edit_groups(): + return [ + { + "id": "basic", + "label": "基础信息", + "expanded": True, + "fields": ["name", "kind", "provider_id", "member_provider_ids", "is_default", "is_enabled"], + }, + { + "id": "advanced", + "label": "高级设置", + "expanded": False, + "fields": ["member_order", "max_retries", "base_delay", "spring_back"], + }, + ] + + +def _provider_edit_groups(): + return [ + { + "id": "basic", + "label": "基础信息", + "expanded": True, + "fields": ["name", "backend_kind", "model", "apibase"], + }, + { + "id": "advanced", + "label": "高级设置", + "expanded": False, + "fields": ["apikey", "api_mode", "temperature", "max_tokens", "timeout", "read_timeout", "proxy", "extra"], + }, + ] + + +def build_ui_viewmodel(snapshot): + routes = snapshot.get("routes", []) + providers = snapshot.get("providers", []) + runtime = snapshot.get("runtime", []) + events = snapshot.get("events", []) + runtime_by_route_id = snapshot.get("runtime_by_route_id", {}) + active_summary = dict(snapshot.get("active_route_summary") or {}) + active_runtime = snapshot.get("active_runtime") or next((item for item in runtime if item.get("active")), None) or {} + + summary = { + "route_id": active_summary.get("route_id") if active_summary.get("route_id") is not None else active_runtime.get("route_id"), + "route_name": active_summary.get("route_name") or active_runtime.get("name") or "当前模型", + "route_kind": active_summary.get("route_kind") or active_runtime.get("route_kind") or "single", + "route_kind_label": _route_kind_label(active_summary.get("route_kind") or active_runtime.get("route_kind") or "single"), + "provider_name": active_summary.get("provider_name") or active_runtime.get("provider_name") or "未配置模型服务", + "model": active_summary.get("model") or active_runtime.get("model") or "未指定模型", + "backend_class": active_summary.get("backend_class") or active_runtime.get("backend_class") or "未指定后端", + "backend_kind": active_summary.get("backend_kind") or active_runtime.get("backend_kind"), + "api_mode": active_summary.get("api_mode") or active_runtime.get("api_mode"), + "native_tools": bool(active_summary.get("native_tools") if active_summary.get("native_tools") is not None else active_runtime.get("native_tools")), + "active_member_name": active_summary.get("active_member_name") or active_runtime.get("active_member_name"), + "member_names": list(active_summary.get("member_names") or active_runtime.get("member_names") or []), + "last_error_kind": active_summary.get("last_error_kind") or active_runtime.get("last_error_kind"), + "last_error_message": active_summary.get("last_error_message") or active_runtime.get("last_error_message") or "", + "last_status_code": active_summary.get("last_status_code") or active_runtime.get("last_status_code"), + "last_switch_reason": active_summary.get("last_switch_reason") or active_runtime.get("last_switch_reason") or "", + "last_ok_at": active_summary.get("last_ok_at") or active_runtime.get("last_ok_at"), + "last_error_at": active_summary.get("last_error_at") or active_runtime.get("last_error_at"), + "headline": active_summary.get("route_name") or active_runtime.get("name") or "当前模型", + "meta": _caption([ + active_summary.get("provider_name") or active_runtime.get("provider_name"), + active_summary.get("model") or active_runtime.get("model"), + active_summary.get("backend_class") or active_runtime.get("backend_class"), + ]) or "继续使用当前模型", + } + summary["health_label"] = _health_label(summary["last_error_kind"]) + summary["health_tone"] = _health_tone(summary["last_error_kind"]) + + route_items = [] + for route in routes: + runtime_item = runtime_by_route_id.get(route["id"], {}) + member_ids = [member["id"] for member in route.get("members", [])] + active = route["id"] == snapshot.get("active_route_id") + last_error_kind = runtime_item.get("last_error_kind") + route_items.append({ + "id": route["id"], + "name": route["name"], + "title": route["name"], + "kind": route["kind"], + "kind_label": _route_kind_label(route["kind"]), + "active": active, + "enabled": bool(route.get("is_enabled", True)), + "status_label": _route_status_label(route, active=active), + "is_default": bool(route.get("is_default", False)), + "provider_id": ((route.get("provider") or {}).get("id")), + "provider_name": ((route.get("provider") or {}).get("name")), + "member_provider_ids": member_ids, + "member_names": [member.get("name", "") for member in route.get("members", [])], + "subtitle": _route_subtitle(route, runtime_item), + "model": runtime_item.get("model"), + "backend_class": runtime_item.get("backend_class"), + "backend_kind": runtime_item.get("backend_kind"), + "api_mode": runtime_item.get("api_mode"), + "native_tools": bool(runtime_item.get("native_tools")), + "active_member_name": runtime_item.get("active_member_name"), + "last_error_kind": last_error_kind, + "last_error_message": runtime_item.get("last_error_message") or "", + "last_switch_reason": runtime_item.get("last_switch_reason") or "", + "health_label": _health_label(last_error_kind), + "health_tone": _health_tone(last_error_kind), + "config": dict(route.get("config") or {}), + "edit_groups": _route_edit_groups(), + }) + + provider_items = [] + for provider in providers: + health = dict(provider.get("health") or {}) + health_status = health.get("status") or "unknown" + provider_items.append({ + "id": provider["id"], + "name": provider["name"], + "title": provider["name"], + "backend_kind": provider["backend_kind"], + "is_native": is_native_backend_kind(provider["backend_kind"]), + "model": provider.get("model") or "", + "api_mode": provider.get("api_mode") or "chat_completions", + "subtitle": _caption([provider.get("model"), provider.get("backend_kind"), provider.get("api_mode")]) or "未配置模型 ID", + "health_status": health_status, + "health_label": _health_label(None, status=health_status), + "health_tone": _health_tone(None, status=health_status), + "latency_ms": health.get("latency_ms"), + "ttfb_ms": health.get("ttfb_ms"), + "last_error": health.get("last_error") or "", + "payload": provider, + "edit_groups": _provider_edit_groups(), + }) + + event_items = [] + for event in events: + ok = bool(event.get("ok")) + tone = "active" if ok else "error" + event_items.append({ + "id": event["id"], + "route_id": event.get("route_id"), + "provider_id": event.get("provider_id"), + "title": event.get("backend_name") or f"记录 {event['id']}", + "subtitle": event.get("message") or "没有详细消息", + "created_at": event.get("created_at") or "", + "tone": tone, + "status_code": event.get("status_code"), + "error_kind": event.get("error_kind"), + "raw_label": "查看原始详情", + "payload": event, + }) + + runtime_items = [] + for item in runtime: + runtime_items.append({ + "id": item.get("route_id") if item.get("route_id") is not None else item.get("idx"), + "route_id": item.get("route_id"), + "idx": item.get("idx"), + "name": item.get("name") or item.get("display_name") or "运行时", + "title": item.get("name") or item.get("display_name") or "运行时", + "subtitle": _caption([item.get("provider_name"), item.get("model"), item.get("backend_class")]) or "暂无运行时信息", + "active": bool(item.get("active")), + "active_member_name": item.get("active_member_name"), + "last_error_kind": item.get("last_error_kind"), + "last_error_message": item.get("last_error_message") or "", + "payload": item, + }) + + has_routes = bool(routes) + empty_state = None + if not has_routes: + empty_state = { + "title": "还没有结构化路由", + "message": "可以先导入 mykey,或者新建模型服务。当前仍可继续使用现有模型。", + "actions": [ + {"id": "import_legacy", "label": "导入 mykey", "primary": True}, + {"id": "create_provider", "label": "新建模型服务", "primary": False}, + {"id": "continue_chat", "label": "继续使用当前模型", "primary": False}, + ], + } + + overview = { + "current_route_card": { + "title": "当前路由", + "headline": summary["headline"], + "subtitle": summary["meta"], + "status_label": summary["health_label"], + "status_tone": summary["health_tone"], + "badges": [ + summary["route_kind_label"], + summary["active_member_name"] or "", + "原生工具" if summary["native_tools"] else "文本接口", + ], + }, + "health_card": { + "title": "健康状态", + "headline": summary["health_label"], + "detail": summary["last_error_message"] or "最近没有记录到错误。", + "tone": summary["health_tone"], + }, + "quick_actions": _overview_actions(has_routes), + "route_summary_items": route_items[:5], + "has_routes": has_routes, + } + + return { + "sections": [{"id": section_id, "label": label} for section_id, label in SECTIONS], + "summary": summary, + "overview": overview, + "empty_state": empty_state, + "routes": route_items, + "providers": provider_items, + "events": event_items, + "runtime": runtime_items, + "stats": { + "provider_count": len(provider_items), + "route_count": len(route_items), + "runtime_count": len(runtime_items), + }, + "use_structured_config": bool(snapshot.get("use_structured_config")), + } + + +def build_provider_payload(values, provider_id=None): + payload = { + "name": str(values.get("name", "")).strip(), + "backend_kind": str(values.get("backend_kind", "oai_text")).strip() or "oai_text", + "apikey": str(values.get("apikey", "")).strip(), + "apibase": str(values.get("apibase", "")).strip(), + "model": str(values.get("model", "")).strip(), + "api_mode": str(values.get("api_mode", "chat_completions")).strip() or "chat_completions", + "temperature": _float(values.get("temperature"), 1.0), + "max_tokens": _int(values.get("max_tokens"), 8192), + "timeout": _int(values.get("timeout"), 5), + "read_timeout": _int(values.get("read_timeout"), 30), + "proxy": str(values.get("proxy", "")).strip() or None, + "extra": _json(values.get("extra"), default={}), + } + if provider_id is not None: + payload["id"] = provider_id + return payload + + +def build_route_payload(values, route_id=None): + kind = str(values.get("kind", "single")).strip() or "single" + payload = { + "name": str(values.get("name", "")).strip(), + "kind": kind, + "is_default": _bool(values.get("is_default")), + "is_enabled": _bool(values.get("is_enabled", True)), + "provider_id": values.get("provider_id"), + "member_provider_ids": list(values.get("member_provider_ids") or []), + "config": { + "max_retries": _int(values.get("max_retries"), 3), + "base_delay": _float(values.get("base_delay"), 1.5), + "spring_back": _int(values.get("spring_back"), 300), + }, + } + if kind == "single": + payload["member_provider_ids"] = [] + else: + payload["provider_id"] = None + if route_id is not None: + payload["id"] = route_id + return payload diff --git a/launch.pyw b/launch.pyw index 808316e..5f1a2f0 100644 --- a/launch.pyw +++ b/launch.pyw @@ -1,4 +1,4 @@ -import webview, threading, subprocess, sys, time, os, ctypes, atexit, socket, random +import threading, subprocess, sys, time, os, ctypes, atexit, socket, random WINDOW_WIDTH, WINDOW_HEIGHT, RIGHT_PADDING, TOP_PADDING = 600, 900, 0, 100 @@ -22,6 +22,27 @@ def start_streamlit(port): proc = subprocess.Popen(cmd) atexit.register(proc.kill) +def probe_qt_frontend(): + try: + check = subprocess.run( + [sys.executable, "-c", "from PySide6 import QtWidgets; print('qt-ok')"], + capture_output=True, + text=True, + timeout=20, + ) + except Exception as exc: + return False, str(exc) + return check.returncode == 0, (check.stderr or check.stdout or "").strip() + +def start_qt_frontend(): + qt_entry = os.path.join(frontends_dir, "qtapp.py") + qt_proc = subprocess.Popen([sys.executable, qt_entry]) + atexit.register(lambda: qt_proc.poll() is None and qt_proc.kill()) + time.sleep(3) + if qt_proc.poll() is None: + return qt_proc, None + return None, f"Qt frontend exited early with code {qt_proc.returncode}." + def inject(text): window.evaluate_js(f""" const textarea = document.querySelector('textarea[data-testid="stChatInputTextArea"]'); @@ -74,9 +95,6 @@ if __name__ == '__main__': parser.add_argument('--sched', action='store_true', help='启动计划任务调度器') parser.add_argument('--llm_no', type=int, default=0, help='LLM编号') args = parser.parse_args() - port = str(find_free_port()) if args.port == '0' else args.port - print(f'[Launch] Using port {port}') - threading.Thread(target=start_streamlit, args=(port,), daemon=True).start() if args.tg: tgproc = subprocess.Popen([sys.executable, os.path.join(frontends_dir, "tgapp.py")], creationflags=subprocess.CREATE_NO_WINDOW if os.name=='nt' else 0) @@ -114,12 +132,27 @@ if __name__ == '__main__': print('[Launch] Task Scheduler started (duplicate prevented by scheduler port lock)') else: print('[Launch] Task Scheduler not enabled (--sched)') + qt_ok, qt_probe_msg = probe_qt_frontend() + if qt_ok: + qt_proc, qt_error = start_qt_frontend() + if qt_proc is not None: + print('[Launch] Qt frontend started') + sys.exit(qt_proc.wait()) + print(f'[Launch] Qt frontend failed, falling back to Streamlit: {qt_error}') + else: + print(f'[Launch] Qt frontend unavailable, falling back to Streamlit: {qt_probe_msg}') + + port = str(find_free_port()) if args.port == '0' else args.port + print(f'[Launch] Using port {port}') + threading.Thread(target=start_streamlit, args=(port,), daemon=True).start() + monitor_thread = threading.Thread(target=idle_monitor, daemon=True) monitor_thread.start() if os.name == 'nt': screen_width = get_screen_width() x_pos = screen_width - WINDOW_WIDTH - RIGHT_PADDING else: x_pos = 100 + import webview time.sleep(2) window = webview.create_window( title='GenericAgent', url=f'http://localhost:{port}', diff --git a/llmcore.py b/llmcore.py index a8887fb..981a027 100644 --- a/llmcore.py +++ b/llmcore.py @@ -1,5 +1,6 @@ import os, json, re, time, requests, sys, threading, urllib3, base64, mimetypes, uuid from datetime import datetime +from ga_switch.diagnostics import classify_error, normalize_message, utcnow_iso urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) _RESP_CACHE_KEY = str(uuid.uuid4()) @@ -35,8 +36,10 @@ def _trunc(text): return text for i, msg in enumerate(messages): if i >= len(messages) - keep_recent: break - c = msg['content'] - if isinstance(c, str): msg['content'] = _trunc(c) + key = 'content' if 'content' in msg else ('prompt' if 'prompt' in msg else None) + if key is None: continue + c = msg[key] + if isinstance(c, str): msg[key] = _trunc(c) elif isinstance(c, list): for b in c: if not isinstance(b, dict): continue @@ -253,7 +256,7 @@ def _stamp_oai_cache_markers(messages, model): def _openai_stream(api_base, api_key, messages, model, api_mode='chat_completions', *, temperature=0.5, max_tokens=None, tools=None, reasoning_effort=None, - max_retries=0, connect_timeout=10, read_timeout=300, proxies=None): + max_retries=0, connect_timeout=10, read_timeout=300, proxies=None, session=None): """Shared OpenAI-compatible streaming request with retry. Yields text chunks, returns list[content_block].""" ml = model.lower() if 'kimi' in ml or 'moonshot' in ml: temperature = 1 @@ -287,6 +290,8 @@ def _delay(resp, attempt): try: ra = float((resp.headers or {}).get("retry-after")) except: ra = None return max(0.5, ra if ra is not None else min(30.0, 1.5 * (2 ** attempt))) + started_at = time.perf_counter() + first_chunk_at = None for attempt in range(max_retries + 1): streamed = False try: @@ -306,8 +311,23 @@ def _delay(resp, attempt): e._err_body = err_body; raise gen = _parse_openai_sse(r.iter_lines(), api_mode) try: - while True: streamed = True; yield next(gen) + while True: + chunk = next(gen) + streamed = True + if first_chunk_at is None: + first_chunk_at = time.perf_counter() + yield chunk except StopIteration as e: + if session is not None: + session._record_success( + status_code=r.status_code, + message="OK", + extra={ + "latency_ms": int((time.perf_counter() - started_at) * 1000), + "ttfb_ms": int((((first_chunk_at or time.perf_counter()) - started_at) * 1000)), + "api_mode": api_mode, + }, + ) return e.value or [] except requests.HTTPError as e: resp = getattr(e, "response", None); status = getattr(resp, "status_code", None) @@ -321,6 +341,8 @@ def _delay(resp, attempt): try: h = resp.headers or {}; rid = h.get("x-request-id","") or h.get("request-id",""); ra = h.get("retry-after",""); ct = h.get("content-type","") except: pass err = f"Error: HTTP {status} {e}; content_type: {ct or ''}; retry_after: {ra or ''}; request_id: {rid or ''}; body: {body or ''}" + if session is not None: + session._record_error(err, status_code=status, body=body, exc_type=type(e).__name__, extra={"request_id": rid, "retry_after": ra, "content_type": ct, "api_mode": api_mode}) yield err; return [{"type": "text", "text": err}] except (requests.Timeout, requests.ConnectionError) as e: if attempt < max_retries and not streamed: @@ -328,9 +350,13 @@ def _delay(resp, attempt): print(f"[LLM Retry] {type(e).__name__}, retry in {d:.1f}s ({attempt+1}/{max_retries+1})") time.sleep(d); continue err = f"Error: {type(e).__name__}: {e}" + if session is not None: + session._record_error(err, exc_type=type(e).__name__, extra={"api_mode": api_mode}) yield err; return [{"type": "text", "text": err}] except Exception as e: err = f"Error: {e}" + if session is not None: + session._record_error(err, exc_type=type(e).__name__, extra={"api_mode": api_mode}) yield err; return [{"type": "text", "text": err}] def _to_responses_input(messages): @@ -424,7 +450,7 @@ def __init__(self, cfg): self.max_retries = max(0, int(cfg.get('max_retries', 1))) self.stream = cfg.get('stream', True) default_ct, default_rt = (5, 30) if self.stream else (10, 240) - self.connect_timeout = max(1, int(cfg.get('timeout', default_ct))) + self.connect_timeout = max(1, int(cfg.get('timeout', cfg.get('connect_timeout', default_ct)))) self.read_timeout = max(5, int(cfg.get('read_timeout', default_rt))) def _enum(key, valid): v = cfg.get(key); v = None if v is None else str(v).strip().lower() @@ -436,6 +462,72 @@ def _enum(key, valid): self.api_mode = 'responses' if mode in ('responses', 'response') else 'chat_completions' self.temperature = cfg.get('temperature', 1) self.max_tokens = cfg.get('max_tokens', 8192) + self.last_error_kind = None + self.last_error_message = '' + self.last_error_at = None + self.last_ok_at = None + self.last_status_code = None + self.last_latency_ms = None + self.last_ttfb_ms = None + self.route_id = cfg.get('route_id') + self.route_name = cfg.get('route_name') + self.route_kind = cfg.get('route_kind') + self.provider_id = cfg.get('provider_id') + self.provider_name = cfg.get('provider_name', self.name) + self.backend_kind = cfg.get('backend_kind') + self._diagnostic_recorder = cfg.get('_diagnostic_recorder') + + def _record_success(self, status_code=200, message='OK', extra=None): + extra = dict(extra or {}) + self.last_status_code = status_code + self.last_ok_at = utcnow_iso() + self.last_error_kind = None + self.last_error_message = '' + if 'latency_ms' in extra: + self.last_latency_ms = extra.get('latency_ms') + if 'ttfb_ms' in extra: + self.last_ttfb_ms = extra.get('ttfb_ms') + if callable(self._diagnostic_recorder): + self._diagnostic_recorder({ + 'provider_id': self.provider_id, + 'route_id': self.route_id, + 'backend_name': self.name, + 'ok': True, + 'error_kind': None, + 'message': normalize_message(message), + 'status_code': status_code, + 'extra': extra, + }) + + def _record_error(self, message, *, status_code=None, body='', exc_type='', error_kind=None, extra=None): + kind = error_kind or classify_error(status_code=status_code, message=message, body=body, exc_type=exc_type) + self.last_error_kind = kind + self.last_error_message = normalize_message(message) + self.last_error_at = utcnow_iso() + self.last_status_code = status_code + if callable(self._diagnostic_recorder): + self._diagnostic_recorder({ + 'provider_id': self.provider_id, + 'route_id': self.route_id, + 'backend_name': self.name, + 'ok': False, + 'error_kind': kind, + 'message': self.last_error_message, + 'status_code': status_code, + 'extra': dict(extra or {}, body=normalize_message(body, 1200), exc_type=exc_type or None), + }) + + def describe_diagnostics(self): + return { + 'last_error_kind': self.last_error_kind, + 'last_error_message': self.last_error_message, + 'last_error_at': self.last_error_at, + 'last_ok_at': self.last_ok_at, + 'last_status_code': self.last_status_code, + 'last_latency_ms': self.last_latency_ms, + 'last_ttfb_ms': self.last_ttfb_ms, + } + def _apply_claude_thinking(self, payload): if self.thinking_type: thinking = {"type": self.thinking_type} @@ -474,11 +566,35 @@ def raw_ask(self, messages): if self.temperature != 1: payload["temperature"] = self.temperature self._apply_claude_thinking(payload) if self.system: payload["system"] = [{"type": "text", "text": self.system, "cache_control": {"type": "persistent"}}] + started_at = time.perf_counter() + first_chunk_at = None try: with requests.post(auto_make_url(self.api_base, "messages"), headers=headers, json=payload, stream=True, timeout=(self.connect_timeout, self.read_timeout)) as r: - if r.status_code != 200: raise Exception(f"HTTP {r.status_code} {r.content.decode('utf-8', errors='replace')[:500]}") - return (yield from _parse_claude_sse(r.iter_lines())) or [] + if r.status_code != 200: + body = r.content.decode('utf-8', errors='replace')[:500] + err = f"HTTP {r.status_code} {body}" + self._record_error(err, status_code=r.status_code, body=body, extra={"api_base": self.api_base}) + yield (err_text := f"Error: {err}") + return [{"type": "text", "text": err_text}] + gen = _parse_claude_sse(r.iter_lines()) + try: + while True: + chunk = next(gen) + if first_chunk_at is None: + first_chunk_at = time.perf_counter() + yield chunk + except StopIteration as e: + self._record_success( + status_code=r.status_code, + message="OK", + extra={ + "latency_ms": int((time.perf_counter() - started_at) * 1000), + "ttfb_ms": int((((first_chunk_at or time.perf_counter()) - started_at) * 1000)), + }, + ) + return e.value or [] except Exception as e: + self._record_error(str(e), exc_type=type(e).__name__, extra={"api_base": self.api_base}) yield (err := f"Error: {e}") return [{"type": "text", "text": err}] def make_messages(self, raw_list): @@ -493,7 +609,7 @@ def raw_ask(self, messages): return (yield from _openai_stream(self.api_base, self.api_key, messages, self.model, self.api_mode, temperature=self.temperature, reasoning_effort=self.reasoning_effort, max_tokens=self.max_tokens, max_retries=self.max_retries, - connect_timeout=self.connect_timeout, read_timeout=self.read_timeout, proxies=self.proxies)) + connect_timeout=self.connect_timeout, read_timeout=self.read_timeout, proxies=self.proxies, session=self)) def make_messages(self, raw_list): return _msgs_claude2oai(raw_list) def _fix_messages(messages): @@ -550,10 +666,34 @@ def raw_ask(self, messages): for idx in user_idxs[-2:]: messages[idx] = {**messages[idx], "content": list(messages[idx]["content"])} messages[idx]["content"][-1] = dict(messages[idx]["content"][-1], cache_control={"type": "ephemeral"}) + started_at = time.perf_counter() + first_chunk_at = None try: with requests.post(auto_make_url(self.api_base, "messages")+'?beta=true', headers=headers, json=payload, stream=self.stream, timeout=(self.connect_timeout, self.read_timeout)) as resp: - if resp.status_code != 200: raise Exception(f"HTTP {resp.status_code} {resp.content.decode('utf-8', errors='replace')[:500]}") - if self.stream: return (yield from _parse_claude_sse(resp.iter_lines())) or [] + if resp.status_code != 200: + body = resp.content.decode('utf-8', errors='replace')[:500] + err = f"HTTP {resp.status_code} {body}" + self._record_error(err, status_code=resp.status_code, body=body, extra={"api_base": self.api_base}) + yield (err_text := f"Error: {err}") + return [{"type": "text", "text": err_text}] + if self.stream: + gen = _parse_claude_sse(resp.iter_lines()) + try: + while True: + chunk = next(gen) + if first_chunk_at is None: + first_chunk_at = time.perf_counter() + yield chunk + except StopIteration as e: + self._record_success( + status_code=resp.status_code, + message="OK", + extra={ + "latency_ms": int((time.perf_counter() - started_at) * 1000), + "ttfb_ms": int((((first_chunk_at or time.perf_counter()) - started_at) * 1000)), + }, + ) + return e.value or [] else: data = resp.json(); content_blocks = data.get("content", []) usage = data.get("usage", {}) @@ -561,8 +701,14 @@ def raw_ask(self, messages): for b in content_blocks: if b.get("type") == "text": yield b.get("text", "") elif b.get("type") == "thinking": yield "" + self._record_success( + status_code=resp.status_code, + message="OK", + extra={"latency_ms": int((time.perf_counter() - started_at) * 1000), "ttfb_ms": 0}, + ) return content_blocks except Exception as e: + self._record_error(str(e), exc_type=type(e).__name__, extra={"api_base": self.api_base}) yield (err := f"Error: {e}") return [{"type": "text", "text": err}] @@ -603,7 +749,7 @@ def raw_ask(self, messages): temperature=self.temperature, max_tokens=self.max_tokens, tools=self.tools, reasoning_effort=self.reasoning_effort, max_retries=self.max_retries, connect_timeout=self.connect_timeout, - read_timeout=self.read_timeout, proxies=self.proxies)) + read_timeout=self.read_timeout, proxies=self.proxies, session=self)) def openai_tools_to_claude(tools): """[{type:'function', function:{name,description,parameters}}] → [{name,description,input_schema}].""" @@ -821,6 +967,16 @@ def __init__(self, all_sessions, cfg): self._sessions[0].raw_ask = self._raw_ask self.model = getattr(self._sessions[0], 'model', None) self._cur_idx, self._switched_at = 0, 0.0 + self.active_session_index = 0 + self.active_member_name = self._sessions[0].name + self.last_switch_reason = '' + self.last_error_kind = None + self.last_error_message = '' + self.last_error_at = None + self.last_ok_at = None + self.last_status_code = None + self.last_latency_ms = None + self.last_ttfb_ms = None def __getattr__(self, name): return getattr(self._sessions[0], name) _BROADCAST_ATTRS = frozenset({'system', 'tools', 'temperature', 'max_tokens', 'reasoning_effort'}) def __setattr__(self, name, value): @@ -831,8 +987,29 @@ def __setattr__(self, name, value): else: object.__setattr__(self, name, value) @property def primary(self): return self._sessions[0] + def _sync_diagnostics_from(self, session): + for attr in ('last_error_kind', 'last_error_message', 'last_error_at', 'last_ok_at', 'last_status_code', 'last_latency_ms', 'last_ttfb_ms'): + setattr(self, attr, getattr(session, attr, None)) + def describe_diagnostics(self): + return { + 'last_error_kind': self.last_error_kind, + 'last_error_message': self.last_error_message, + 'last_error_at': self.last_error_at, + 'last_ok_at': self.last_ok_at, + 'last_status_code': self.last_status_code, + 'last_latency_ms': self.last_latency_ms, + 'last_ttfb_ms': self.last_ttfb_ms, + 'active_session_index': self.active_session_index, + 'active_member_name': self.active_member_name, + 'last_switch_reason': self.last_switch_reason, + 'spring_back_seconds': self._spring_sec, + } def _pick(self): - if self._cur_idx and time.time() - self._switched_at > self._spring_sec: self._cur_idx = 0 + if self._cur_idx and time.time() - self._switched_at > self._spring_sec: + self._cur_idx = 0 + self.active_session_index = 0 + self.active_member_name = self._sessions[0].name + self.last_switch_reason = 'spring_back' return self._cur_idx def _raw_ask(self, *args, **kwargs): base, n = self._pick(), len(self._sessions) @@ -850,8 +1027,18 @@ def _raw_ask(self, *args, **kwargs): except StopIteration as e: return_val = e.value or [] is_err = test_error(last_chunk) if not is_err: - if attempt > 0: self._cur_idx = idx; self._switched_at = time.time() + if attempt > 0: + self._cur_idx = idx + self._switched_at = time.time() + self.last_switch_reason = f'fallback_success:{self._sessions[idx].name}' + self.active_session_index = idx + self.active_member_name = self._sessions[idx].name + self._sync_diagnostics_from(self._sessions[idx]) return return_val + self.active_session_index = idx + self.active_member_name = self._sessions[idx].name + self.last_switch_reason = (last_chunk or '')[:240] + self._sync_diagnostics_from(self._sessions[idx]) if attempt >= self._retries: yield last_chunk; return return_val nxt = (base + attempt + 1) % n @@ -915,6 +1102,12 @@ def chat(self, messages, tools=None): while True: chunk = next(gen); yield chunk except StopIteration as e: resp = e.value + if resp and hasattr(resp, 'content') and isinstance(resp.content, str) and not getattr(resp, 'thinking', ''): + think_pattern = r"(.*?)" + think_match = re.search(think_pattern, resp.content, re.DOTALL) + if think_match: + resp.thinking = think_match.group(1).strip() + resp.content = re.sub(think_pattern, "", resp.content, flags=re.DOTALL).strip() if resp: _write_llm_log('Response', resp.raw) if resp and hasattr(resp, 'tool_calls') and resp.tool_calls: self._pending_tool_ids = [tc.id for tc in resp.tool_calls] - return resp \ No newline at end of file + return resp diff --git a/requirements-api.txt b/requirements-api.txt new file mode 100644 index 0000000..32e2442 --- /dev/null +++ b/requirements-api.txt @@ -0,0 +1,3 @@ +fastapi>=0.100.0 +uvicorn>=0.20.0 +pydantic>=2.0.0 diff --git a/start_api_server.bat b/start_api_server.bat new file mode 100644 index 0000000..0ad5459 --- /dev/null +++ b/start_api_server.bat @@ -0,0 +1,6 @@ +@echo off +REM GA Switch API Server Startup Script +echo Starting GA Switch API Server... +cd /d "%~dp0" +python api_server.py +pause diff --git a/tests/test_ga_switch.py b/tests/test_ga_switch.py new file mode 100644 index 0000000..b8a05cb --- /dev/null +++ b/tests/test_ga_switch.py @@ -0,0 +1,373 @@ +import json +import os +import sys +import tempfile +import unittest +from unittest.mock import MagicMock, patch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +class GASwitchTestCase(unittest.TestCase): + def make_service(self): + from ga_switch.service import GASwitchService + + tmp_dir = tempfile.mkdtemp(prefix="ga-switch-test-") + return GASwitchService(os.path.join(tmp_dir, "ga-switch.db")) + + def make_oai_provider(self, service, name="p1", backend_kind="oai_text", model="gpt-4.1-mini"): + return service.upsert_provider({ + "name": name, + "backend_kind": backend_kind, + "apikey": "test-key", + "apibase": "https://api.example.com/v1", + "model": model, + }) + + +class TestGASwitchImport(GASwitchTestCase): + def test_import_legacy_mykey_json_to_structured_routes(self): + service = self.make_service() + tmp_dir = tempfile.mkdtemp(prefix="ga-switch-import-") + legacy_path = os.path.join(tmp_dir, "mykey.json") + with open(legacy_path, "w", encoding="utf-8") as f: + json.dump({ + "oai_config": { + "name": "kimi-primary", + "apikey": "k1", + "apibase": "https://api.example.com/v1", + "model": "kimi-k2.5", + }, + "oai_config_backup": { + "name": "glm-backup", + "apikey": "k2", + "apibase": "https://api.example.com/v1", + "model": "glm-5.1", + }, + "mixin_config": { + "name": "fallback-pair", + "llm_nos": ["kimi-primary", "glm-backup"], + "max_retries": 2, + "spring_back": 60, + }, + }, f, ensure_ascii=False) + + result = service.import_legacy_mykey(legacy_path) + + self.assertEqual(len(result["providers"]), 2) + self.assertEqual({route["kind"] for route in result["routes"]}, {"single", "failover"}) + failover = next(route for route in result["routes"] if route["kind"] == "failover") + self.assertEqual([member["name"] for member in failover["members"]], ["kimi-primary", "glm-backup"]) + self.assertTrue(service.use_structured_config()) + + +class TestGASwitchValidation(GASwitchTestCase): + def test_provider_update_round_trips(self): + service = self.make_service() + provider = self.make_oai_provider(service, name="editable") + + updated = service.upsert_provider({ + "id": provider["id"], + "name": "editable", + "backend_kind": "oai_text", + "apikey": "test-key-2", + "apibase": "https://api.example.com/v1", + "model": "updated-model", + "timeout": 9, + }) + + self.assertEqual(updated["model"], "updated-model") + self.assertEqual(updated["timeout"], 9) + + def test_failover_rejects_native_and_non_native_mix(self): + service = self.make_service() + native = self.make_oai_provider(service, name="native", backend_kind="native_oai") + text = self.make_oai_provider(service, name="text", backend_kind="oai_text") + + with self.assertRaisesRegex(ValueError, "cannot mix native and non-native"): + service.upsert_route({ + "name": "bad-route", + "kind": "failover", + "member_provider_ids": [native["id"], text["id"]], + }) + + +class TestGASwitchBuild(GASwitchTestCase): + def test_build_clients_from_store_maps_backend_classes(self): + service = self.make_service() + p1 = self.make_oai_provider(service, name="alpha", backend_kind="oai_text") + p2 = self.make_oai_provider(service, name="beta", backend_kind="oai_text") + service.upsert_route({"name": "alpha-route", "kind": "single", "provider_id": p1["id"], "is_default": True}) + service.upsert_route({"name": "fallback-route", "kind": "failover", "member_provider_ids": [p1["id"], p2["id"]]}) + + clients, meta = service.build_clients_from_store() + + self.assertEqual([type(client.backend).__name__ for client in clients], ["LLMSession", "MixinSession"]) + self.assertEqual(meta["active_index"], 0) + self.assertEqual(clients[1].ga_switch_route_kind, "failover") + + def test_get_ui_snapshot_includes_active_summary(self): + service = self.make_service() + p1 = self.make_oai_provider(service, name="alpha", backend_kind="oai_text", model="m1") + service.upsert_route({"name": "alpha-route", "kind": "single", "provider_id": p1["id"], "is_default": True}) + + with patch("agentmain.get_service", return_value=service): + from agentmain import GeneraticAgent + + agent = GeneraticAgent() + snapshot = service.get_ui_snapshot(agent) + + self.assertEqual(snapshot["active_route_summary"]["route_name"], "alpha-route") + self.assertEqual(snapshot["active_route_summary"]["provider_name"], "alpha") + self.assertEqual(snapshot["stats"]["route_count"], 1) + + +class TestDiagnostics(unittest.TestCase): + def test_classify_error_common_cases(self): + from ga_switch.diagnostics import classify_error + + self.assertEqual(classify_error(status_code=401, message="Unauthorized"), "auth") + self.assertEqual(classify_error(status_code=429, body="insufficient_quota"), "quota") + self.assertEqual(classify_error(status_code=429, body="rate limit exceeded"), "rate_limit") + self.assertEqual(classify_error(status_code=404, body="model not found"), "model_not_found") + self.assertEqual(classify_error(message="unsupported parameter reasoning_effort"), "unsupported_param") + self.assertEqual(classify_error(exc_type="Timeout", message="timed out"), "timeout") + self.assertEqual(classify_error(exc_type="ConnectionError", message="connection refused"), "network") + + +class TestAgentReload(GASwitchTestCase): + def test_reload_llm_config_preserves_history_and_blocks_running(self): + service = self.make_service() + p1 = self.make_oai_provider(service, name="route-a", backend_kind="oai_text", model="m1") + service.upsert_route({"name": "route-a", "kind": "single", "provider_id": p1["id"], "is_default": True}) + + with patch("agentmain.get_service", return_value=service): + from agentmain import GeneraticAgent + + agent = GeneraticAgent() + agent.llmclient.backend.history = [{"role": "user", "content": [{"type": "text", "text": "keep me"}]}] + + p2 = self.make_oai_provider(service, name="route-b", backend_kind="oai_text", model="m2") + service.upsert_route({"name": "route-b", "kind": "single", "provider_id": p2["id"]}) + described = agent.reload_llm_config() + + self.assertEqual(agent.llmclient.backend.history[0]["content"][0]["text"], "keep me") + self.assertEqual(agent.llmclient.ga_switch_route_name, "route-a") + self.assertEqual(len(described), 2) + + agent.is_running = True + with self.assertRaisesRegex(RuntimeError, "Cannot reload"): + agent.reload_llm_config() + + +class TestUIViewModel(unittest.TestCase): + def test_build_ui_viewmodel_maps_summary_and_lists(self): + from ga_switch.viewmodel import build_ui_viewmodel + + snapshot = { + "use_structured_config": True, + "active_route_id": 10, + "active_runtime": { + "route_id": 10, + "name": "primary-route", + "provider_name": "kimi", + "model": "kimi-k2.5", + "backend_class": "LLMSession", + "route_kind": "single", + "active": True, + }, + "active_route_summary": { + "route_id": 10, + "route_name": "primary-route", + "route_kind": "single", + "provider_name": "kimi", + "model": "kimi-k2.5", + "backend_class": "LLMSession", + "backend_kind": "oai_text", + "api_mode": "chat_completions", + "native_tools": False, + "member_names": [], + "active_member_name": "kimi", + "last_error_kind": "quota", + "last_error_message": "429 insufficient_quota", + "last_status_code": 429, + "last_switch_reason": "fallback_success:kimi", + }, + "providers": [{ + "id": 1, + "name": "kimi", + "backend_kind": "oai_text", + "model": "kimi-k2.5", + "api_mode": "chat_completions", + "health": {"status": "healthy", "latency_ms": 250, "ttfb_ms": 80, "last_error": ""}, + }], + "routes": [{ + "id": 10, + "name": "primary-route", + "kind": "single", + "is_enabled": True, + "is_default": True, + "provider": {"id": 1, "name": "kimi"}, + "members": [], + "config": {"max_retries": 3, "base_delay": 1.5, "spring_back": 300}, + }], + "runtime": [{ + "idx": 0, + "route_id": 10, + "name": "primary-route", + "display_name": "primary-route [LLMSession/kimi]", + "route_kind": "single", + "backend_class": "LLMSession", + "backend_kind": "oai_text", + "provider_name": "kimi", + "model": "kimi-k2.5", + "api_mode": "chat_completions", + "active": True, + "native_tools": False, + "member_names": [], + "active_member_name": "kimi", + "last_error_kind": "quota", + "last_error_message": "429 insufficient_quota", + }], + "runtime_by_route_id": { + 10: { + "idx": 0, + "route_id": 10, + "name": "primary-route", + "backend_class": "LLMSession", + "backend_kind": "oai_text", + "provider_name": "kimi", + "model": "kimi-k2.5", + "api_mode": "chat_completions", + "native_tools": False, + "active_member_name": "kimi", + "last_error_kind": "quota", + "last_error_message": "429 insufficient_quota", + "last_switch_reason": "fallback_success:kimi", + }, + }, + "events": [{ + "id": 99, + "route_id": 10, + "provider_id": 1, + "backend_name": "kimi", + "ok": False, + "error_kind": "quota", + "message": "429 insufficient_quota", + "status_code": 429, + "created_at": "2026-04-19T00:00:00Z", + }], + } + + vm = build_ui_viewmodel(snapshot) + + self.assertEqual( + [section["label"] for section in vm["sections"]], + ["总览", "全部路由", "模型服务", "诊断记录"], + ) + self.assertEqual(vm["summary"]["headline"], "primary-route") + self.assertEqual(vm["summary"]["route_kind_label"], "单路由") + self.assertEqual(vm["summary"]["provider_name"], "kimi") + self.assertEqual(vm["overview"]["current_route_card"]["title"], "当前路由") + self.assertEqual(vm["routes"][0]["last_error_kind"], "quota") + self.assertEqual(vm["routes"][0]["edit_groups"][1]["label"], "高级设置") + self.assertIn("member_order", vm["routes"][0]["edit_groups"][1]["fields"]) + self.assertEqual(vm["providers"][0]["health_status"], "healthy") + self.assertEqual(vm["providers"][0]["edit_groups"][1]["label"], "高级设置") + self.assertIn("apikey", vm["providers"][0]["edit_groups"][1]["fields"]) + self.assertIn("extra", vm["providers"][0]["edit_groups"][1]["fields"]) + self.assertEqual(vm["events"][0]["tone"], "error") + self.assertEqual(vm["events"][0]["raw_label"], "查看原始详情") + self.assertTrue(vm["runtime"][0]["active"]) + + def test_build_ui_viewmodel_empty_state_uses_chinese_actions(self): + from ga_switch.viewmodel import build_ui_viewmodel + + vm = build_ui_viewmodel({ + "use_structured_config": False, + "providers": [], + "routes": [], + "runtime": [], + "runtime_by_route_id": {}, + "events": [], + "active_route_summary": {}, + "active_runtime": {}, + }) + + self.assertIsNotNone(vm["empty_state"]) + self.assertEqual(vm["empty_state"]["title"], "还没有结构化路由") + self.assertEqual( + [action["label"] for action in vm["empty_state"]["actions"]], + ["导入 mykey", "新建模型服务", "继续使用当前模型"], + ) + self.assertEqual(vm["summary"]["headline"], "当前模型") + self.assertEqual(vm["summary"]["meta"], "继续使用当前模型") + + def test_build_provider_payload_parses_json_extra(self): + from ga_switch.viewmodel import build_provider_payload + + payload = build_provider_payload({ + "name": "glm", + "backend_kind": "oai_text", + "apikey": "test-key", + "apibase": "https://api.example.com/v1", + "model": "glm-5.1", + "api_mode": "responses", + "temperature": 0.7, + "max_tokens": 4096, + "timeout": 8, + "read_timeout": 45, + "proxy": "", + "extra": "{\"reasoning_effort\": \"low\"}", + }, provider_id=7) + + self.assertEqual(payload["id"], 7) + self.assertEqual(payload["api_mode"], "responses") + self.assertIsNone(payload["proxy"]) + self.assertEqual(payload["extra"]["reasoning_effort"], "low") + + def test_build_route_payload_preserves_failover_member_order(self): + from ga_switch.viewmodel import build_route_payload + + payload = build_route_payload({ + "name": "fallback", + "kind": "failover", + "is_default": True, + "is_enabled": True, + "member_provider_ids": [3, 1, 2], + "max_retries": 2, + "base_delay": 2.5, + "spring_back": 120, + }, route_id=15) + + self.assertEqual(payload["id"], 15) + self.assertEqual(payload["kind"], "failover") + self.assertEqual(payload["member_provider_ids"], [3, 1, 2]) + self.assertIsNone(payload["provider_id"]) + self.assertEqual(payload["config"]["spring_back"], 120) + + +class TestModelTester(GASwitchTestCase): + def test_model_test_uses_temporary_session_and_keeps_real_history(self): + service = self.make_service() + provider = self.make_oai_provider(service, name="tester", backend_kind="oai_text") + real_client = service.build_client_from_provider(provider, route_id=123, route_name="real-route") + real_client.backend.history = [{"role": "user", "content": [{"type": "text", "text": "real history"}]}] + + def fake_post(url, headers=None, json=None, stream=None, timeout=None, proxies=None): + resp = MagicMock() + resp.status_code = 200 + resp.iter_lines.return_value = iter([b"data: [DONE]"]) + resp.__enter__ = lambda s: s + resp.__exit__ = MagicMock(return_value=False) + return resp + + with patch("llmcore.requests.post", side_effect=fake_post): + result = service.run_model_test(provider["id"]) + + self.assertEqual(result["status"], "healthy") + self.assertEqual(real_client.backend.history[0]["content"][0]["text"], "real history") + + +if __name__ == "__main__": + unittest.main()