Skip to content

Commit 4299cd0

Browse files
committed
feat: added support for groq provider
1 parent 2d8b602 commit 4299cd0

File tree

4 files changed

+36
-5
lines changed

4 files changed

+36
-5
lines changed

docs/providers.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,24 @@ To use OpenRouter, set your API key:
4141
4242
export OPENROUTER_API_KEY="your-api-key"
4343
44+
Groq
45+
----
46+
47+
To use Groq, set your API key:
48+
49+
.. code-block:: sh
50+
51+
export GROQ_API_KEY="your-api-key"
52+
53+
xAI
54+
---
55+
56+
To use xAI, set your API key:
57+
58+
.. code-block:: sh
59+
60+
export XAI_API_KEY="your-api-key"
61+
4462
Local
4563
-----
4664

gptme/llm.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
from collections.abc import Iterator
55
from functools import lru_cache
6+
from typing import cast
67

78
from rich import print
89

@@ -17,7 +18,12 @@
1718
from .llm_openai import init as init_openai
1819
from .llm_openai import stream as stream_openai
1920
from .message import Message, format_msgs, len_tokens
20-
from .models import MODELS, Provider, get_summary_model
21+
from .models import (
22+
MODELS,
23+
PROVIDERS_OPENAI,
24+
Provider,
25+
get_summary_model,
26+
)
2127
from .tools import ToolUse
2228

2329
logger = logging.getLogger(__name__)
@@ -27,7 +33,8 @@ def init_llm(llm: str):
2733
# set up API_KEY (if openai) and API_BASE (if local)
2834
config = get_config()
2935

30-
if llm in ["openai", "azure", "openrouter", "local", "xai"]:
36+
llm = cast(Provider, llm)
37+
if llm in PROVIDERS_OPENAI:
3138
init_openai(llm, config)
3239
assert get_openai_client()
3340
elif llm == "anthropic":

gptme/llm_openai.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .config import Config
66
from .constants import TEMPERATURE, TOP_P
77
from .message import Message, msgs2dicts
8+
from .models import Provider
89

910
if TYPE_CHECKING:
1011
from openai import OpenAI
@@ -21,7 +22,7 @@
2122
}
2223

2324

24-
def init(provider: str, config: Config):
25+
def init(provider: Provider, config: Config):
2526
global openai
2627
from openai import AzureOpenAI, OpenAI # fmt: skip
2728

@@ -42,6 +43,9 @@ def init(provider: str, config: Config):
4243
elif provider == "xai":
4344
api_key = config.get_env_required("XAI_API_KEY")
4445
openai = OpenAI(api_key=api_key, base_url="https://api.x.ai/v1")
46+
elif provider == "groq":
47+
api_key = config.get_env_required("GROQ_API_KEY")
48+
openai = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
4549
elif provider == "local":
4650
# OPENAI_API_BASE renamed to OPENAI_BASE_URL: https://github.com/openai/openai-python/issues/745
4751
api_base = config.get_env("OPENAI_API_BASE")

gptme/models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ class _ModelDictMeta(TypedDict):
3737

3838

3939
# available providers
40-
Provider = Literal["openai", "anthropic", "azure", "openrouter", "xai", "local"]
41-
PROVIDERS = get_args(Provider)
40+
Provider = Literal["openai", "anthropic", "azure", "openrouter", "groq", "xai", "local"]
41+
PROVIDERS: list[Provider] = cast(list[Provider], get_args(Provider))
42+
PROVIDERS_OPENAI: list[Provider]
43+
PROVIDERS_OPENAI = ["openai", "azure", "openrouter", "xai", "groq", "local"]
4244

4345
# default model
4446
DEFAULT_MODEL: ModelMeta | None = None

0 commit comments

Comments
 (0)