diff --git a/backend/alembic/versions/add_llm_concurrency_group.py b/backend/alembic/versions/add_llm_concurrency_group.py new file mode 100644 index 000000000..5c12f8356 --- /dev/null +++ b/backend/alembic/versions/add_llm_concurrency_group.py @@ -0,0 +1,20 @@ +"""add llm concurrency_group + +Revision ID: add_llm_concurrency_group +Revises: a1b2c3d4e5f6 +Create Date: 2026-04-07 00:00:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + +revision = "add_llm_concurrency_group" +down_revision = "a1b2c3d4e5f6" +branch_labels = None +depends_on = None + +def upgrade() -> None: + op.execute("ALTER TABLE llm_models ADD COLUMN IF NOT EXISTS concurrency_group VARCHAR(100)") + +def downgrade() -> None: + op.execute("ALTER TABLE llm_models DROP COLUMN IF EXISTS concurrency_group") diff --git a/backend/app/api/feishu.py b/backend/app/api/feishu.py index aff83664a..e0c2b5bf2 100644 --- a/backend/app/api/feishu.py +++ b/backend/app/api/feishu.py @@ -25,10 +25,6 @@ # The per-model request_timeout field takes precedence — see _get_llm_timeout(). _LLM_TIMEOUT_SECONDS_DEFAULT = 180.0 -# Number of tool status lines to keep visible in the Feishu card. -# Shows the last N non-running lines plus any active "running" entry. -_TOOL_STATUS_KEEP_LINES = 20 - def _get_llm_timeout(model) -> float: """Get effective LLM timeout for the Feishu channel. @@ -728,8 +724,6 @@ async def _feishu_file_sender(file_path, msg: str = ""): _FLUSH_INTERVAL_CARDKIT = 0.5 _FLUSH_INTERVAL_PATCH = 1.0 _agent_name = agent_obj.name if agent_obj else "AI 回复" - _tool_status_running: dict[str, str] = {} - _tool_status_done: list[str] = [] _patch_queue = _SerialPatchQueue() _heartbeat_task: asyncio.Task | None = None _llm_done = False @@ -741,49 +735,12 @@ def _build_card( answer_text: str, thinking_text: str = "", streaming: bool = False, - tool_status_lines: list[str] | None = None, agent_name: str | None = None, ) -> dict: - """Build a Feishu interactive card for streaming replies. - - Args: - answer_text: Main reply text (may be partial during streaming). - thinking_text: Reasoning/thinking content shown in a collapsed section. - streaming: If True, appends a cursor glyph to indicate in-progress output. - tool_status_lines: Override list for image streaming (which maintains its - own done-list; pass None to use the default text-streaming state). - agent_name: Override the default _agent_name (for image streaming context). - """ _name = agent_name if agent_name is not None else _agent_name elements = [] - # Tool status section. - # For the primary text-streaming path we use the split running/done dicts; - # callers may pass an explicit list (image streaming) as override. - if tool_status_lines is not None: - # Caller-supplied override (image path): plain list, no split needed. - if tool_status_lines: - elements.append({ - "tag": "markdown", - "content": "\n".join(tool_status_lines[-_TOOL_STATUS_KEEP_LINES:]), - }) - elements.append({"tag": "hr"}) - else: - # Primary text-streaming path: show done history + any still-running tools. - # _tool_status_running entries are removed when the tool completes, - # so only genuinely in-flight tools appear here. - done_visible = _tool_status_done[-_TOOL_STATUS_KEEP_LINES:] - running_visible = list(_tool_status_running.values()) - all_visible = done_visible + running_visible - if all_visible: - elements.append({ - "tag": "markdown", - "content": "\n".join(all_visible), - }) - elements.append({"tag": "hr"}) - - # Thinking section: collapsed grey block if thinking_text: think_preview = thinking_text[:200].replace("\n", " ") elements.append({ @@ -857,18 +814,7 @@ async def _flush_stream(reason: str, force: bool = False): return accumulated = "".join(_stream_buffer) if cardkit_card_id: - # Build composite content: tool status lines + answer text. - # This mirrors the IM Patch path where _build_card() includes the - # tool status section, so CardKit users also see which tools are - # running or completed during the LLM turn. - done_visible = _tool_status_done[-_TOOL_STATUS_KEEP_LINES:] - running_visible = list(_tool_status_running.values()) - all_tool_lines = done_visible + running_visible - if all_tool_lines: - tool_section = "\n".join(all_tool_lines) - cardkit_text = f"{tool_section}\n---\n{accumulated}" if accumulated else tool_section - else: - cardkit_text = accumulated + cardkit_text = accumulated if cardkit_text != _last_flushed_text: cardkit_sequence += 1 try: @@ -887,7 +833,7 @@ async def _flush_stream(reason: str, force: bool = False): logger.warning(f"[Feishu] CardKit stream failed: {e}") elif msg_id_for_patch: card = _build_card(accumulated, "".join(_thinking_buffer), streaming=True) - current_hash = hash(accumulated + "".join(_thinking_buffer) + str(_tool_status_done) + str(list(_tool_status_running.values()))) + current_hash = hash(accumulated + "".join(_thinking_buffer)) if reason == "heartbeat" and current_hash == _last_flushed_hash: return _last_flushed_hash = current_hash @@ -906,30 +852,6 @@ async def _ws_on_thinking(text: str): _thinking_buffer.append(text) await _flush_stream("thinking") - async def _ws_on_tool_call(evt: dict): - """Receive tool call status events and update the card's progress section. - - Uses the tool's call_id as the dict key so each tool shows only its - latest state. When a tool completes the "running" entry is removed from - _tool_status_running and a "done" line is appended to _tool_status_done, - ensuring finished tools never linger as ⏳ in the card. - """ - tool_name = evt.get("name") or "unknown_tool" - # Use call_id when available (unique per invocation); fall back to name. - call_id = evt.get("call_id") or tool_name - status = (evt.get("status") or "").lower() - if status == "running": - # Register as in-flight; will be removed when "done" arrives. - _tool_status_running[call_id] = f"⏳ Tool running: `{tool_name}`" - elif status == "done": - # Remove from running dict so the ⏳ icon disappears immediately. - _tool_status_running.pop(call_id, None) - _tool_status_done.append(f"✅ Tool done: `{tool_name}`") - else: - _tool_status_running.pop(call_id, None) - _tool_status_done.append(f"ℹ️ Tool update: `{tool_name}` ({status or 'unknown'})") - await _flush_stream("tool") - async def _heartbeat(): while not _llm_done: await asyncio.sleep(_FLUSH_INTERVAL_CARDKIT if cardkit_card_id else _FLUSH_INTERVAL_PATCH) @@ -948,7 +870,6 @@ async def _heartbeat(): user_id=platform_user_id, on_chunk=_ws_on_chunk, on_thinking=_ws_on_thinking, - on_tool_call=_ws_on_tool_call, ) finally: _llm_done = True diff --git a/backend/app/services/llm_concurrency.py b/backend/app/services/llm_concurrency.py new file mode 100644 index 000000000..40a0b7027 --- /dev/null +++ b/backend/app/services/llm_concurrency.py @@ -0,0 +1,212 @@ +"""Concurrency limiter for LLM API calls. + +Uses asyncio.Semaphore instances keyed by group name to limit concurrent LLM +requests per group. Groups can be custom-defined (by model / provider) or +fall back to per-provider defaults. +""" + +from __future__ import annotations + +import asyncio +import hashlib +from contextlib import asynccontextmanager +from typing import AsyncIterator + +from loguru import logger + + +_DEFAULT_CONCURRENCY: dict[str, int] = { + "openai": 5, + "anthropic": 5, + "deepseek": 3, + "qwen": 5, + "gemini": 5, + "azure": 5, + "ollama": 2, + "vllm": 2, + "openrouter": 5, + "minimax": 3, + "zhipu": 3, + "custom": 5, +} + + +class LLMConcurrencyError(Exception): + """Raised when acquiring a concurrency slot times out.""" + + def __init__(self, group: str, timeout: float) -> None: + self.group = group + self.timeout = timeout + super().__init__( + f"LLM concurrency limit reached for group '{group}': " + f"could not acquire a slot within {timeout}s. " + f"Consider increasing max_concurrency for this group." + ) + + +class ConcurrencyManager: + """Singleton-style manager that owns all concurrency semaphores.""" + + _instance: ConcurrencyManager | None = None + + def __new__(cls) -> ConcurrencyManager: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self) -> None: + if self._initialized: + return + self._semaphores: dict[str, asyncio.Semaphore] = {} + self._config: dict[str, dict] = {} + self._default_config: dict[str, int] = dict(_DEFAULT_CONCURRENCY) + self._active_counts: dict[str, int] = {} + self._locks: dict[str, asyncio.Lock] = {} + self._initialized = True + + def _api_key_hash(self, api_key: str) -> str: + return hashlib.sha256(api_key.encode()).hexdigest()[:16] + + def resolve_group(self, provider: str, model: str, api_key: str) -> str: + key_hash = self._api_key_hash(api_key) + for group_name, cfg in self._config.items(): + models: list[str] = cfg.get("models", []) + providers: list[str] = cfg.get("providers", []) + if model in models: + return f"{group_name}:{key_hash}" + if provider in providers and not models: + return f"{group_name}:{key_hash}" + return f"{provider}:{key_hash}" + + def _get_or_create_semaphore(self, group: str) -> asyncio.Semaphore: + if group in self._semaphores: + return self._semaphores[group] + + group_prefix = group.split(":")[0] + custom_cfg = self._config.get(group_prefix) + if custom_cfg: + max_c = custom_cfg.get("max_concurrency", _DEFAULT_CONCURRENCY.get("custom", 5)) + else: + max_c = self._default_config.get(group_prefix, _DEFAULT_CONCURRENCY.get("custom", 5)) + + sem = asyncio.Semaphore(max_c) + self._semaphores[group] = sem + self._active_counts.setdefault(group, 0) + self._locks.setdefault(group, asyncio.Lock()) + logger.debug("Created semaphore for group '{}' with max_concurrency={}", group, max_c) + return sem + + async def acquire( + self, + provider: str, + model: str, + api_key: str, + timeout: float = 60.0, + ) -> str: + group = self.resolve_group(provider, model, api_key) + sem = self._get_or_create_semaphore(group) + + try: + await asyncio.wait_for(sem.acquire(), timeout=timeout) + except asyncio.TimeoutError: + raise LLMConcurrencyError(group, timeout) + + lock = self._locks.setdefault(group, asyncio.Lock()) + async with lock: + self._active_counts[group] = self._active_counts.get(group, 0) + 1 + + return group + + async def release(self, group: str) -> None: + sem = self._semaphores.get(group) + if sem is None: + logger.warning("release() called for unknown group '{}'", group) + return + + lock = self._locks.setdefault(group, asyncio.Lock()) + async with lock: + count = self._active_counts.get(group, 0) + if count > 0: + self._active_counts[group] = count - 1 + + sem.release() + + def configure_groups(self, configs: list[dict]) -> None: + for cfg in configs: + name: str = cfg["name"] + self._config[name] = { + "max_concurrency": cfg.get("max_concurrency", 5), + "models": list(cfg.get("models", [])), + "providers": list(cfg.get("providers", [])), + } + + keys_to_reset = [ + k for k in list(self._semaphores) + if k.split(":")[0] in self._config + ] + for k in keys_to_reset: + del self._semaphores[k] + self._active_counts.pop(k, None) + self._locks.pop(k, None) + + logger.info( + "Updated concurrency groups: {}", + list(self._config.keys()), + ) + + def get_status(self) -> dict: + groups: dict[str, dict] = {} + + for group_name, cfg in self._config.items(): + active_entries = { + k: self._active_counts.get(k, 0) + for k in self._semaphores + if k.split(":")[0] == group_name + } + total_active = sum(active_entries.values()) + max_c = cfg.get("max_concurrency", 5) + groups[group_name] = { + "max_concurrency": max_c, + "active_count": total_active, + "models": cfg.get("models", []), + "providers": cfg.get("providers", []), + } + + for sem_key in self._semaphores: + prefix = sem_key.split(":")[0] + if prefix not in self._config: + if prefix not in groups: + max_c = self._default_config.get( + prefix, _DEFAULT_CONCURRENCY.get("custom", 5) + ) + groups[prefix] = { + "max_concurrency": max_c, + "active_count": 0, + "models": [], + "providers": [], + } + groups[prefix]["active_count"] += self._active_counts.get(sem_key, 0) + + return groups + + @classmethod + def reset(cls) -> None: + cls._instance = None + + +concurrency_manager = ConcurrencyManager() + + +@asynccontextmanager +async def concurrency_limit( + provider: str, + model: str, + api_key: str, + timeout: float = 60.0, +) -> AsyncIterator[str]: + group = await concurrency_manager.acquire(provider, model, api_key, timeout=timeout) + try: + yield group + finally: + await concurrency_manager.release(group)