Skip to content

Commit 9c05ac2

Browse files
feat: implement get_model TODO and fix critical telemetry bug (#647)
* feat: implement get_model TODO and fix critical telemetry bug - Enhanced get_model() to use _get_models_for_provider for dynamic model discovery - Integrated with existing dynamic fetching infrastructure - Added proper fallback handling for unknown providers/models - Improved parsing logic to handle provider/model formats correctly - Fixed critical telemetry bug in OpenAI LLM where stripped model names caused warnings - Changed _record_usage calls to pass full model name instead of base_model - Resolves "Unknown model x-ai/grok-4-fast:free" warnings during evals - Makes OpenAI implementation consistent with Anthropic - Improved reasoning model detection using metadata instead of hardcoded checks - Replaced _is_reasoner() function with model_meta.supports_reasoning - Updated extra_body() to use ModelMeta parameter for better reasoning support - Enhanced message preparation logic for reasoning models - Added comprehensive test suite with 10 focused tests - Tests static/dynamic model lookup, provider-only requests - Validates error handling and fallback scenarios - Covers new dynamic model fetching integration All changes maintain backwards compatibility while significantly enhancing model discovery capabilities and fixing evaluation warnings. * Update gptme/llm/models.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
1 parent 62534a9 commit 9c05ac2

File tree

3 files changed

+217
-50
lines changed

3 files changed

+217
-50
lines changed

gptme/llm/llm_openai.py

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,6 @@ def _prep_deepseek_reasoner(msgs: list[Message]) -> Generator[Message, None, Non
171171
yield from _merge_consecutive(_prep_o1(msgs[1:]))
172172

173173

174-
def _is_reasoner(base_model: str) -> bool:
175-
is_o1 = any(base_model.startswith(om) for om in ["o1", "o3", "o4"])
176-
is_deepseek_reasoner = base_model == "deepseek-reasoner"
177-
is_gpt5 = base_model.startswith("gpt-5")
178-
return is_o1 or is_deepseek_reasoner or is_gpt5
179-
180-
181174
@lru_cache(maxsize=2)
182175
def _is_proxy(client: "OpenAI") -> bool:
183176
proxy_url = get_config().get_env("LLM_PROXY_URL")
@@ -193,13 +186,15 @@ def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> s
193186
# top_p controls diversity, temperature controls randomness
194187

195188
from . import _get_base_model, get_provider_from_model # fmt: skip
189+
from .models import get_model # fmt: skip
196190

197191
provider = get_provider_from_model(model)
198192
client = get_client(provider)
199193
is_proxy = _is_proxy(client)
200194

201195
base_model = _get_base_model(model)
202-
is_reasoner = _is_reasoner(base_model)
196+
model_meta = get_model(model)
197+
is_reasoner = model_meta.supports_reasoning
203198

204199
# make the model name prefix with the provider if using LLM_PROXY, to make proxy aware of the provider
205200
api_model = model if is_proxy else base_model
@@ -216,9 +211,9 @@ def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> s
216211
top_p=TOP_P if not is_reasoner else NOT_GIVEN,
217212
tools=tools_dict if tools_dict else NOT_GIVEN,
218213
extra_headers=extra_headers(provider),
219-
extra_body=extra_body(provider, base_model),
214+
extra_body=extra_body(provider, model_meta),
220215
)
221-
_record_usage(response.usage, base_model)
216+
_record_usage(response.usage, model)
222217
choice = response.choices[0]
223218
result = []
224219
if choice.finish_reason == "tool_calls":
@@ -252,12 +247,14 @@ def extra_headers(provider: Provider) -> dict[str, str]:
252247
return headers
253248

254249

255-
def extra_body(provider: Provider, base_model: str) -> dict[str, Any]:
250+
def extra_body(provider: Provider, model_meta: ModelMeta) -> dict[str, Any]:
256251
"""Return extra body for the OpenAI API based on the model."""
257252
body: dict[str, Any] = {}
258253
if provider == "openrouter":
259-
if ":" in base_model:
260-
provider_override = base_model.split(":")[1]
254+
if model_meta.supports_reasoning:
255+
body["reasoning"] = {"enabled": True, "max_tokens": 20000}
256+
if "@" in model_meta.model:
257+
provider_override = model_meta.model.split("@")[1]
261258
body["provider"] = {
262259
"order": [provider_override],
263260
"allow_fallbacks": False,
@@ -269,13 +266,15 @@ def stream(
269266
messages: list[Message], model: str, tools: list[ToolSpec] | None
270267
) -> Generator[str, None, None]:
271268
from . import _get_base_model, get_provider_from_model # fmt: skip
269+
from .models import get_model # fmt: skip
272270

273271
provider = get_provider_from_model(model)
274272
client = get_client(provider)
275273
is_proxy = _is_proxy(client)
276274

277275
base_model = _get_base_model(model)
278-
is_reasoner = _is_reasoner(base_model)
276+
model_meta = get_model(model)
277+
is_reasoner = model_meta.supports_reasoning
279278

280279
# make the model name prefix with the provider if using LLM_PROXY, to make proxy aware of the provider
281280
api_model = model if is_proxy else base_model
@@ -294,7 +293,7 @@ def stream(
294293
stream=True,
295294
tools=tools_dict if tools_dict else NOT_GIVEN,
296295
extra_headers=extra_headers(provider),
297-
extra_body=extra_body(provider, base_model),
296+
extra_body=extra_body(provider, model_meta),
298297
stream_options={"include_usage": True},
299298
):
300299
from openai.types.chat import ChatCompletionChunk # fmt: skip
@@ -308,7 +307,7 @@ def stream(
308307

309308
# Record usage if available (typically in final chunk)
310309
if hasattr(chunk, "usage") and chunk.usage:
311-
_record_usage(chunk.usage, base_model)
310+
_record_usage(chunk.usage, model)
312311

313312
if not chunk.choices:
314313
continue
@@ -574,47 +573,51 @@ def get_available_models(provider: Provider) -> list[ModelMeta]:
574573
def openrouter_model_to_modelmeta(model_data: dict) -> ModelMeta:
575574
"""Convert OpenRouter model data to ModelMeta object."""
576575
pricing = model_data.get("pricing", {})
576+
price_input = float(pricing.get("prompt", 0)) * 1_000_000
577+
price_output = float(pricing.get("completion", 0)) * 1_000_000
578+
vision = "vision" in model_data.get("architecture", {}).get("modality", "")
579+
reasoning = "reasoning" in model_data.get("supported_parameters", [])
580+
include_reasoning = "include_reasoning" in model_data.get(
581+
"supported_parameters", []
582+
)
577583

578584
return ModelMeta(
579585
provider="openrouter",
580586
model=model_data.get("id", ""),
581587
context=model_data.get("context_length", 128_000),
582588
max_output=model_data.get("max_completion_tokens"),
583589
supports_streaming=True, # Most OpenRouter models support streaming
584-
supports_vision="vision"
585-
in model_data.get("architecture", {}).get("modality", ""),
586-
supports_reasoning=False, # Would need to check model-specific capabilities
587-
price_input=float(pricing.get("prompt", 0))
588-
* 1_000_000, # Convert to per-1M tokens
589-
price_output=float(pricing.get("completion", 0))
590-
* 1_000_000, # Convert to per-1M tokens
590+
supports_vision=vision,
591+
supports_reasoning=reasoning and include_reasoning,
592+
price_input=price_input,
593+
price_output=price_output,
591594
)
592595

593596

594597
def _prepare_messages_for_api(
595598
messages: list[Message], model: str, tools: list[ToolSpec] | None
596599
) -> tuple[Iterable[dict], Iterable["ChatCompletionToolParam"] | None]:
597-
from . import _get_base_model # fmt: skip
598600
from .models import get_model # fmt: skip
599601

600602
model_meta = get_model(model)
601603

602-
is_o1 = _get_base_model(model).startswith("o1")
603-
if is_o1:
604-
messages = list(_prep_o1(messages))
605-
606-
# without this, deepseek-chat and reasoner can start outputting gibberish after tool calls
607-
# similarly, kimi-k2-instruct doesn't acknowledge tool responses/system messages without it, same with magistral
608-
# it probably applies to more models/providers, we should figure out which and perhaps make it default behavior
609-
# TODO: it seems to apply to a lot of reasoning models, should maybe be default behavior for all reasoning models?
604+
# o1 models need _prep_o1 applied to ALL messages (including first), but no merging
610605
if any(
611-
m in model_meta.model
612-
for m in [
613-
"deepseek-reasoner",
614-
"deepseek-chat",
615-
"kimi-k2-instruct",
616-
"magistral-medium-2506",
617-
]
606+
model_meta.model.startswith(om) for om in ["o1", "o3", "o4"]
607+
) or model_meta.model.startswith("gpt-5"):
608+
messages = list(_prep_o1(messages))
609+
# other reasoning models use deepseek reasoner prep (first message unchanged, then _prep_o1 on rest)
610+
elif (
611+
any(
612+
m in model_meta.model
613+
for m in [
614+
"deepseek-reasoner",
615+
"deepseek-chat",
616+
"kimi-k2",
617+
"magistral",
618+
]
619+
)
620+
or model_meta.supports_reasoning
618621
):
619622
messages = list(_prep_deepseek_reasoner(messages))
620623

gptme/llm/models.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,20 @@ class _ModelDictMeta(TypedDict):
306306
"price_output": 0.3,
307307
"supports_vision": True,
308308
},
309+
"moonshotai/kimi-k2": {
310+
"context": 262_144,
311+
"max_output": 262_144,
312+
"price_input": 0.38,
313+
"price_output": 1.52,
314+
"supports_vision": True,
315+
},
316+
"moonshotai/kimi-k2-0905": {
317+
"context": 262_144,
318+
"max_output": 262_144,
319+
"price_input": 0.38,
320+
"price_output": 1.52,
321+
"supports_vision": True,
322+
},
309323
},
310324
"nvidia": {},
311325
"azure": {},
@@ -351,24 +365,58 @@ def get_model(model: str) -> ModelMeta:
351365
model = get_recommended_model(provider)
352366
return get_model(f"{provider}/{model}")
353367

354-
if any(f"{provider}/" in model for provider in PROVIDERS):
355-
provider, model = cast(tuple[Provider, str], model.split("/", 1))
356-
if provider not in MODELS or model not in MODELS[provider]:
368+
# Check if model has provider/model format
369+
if any(model.startswith(f"{provider}/") for provider in PROVIDERS):
370+
provider_str, model_name = model.split("/", 1)
371+
372+
# Check if provider is known
373+
if provider_str in PROVIDERS:
374+
provider = cast(Provider, provider_str)
375+
376+
# First try static MODELS dict for performance
377+
if provider in MODELS and model_name in MODELS[provider]:
378+
return ModelMeta(provider, model_name, **MODELS[provider][model_name])
379+
380+
# For providers that support dynamic fetching, use _get_models_for_provider
381+
if provider == "openrouter":
382+
try:
383+
models = _get_models_for_provider(provider, dynamic_fetch=True)
384+
for model_meta in models:
385+
if model_meta.model == model_name:
386+
return model_meta
387+
except Exception:
388+
# Fall back to unknown model metadata
389+
pass
390+
391+
# Unknown model, use fallback metadata
357392
if provider not in ["openrouter", "local"]:
358393
log_warn_once(
359-
f"Unknown model: using fallback metadata for {provider}/{model}"
394+
f"Unknown model: using fallback metadata for {provider}/{model_name}"
360395
)
361-
return ModelMeta(provider, model, context=128_000)
362-
else:
363-
# try to find model in all providers
364-
for provider in MODELS:
365-
if model in MODELS[provider]:
366-
break
396+
return ModelMeta(provider, model_name, context=128_000)
367397
else:
398+
# Unknown provider
368399
logger.warning(f"Unknown model {model}, using fallback metadata")
369400
return ModelMeta(provider="unknown", model=model, context=128_000)
401+
else:
402+
# try to find model in all providers, starting with static models
403+
for provider in cast(list[Provider], MODELS.keys()):
404+
if model in MODELS[provider]:
405+
return ModelMeta(provider, model, **MODELS[provider][model])
406+
407+
# For model name without provider, also try dynamic fetching for openrouter
408+
try:
409+
openrouter_models = _get_models_for_provider(
410+
"openrouter", dynamic_fetch=True
411+
)
412+
for model_meta in openrouter_models:
413+
if model_meta.model == model:
414+
return model_meta
415+
except Exception:
416+
pass
370417

371-
return ModelMeta(provider, model, **MODELS[provider][model])
418+
logger.warning(f"Unknown model {model}, using fallback metadata")
419+
return ModelMeta(provider="unknown", model=model, context=128_000)
372420

373421

374422
def get_recommended_model(provider: Provider) -> str: # pragma: no cover

tests/test_llm_models.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from unittest.mock import patch
2+
3+
from gptme.llm.models import (
4+
get_model,
5+
_get_models_for_provider,
6+
ModelMeta,
7+
)
8+
9+
10+
def test_get_static_model():
11+
"""Test getting a model that exists in static MODELS dict."""
12+
model = get_model("openai/gpt-4o")
13+
assert model.provider == "openai"
14+
assert model.model == "gpt-4o"
15+
assert model.context > 0
16+
17+
18+
def test_get_model_provider_only():
19+
"""Test getting recommended model when only provider is given."""
20+
model = get_model("openai")
21+
assert model.provider == "openai"
22+
assert model.model == "gpt-5" # current recommended model
23+
24+
25+
def test_get_model_unknown_provider_model():
26+
"""Test fallback for unknown provider/model combination."""
27+
model = get_model("unknown-provider/unknown-model")
28+
assert model.provider == "unknown"
29+
assert model.model == "unknown-provider/unknown-model"
30+
assert model.context == 128_000 # fallback context
31+
32+
33+
def test_get_model_by_name_only():
34+
"""Test getting model by name only (searches all providers)."""
35+
model = get_model("gpt-4o")
36+
assert model.provider == "openai"
37+
assert model.model == "gpt-4o"
38+
39+
40+
def test_get_model_unknown_name_only():
41+
"""Test fallback for unknown model name without provider."""
42+
model = get_model("completely-unknown-model")
43+
assert model.provider == "unknown"
44+
assert model.model == "completely-unknown-model"
45+
assert model.context == 128_000
46+
47+
48+
@patch("gptme.llm.models._get_models_for_provider")
49+
def test_get_model_dynamic_fetch_success(mock_get_models):
50+
"""Test successful dynamic model fetching for OpenRouter."""
51+
# Mock a dynamic model
52+
dynamic_model = ModelMeta(
53+
provider="openrouter",
54+
model="test-dynamic-model",
55+
context=100_000,
56+
price_input=1.0,
57+
price_output=2.0,
58+
)
59+
mock_get_models.return_value = [dynamic_model]
60+
61+
model = get_model("openrouter/test-dynamic-model")
62+
assert model.provider == "openrouter"
63+
assert model.model == "test-dynamic-model"
64+
assert model.context == 100_000
65+
assert model.price_input == 1.0
66+
67+
mock_get_models.assert_called_once_with("openrouter", dynamic_fetch=True)
68+
69+
70+
@patch("gptme.llm.models._get_models_for_provider")
71+
def test_get_model_dynamic_fetch_failure(mock_get_models):
72+
"""Test fallback when dynamic model fetching fails."""
73+
mock_get_models.side_effect = Exception("API error")
74+
75+
model = get_model("openrouter/test-dynamic-model")
76+
assert model.provider == "openrouter"
77+
assert model.model == "test-dynamic-model"
78+
assert model.context == 128_000 # fallback
79+
80+
81+
@patch("gptme.llm.models._get_models_for_provider")
82+
def test_get_model_dynamic_fetch_model_not_found(mock_get_models):
83+
"""Test fallback when dynamic model is not found in results."""
84+
other_model = ModelMeta(provider="openrouter", model="other-model", context=100_000)
85+
mock_get_models.return_value = [other_model]
86+
87+
model = get_model("openrouter/test-dynamic-model")
88+
assert model.provider == "openrouter"
89+
assert model.model == "test-dynamic-model"
90+
assert model.context == 128_000 # fallback
91+
92+
93+
def test_get_models_for_provider():
94+
"""Test getting models for a specific provider."""
95+
# Test with static models only
96+
openai_models = _get_models_for_provider("openai", dynamic_fetch=False)
97+
assert len(openai_models) > 0
98+
assert all(m.provider == "openai" for m in openai_models)
99+
100+
101+
@patch("gptme.llm.models._get_models_for_provider")
102+
def test_get_model_name_only_with_dynamic_fetch(mock_get_models):
103+
"""Test model lookup by name only with dynamic fetching from OpenRouter."""
104+
# Mock OpenRouter dynamic model
105+
dynamic_model = ModelMeta(
106+
provider="openrouter", model="test-model", context=100_000
107+
)
108+
mock_get_models.return_value = [dynamic_model]
109+
110+
model = get_model("test-model")
111+
assert model.provider == "openrouter"
112+
assert model.model == "test-model"
113+
assert model.context == 100_000
114+
115+
# Should have tried OpenRouter dynamic fetch
116+
mock_get_models.assert_called_with("openrouter", dynamic_fetch=True)

0 commit comments

Comments
 (0)