Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions backend/alembic/versions/add_llm_concurrency_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""add llm concurrency_group

Revision ID: add_llm_concurrency_group
Revises: a1b2c3d4e5f6
Create Date: 2026-04-07 00:00:00.000000

"""
from alembic import op
import sqlalchemy as sa

revision = "add_llm_concurrency_group"
down_revision = "a1b2c3d4e5f6"
branch_labels = None
depends_on = None

def upgrade() -> None:
op.execute("ALTER TABLE llm_models ADD COLUMN IF NOT EXISTS concurrency_group VARCHAR(100)")

def downgrade() -> None:
op.execute("ALTER TABLE llm_models DROP COLUMN IF EXISTS concurrency_group")
83 changes: 2 additions & 81 deletions backend/app/api/feishu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
# The per-model request_timeout field takes precedence — see _get_llm_timeout().
_LLM_TIMEOUT_SECONDS_DEFAULT = 180.0

# Number of tool status lines to keep visible in the Feishu card.
# Shows the last N non-running lines plus any active "running" entry.
_TOOL_STATUS_KEEP_LINES = 20


def _get_llm_timeout(model) -> float:
"""Get effective LLM timeout for the Feishu channel.
Expand Down Expand Up @@ -728,8 +724,6 @@ async def _feishu_file_sender(file_path, msg: str = ""):
_FLUSH_INTERVAL_CARDKIT = 0.5
_FLUSH_INTERVAL_PATCH = 1.0
_agent_name = agent_obj.name if agent_obj else "AI 回复"
_tool_status_running: dict[str, str] = {}
_tool_status_done: list[str] = []
_patch_queue = _SerialPatchQueue()
_heartbeat_task: asyncio.Task | None = None
_llm_done = False
Expand All @@ -741,49 +735,12 @@ def _build_card(
answer_text: str,
thinking_text: str = "",
streaming: bool = False,
tool_status_lines: list[str] | None = None,
agent_name: str | None = None,
) -> dict:
"""Build a Feishu interactive card for streaming replies.

Args:
answer_text: Main reply text (may be partial during streaming).
thinking_text: Reasoning/thinking content shown in a collapsed section.
streaming: If True, appends a cursor glyph to indicate in-progress output.
tool_status_lines: Override list for image streaming (which maintains its
own done-list; pass None to use the default text-streaming state).
agent_name: Override the default _agent_name (for image streaming context).
"""
_name = agent_name if agent_name is not None else _agent_name

elements = []

# Tool status section.
# For the primary text-streaming path we use the split running/done dicts;
# callers may pass an explicit list (image streaming) as override.
if tool_status_lines is not None:
# Caller-supplied override (image path): plain list, no split needed.
if tool_status_lines:
elements.append({
"tag": "markdown",
"content": "\n".join(tool_status_lines[-_TOOL_STATUS_KEEP_LINES:]),
})
elements.append({"tag": "hr"})
else:
# Primary text-streaming path: show done history + any still-running tools.
# _tool_status_running entries are removed when the tool completes,
# so only genuinely in-flight tools appear here.
done_visible = _tool_status_done[-_TOOL_STATUS_KEEP_LINES:]
running_visible = list(_tool_status_running.values())
all_visible = done_visible + running_visible
if all_visible:
elements.append({
"tag": "markdown",
"content": "\n".join(all_visible),
})
elements.append({"tag": "hr"})

# Thinking section: collapsed grey block
if thinking_text:
think_preview = thinking_text[:200].replace("\n", " ")
elements.append({
Expand Down Expand Up @@ -857,18 +814,7 @@ async def _flush_stream(reason: str, force: bool = False):
return
accumulated = "".join(_stream_buffer)
if cardkit_card_id:
# Build composite content: tool status lines + answer text.
# This mirrors the IM Patch path where _build_card() includes the
# tool status section, so CardKit users also see which tools are
# running or completed during the LLM turn.
done_visible = _tool_status_done[-_TOOL_STATUS_KEEP_LINES:]
running_visible = list(_tool_status_running.values())
all_tool_lines = done_visible + running_visible
if all_tool_lines:
tool_section = "\n".join(all_tool_lines)
cardkit_text = f"{tool_section}\n---\n{accumulated}" if accumulated else tool_section
else:
cardkit_text = accumulated
cardkit_text = accumulated
if cardkit_text != _last_flushed_text:
cardkit_sequence += 1
try:
Expand All @@ -887,7 +833,7 @@ async def _flush_stream(reason: str, force: bool = False):
logger.warning(f"[Feishu] CardKit stream failed: {e}")
elif msg_id_for_patch:
card = _build_card(accumulated, "".join(_thinking_buffer), streaming=True)
current_hash = hash(accumulated + "".join(_thinking_buffer) + str(_tool_status_done) + str(list(_tool_status_running.values())))
current_hash = hash(accumulated + "".join(_thinking_buffer))
if reason == "heartbeat" and current_hash == _last_flushed_hash:
return
_last_flushed_hash = current_hash
Expand All @@ -906,30 +852,6 @@ async def _ws_on_thinking(text: str):
_thinking_buffer.append(text)
await _flush_stream("thinking")

async def _ws_on_tool_call(evt: dict):
"""Receive tool call status events and update the card's progress section.

Uses the tool's call_id as the dict key so each tool shows only its
latest state. When a tool completes the "running" entry is removed from
_tool_status_running and a "done" line is appended to _tool_status_done,
ensuring finished tools never linger as ⏳ in the card.
"""
tool_name = evt.get("name") or "unknown_tool"
# Use call_id when available (unique per invocation); fall back to name.
call_id = evt.get("call_id") or tool_name
status = (evt.get("status") or "").lower()
if status == "running":
# Register as in-flight; will be removed when "done" arrives.
_tool_status_running[call_id] = f"⏳ Tool running: `{tool_name}`"
elif status == "done":
# Remove from running dict so the ⏳ icon disappears immediately.
_tool_status_running.pop(call_id, None)
_tool_status_done.append(f"✅ Tool done: `{tool_name}`")
else:
_tool_status_running.pop(call_id, None)
_tool_status_done.append(f"ℹ️ Tool update: `{tool_name}` ({status or 'unknown'})")
await _flush_stream("tool")

async def _heartbeat():
while not _llm_done:
await asyncio.sleep(_FLUSH_INTERVAL_CARDKIT if cardkit_card_id else _FLUSH_INTERVAL_PATCH)
Expand All @@ -948,7 +870,6 @@ async def _heartbeat():
user_id=platform_user_id,
on_chunk=_ws_on_chunk,
on_thinking=_ws_on_thinking,
on_tool_call=_ws_on_tool_call,
)
finally:
_llm_done = True
Expand Down
212 changes: 212 additions & 0 deletions backend/app/services/llm_concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
"""Concurrency limiter for LLM API calls.

Uses asyncio.Semaphore instances keyed by group name to limit concurrent LLM
requests per group. Groups can be custom-defined (by model / provider) or
fall back to per-provider defaults.
"""

from __future__ import annotations

import asyncio
import hashlib
from contextlib import asynccontextmanager
from typing import AsyncIterator

from loguru import logger


_DEFAULT_CONCURRENCY: dict[str, int] = {
"openai": 5,
"anthropic": 5,
"deepseek": 3,
"qwen": 5,
"gemini": 5,
"azure": 5,
"ollama": 2,
"vllm": 2,
"openrouter": 5,
"minimax": 3,
"zhipu": 3,
"custom": 5,
}


class LLMConcurrencyError(Exception):
"""Raised when acquiring a concurrency slot times out."""

def __init__(self, group: str, timeout: float) -> None:
self.group = group
self.timeout = timeout
super().__init__(
f"LLM concurrency limit reached for group '{group}': "
f"could not acquire a slot within {timeout}s. "
f"Consider increasing max_concurrency for this group."
)


class ConcurrencyManager:
"""Singleton-style manager that owns all concurrency semaphores."""

_instance: ConcurrencyManager | None = None

def __new__(cls) -> ConcurrencyManager:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance

def __init__(self) -> None:
if self._initialized:
return
self._semaphores: dict[str, asyncio.Semaphore] = {}
self._config: dict[str, dict] = {}
self._default_config: dict[str, int] = dict(_DEFAULT_CONCURRENCY)
self._active_counts: dict[str, int] = {}
self._locks: dict[str, asyncio.Lock] = {}
self._initialized = True

def _api_key_hash(self, api_key: str) -> str:
return hashlib.sha256(api_key.encode()).hexdigest()[:16]

def resolve_group(self, provider: str, model: str, api_key: str) -> str:
key_hash = self._api_key_hash(api_key)
for group_name, cfg in self._config.items():
models: list[str] = cfg.get("models", [])
providers: list[str] = cfg.get("providers", [])
if model in models:
return f"{group_name}:{key_hash}"
if provider in providers and not models:
return f"{group_name}:{key_hash}"
return f"{provider}:{key_hash}"

def _get_or_create_semaphore(self, group: str) -> asyncio.Semaphore:
if group in self._semaphores:
return self._semaphores[group]

group_prefix = group.split(":")[0]
custom_cfg = self._config.get(group_prefix)
if custom_cfg:
max_c = custom_cfg.get("max_concurrency", _DEFAULT_CONCURRENCY.get("custom", 5))
else:
max_c = self._default_config.get(group_prefix, _DEFAULT_CONCURRENCY.get("custom", 5))

sem = asyncio.Semaphore(max_c)
self._semaphores[group] = sem
self._active_counts.setdefault(group, 0)
self._locks.setdefault(group, asyncio.Lock())
logger.debug("Created semaphore for group '{}' with max_concurrency={}", group, max_c)
return sem

async def acquire(
self,
provider: str,
model: str,
api_key: str,
timeout: float = 60.0,
) -> str:
group = self.resolve_group(provider, model, api_key)
sem = self._get_or_create_semaphore(group)

try:
await asyncio.wait_for(sem.acquire(), timeout=timeout)
except asyncio.TimeoutError:
raise LLMConcurrencyError(group, timeout)

lock = self._locks.setdefault(group, asyncio.Lock())
async with lock:
self._active_counts[group] = self._active_counts.get(group, 0) + 1

return group

async def release(self, group: str) -> None:
sem = self._semaphores.get(group)
if sem is None:
logger.warning("release() called for unknown group '{}'", group)
return

lock = self._locks.setdefault(group, asyncio.Lock())
async with lock:
count = self._active_counts.get(group, 0)
if count > 0:
self._active_counts[group] = count - 1

sem.release()

def configure_groups(self, configs: list[dict]) -> None:
for cfg in configs:
name: str = cfg["name"]
self._config[name] = {
"max_concurrency": cfg.get("max_concurrency", 5),
"models": list(cfg.get("models", [])),
"providers": list(cfg.get("providers", [])),
}

keys_to_reset = [
k for k in list(self._semaphores)
if k.split(":")[0] in self._config
]
for k in keys_to_reset:
del self._semaphores[k]
self._active_counts.pop(k, None)
self._locks.pop(k, None)

logger.info(
"Updated concurrency groups: {}",
list(self._config.keys()),
)

def get_status(self) -> dict:
groups: dict[str, dict] = {}

for group_name, cfg in self._config.items():
active_entries = {
k: self._active_counts.get(k, 0)
for k in self._semaphores
if k.split(":")[0] == group_name
}
total_active = sum(active_entries.values())
max_c = cfg.get("max_concurrency", 5)
groups[group_name] = {
"max_concurrency": max_c,
"active_count": total_active,
"models": cfg.get("models", []),
"providers": cfg.get("providers", []),
}

for sem_key in self._semaphores:
prefix = sem_key.split(":")[0]
if prefix not in self._config:
if prefix not in groups:
max_c = self._default_config.get(
prefix, _DEFAULT_CONCURRENCY.get("custom", 5)
)
groups[prefix] = {
"max_concurrency": max_c,
"active_count": 0,
"models": [],
"providers": [],
}
groups[prefix]["active_count"] += self._active_counts.get(sem_key, 0)

return groups

@classmethod
def reset(cls) -> None:
cls._instance = None


concurrency_manager = ConcurrencyManager()


@asynccontextmanager
async def concurrency_limit(
provider: str,
model: str,
api_key: str,
timeout: float = 60.0,
) -> AsyncIterator[str]:
group = await concurrency_manager.acquire(provider, model, api_key, timeout=timeout)
try:
yield group
finally:
await concurrency_manager.release(group)