Skip to content

Commit eec8215

Browse files
committed
refactor: refactored provider-specific code into new files llm_openai.py and llm_anthropic.py
1 parent 580cc36 commit eec8215

File tree

4 files changed

+212
-158
lines changed

4 files changed

+212
-158
lines changed

gptme/constants.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
"""
2+
Constants
3+
"""
4+
5+
# Optimized for code
6+
# Discussion here: https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683
7+
# TODO: make these configurable
8+
TEMPERATURE = 0
9+
TOP_P = 0.1
10+
111
# prefix for commands, e.g. /help
212
CMDFIX = "/"
313

gptme/llm.py

Lines changed: 38 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,45 @@
11
import logging
22
import shutil
33
import sys
4-
from collections.abc import Generator, Iterator
4+
from collections.abc import Iterator
5+
from typing import Literal
56

6-
from anthropic import Anthropic
7-
from openai import AzureOpenAI, OpenAI
87
from rich import print
98

9+
from .llm_anthropic import chat as chat_anthropic
10+
from .llm_anthropic import get_client as get_anthropic_client
11+
from .llm_anthropic import init as init_anthropic
12+
from .llm_anthropic import stream as stream_anthropic
13+
from .llm_openai import chat as chat_openai
14+
from .llm_openai import get_client as get_openai_client
15+
from .llm_openai import init as init_openai
16+
from .llm_openai import stream as stream_openai
1017
from .config import get_config
1118
from .constants import PROMPT_ASSISTANT
12-
from .message import Message, len_tokens, msgs2dicts
19+
from .message import Message, len_tokens
1320
from .models import MODELS, get_summary_model
1421
from .util import extract_codeblocks
1522

16-
# Optimized for code
17-
# Discussion here: https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683
18-
# TODO: make these configurable
19-
temperature = 0
20-
top_p = 0.1
21-
2223
logger = logging.getLogger(__name__)
2324

24-
oai_client: OpenAI | None = None
25-
anthropic_client: Anthropic | None = None
2625

26+
Provider = Literal["openai", "anthropic", "azure", "openrouter", "local"]
2727

28-
def init_llm(llm: str):
29-
global oai_client, anthropic_client
3028

29+
def init_llm(llm: str):
3130
# set up API_KEY (if openai) and API_BASE (if local)
3231
config = get_config()
3332

34-
if llm == "openai":
35-
api_key = config.get_env_required("OPENAI_API_KEY")
36-
oai_client = OpenAI(api_key=api_key)
37-
elif llm == "azure":
38-
api_key = config.get_env_required("AZURE_OPENAI_API_KEY")
39-
azure_endpoint = config.get_env_required("AZURE_OPENAI_ENDPOINT")
40-
oai_client = AzureOpenAI(
41-
api_key=api_key,
42-
api_version="2023-07-01-preview",
43-
azure_endpoint=azure_endpoint,
44-
)
33+
if llm in ["openai", "azure", "openrouter", "local"]:
34+
init_openai(llm, config)
35+
assert get_openai_client()
4536
elif llm == "anthropic":
46-
api_key = config.get_env_required("ANTHROPIC_API_KEY")
47-
anthropic_client = Anthropic(
48-
api_key=api_key,
49-
max_retries=5,
50-
)
51-
elif llm == "openrouter":
52-
api_key = config.get_env_required("OPENROUTER_API_KEY")
53-
oai_client = OpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1")
54-
elif llm == "local":
55-
api_base = config.get_env_required("OPENAI_API_BASE")
56-
api_key = config.get_env("OPENAI_API_KEY") or "ollama"
57-
oai_client = OpenAI(api_key=api_key, base_url=api_base)
37+
init_anthropic(config)
38+
assert get_anthropic_client()
5839
else:
5940
print(f"Error: Unknown LLM: {llm}")
6041
sys.exit(1)
6142

62-
# ensure we have initialized the client
63-
assert oai_client or anthropic_client
64-
6543

6644
def reply(messages: list[Message], model: str, stream: bool = False) -> Message:
6745
if stream:
@@ -74,128 +52,26 @@ def reply(messages: list[Message], model: str, stream: bool = False) -> Message:
7452
return Message("assistant", response)
7553

7654

77-
def _chat_complete_openai(messages: list[Message], model: str) -> str:
78-
# This will generate code and such, so we need appropriate temperature and top_p params
79-
# top_p controls diversity, temperature controls randomness
80-
assert oai_client, "LLM not initialized"
81-
response = oai_client.chat.completions.create(
82-
model=model,
83-
messages=msgs2dicts(messages, openai=True), # type: ignore
84-
temperature=temperature,
85-
top_p=top_p,
86-
)
87-
content = response.choices[0].message.content
88-
assert content
89-
return content
90-
91-
92-
def _chat_complete_anthropic(messages: list[Message], model: str) -> str:
93-
assert anthropic_client, "LLM not initialized"
94-
messages, system_message = _transform_system_messages_anthropic(messages)
95-
response = anthropic_client.messages.create(
96-
model=model,
97-
messages=msgs2dicts(messages, anthropic=True), # type: ignore
98-
system=system_message,
99-
temperature=temperature,
100-
top_p=top_p,
101-
max_tokens=4096,
102-
)
103-
content = response.content
104-
assert content
105-
assert len(content) == 1
106-
return content[0].text # type: ignore
107-
108-
10955
def _chat_complete(messages: list[Message], model: str) -> str:
110-
if oai_client:
111-
return _chat_complete_openai(messages, model)
112-
elif anthropic_client:
113-
return _chat_complete_anthropic(messages, model)
56+
provider = _client_to_provider()
57+
if provider == "openai":
58+
return chat_openai(messages, model)
59+
elif provider == "anthropic":
60+
return chat_anthropic(messages, model)
11461
else:
11562
raise ValueError("LLM not initialized")
11663

11764

118-
def _transform_system_messages_anthropic(
119-
messages: list[Message],
120-
) -> tuple[list[Message], str]:
121-
# transform system messages into system kwarg for anthropic
122-
# for first system message, transform it into a system kwarg
123-
assert messages[0].role == "system"
124-
system_prompt = messages[0].content
125-
messages.pop(0)
126-
127-
# for any subsequent system messages, transform them into a <system> message
128-
for i, message in enumerate(messages):
129-
if message.role == "system":
130-
messages[i] = Message(
131-
"user",
132-
content=f"<system>{message.content}</system>",
133-
)
134-
135-
# find consecutive user role messages and merge them into a single <system> message
136-
messages_new: list[Message] = []
137-
while messages:
138-
message = messages.pop(0)
139-
if messages_new and messages_new[-1].role == "user":
140-
messages_new[-1] = Message(
141-
"user",
142-
content=f"{messages_new[-1].content}\n{message.content}",
143-
)
144-
else:
145-
messages_new.append(message)
146-
messages = messages_new
147-
148-
return messages, system_prompt
149-
150-
15165
def _stream(messages: list[Message], model: str) -> Iterator[str]:
152-
if oai_client:
153-
return _stream_openai(messages, model)
154-
elif anthropic_client:
155-
return _stream_anthropic(messages, model)
66+
provider = _client_to_provider()
67+
if provider == "openai":
68+
return stream_openai(messages, model)
69+
elif provider == "anthropic":
70+
return stream_anthropic(messages, model)
15671
else:
15772
raise ValueError("LLM not initialized")
15873

15974

160-
def _stream_openai(messages: list[Message], model: str) -> Generator[str, None, None]:
161-
assert oai_client, "LLM not initialized"
162-
stop_reason = None
163-
for chunk in oai_client.chat.completions.create(
164-
model=model,
165-
messages=msgs2dicts(messages, openai=True), # type: ignore
166-
temperature=temperature,
167-
top_p=top_p,
168-
stream=True,
169-
# the llama-cpp-python server needs this explicitly set, otherwise unreliable results
170-
# TODO: make this better
171-
max_tokens=1000 if not model.startswith("gpt-") else 4096,
172-
):
173-
if not chunk.choices: # type: ignore
174-
# Got a chunk with no choices, Azure always sends one of these at the start
175-
continue
176-
stop_reason = chunk.choices[0].finish_reason # type: ignore
177-
content = chunk.choices[0].delta.content # type: ignore
178-
if content:
179-
yield content
180-
logger.debug(f"Stop reason: {stop_reason}")
181-
182-
183-
def _stream_anthropic(
184-
messages: list[Message], model: str
185-
) -> Generator[str, None, None]:
186-
messages, system_prompt = _transform_system_messages_anthropic(messages)
187-
assert anthropic_client, "LLM not initialized"
188-
with anthropic_client.messages.stream(
189-
model=model,
190-
messages=msgs2dicts(messages, anthropic=True), # type: ignore
191-
system=system_prompt,
192-
temperature=temperature,
193-
top_p=top_p,
194-
max_tokens=4096,
195-
) as stream:
196-
yield from stream.text_stream
197-
198-
19975
def _reply_stream(messages: list[Message], model: str) -> Message:
20076
print(f"{PROMPT_ASSISTANT}: Thinking...", end="\r")
20177

@@ -236,11 +112,14 @@ def print_clear():
236112
return Message("assistant", output)
237113

238114

239-
def _client_to_provider() -> str:
240-
if oai_client:
241-
if "openai" in oai_client.base_url.host:
115+
def _client_to_provider() -> Provider:
116+
openai_client = get_openai_client()
117+
anthropic_client = get_anthropic_client()
118+
assert openai_client or anthropic_client, "No client initialized"
119+
if openai_client:
120+
if "openai" in openai_client.base_url.host:
242121
return "openai"
243-
elif "openrouter" in oai_client.base_url.host:
122+
elif "openrouter" in openai_client.base_url.host:
244123
return "openrouter"
245124
else:
246125
return "azure"
@@ -265,8 +144,9 @@ def summarize(content: str) -> str:
265144
Message("user", content=f"Summarize this:\n{content}"),
266145
]
267146

268-
model = get_summary_model(_client_to_provider())
269-
context_limit = MODELS["openai" if oai_client else "anthropic"][model]["context"]
147+
provider = _client_to_provider()
148+
model = get_summary_model(provider)
149+
context_limit = MODELS[provider][model]["context"]
270150
if len_tokens(messages) > context_limit:
271151
raise ValueError(
272152
f"Cannot summarize more than {context_limit} tokens, got {len_tokens(messages)}"

gptme/llm_anthropic.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from collections.abc import Generator
2+
3+
from anthropic import Anthropic
4+
5+
from .constants import TEMPERATURE, TOP_P
6+
from .message import Message, msgs2dicts
7+
8+
anthropic: Anthropic | None = None
9+
10+
11+
def init(config):
12+
global anthropic
13+
api_key = config.get_env_required("ANTHROPIC_API_KEY")
14+
anthropic = Anthropic(
15+
api_key=api_key,
16+
max_retries=5,
17+
)
18+
19+
20+
def get_client() -> Anthropic | None:
21+
return anthropic
22+
23+
24+
def chat(messages: list[Message], model: str) -> str:
25+
assert anthropic, "LLM not initialized"
26+
messages, system_messages = _transform_system_messages(messages)
27+
response = anthropic.messages.create(
28+
model=model,
29+
messages=msgs2dicts(messages, anthropic=True), # type: ignore
30+
system=system_messages,
31+
temperature=TEMPERATURE,
32+
top_p=TOP_P,
33+
max_tokens=4096,
34+
)
35+
content = response.content
36+
assert content
37+
assert len(content) == 1
38+
return content[0].text # type: ignore
39+
40+
41+
def stream(messages: list[Message], model: str) -> Generator[str, None, None]:
42+
messages, system_messages = _transform_system_messages(messages)
43+
assert anthropic, "LLM not initialized"
44+
with anthropic.messages.stream(
45+
model=model,
46+
messages=msgs2dicts(messages, anthropic=True), # type: ignore
47+
system=system_messages,
48+
temperature=TEMPERATURE,
49+
top_p=TOP_P,
50+
max_tokens=4096,
51+
) as stream:
52+
yield from stream.text_stream
53+
54+
55+
def _transform_system_messages(
56+
messages: list[Message],
57+
) -> tuple[list[Message], str]:
58+
# transform system messages into system kwarg for anthropic
59+
# for first system message, transform it into a system kwarg
60+
assert messages[0].role == "system"
61+
system_prompt = messages[0].content
62+
messages.pop(0)
63+
64+
# for any subsequent system messages, transform them into a <system> message
65+
for i, message in enumerate(messages):
66+
if message.role == "system":
67+
messages[i] = Message(
68+
"user",
69+
content=f"<system>{message.content}</system>",
70+
)
71+
72+
# find consecutive user role messages and merge them into a single <system> message
73+
messages_new: list[Message] = []
74+
while messages:
75+
message = messages.pop(0)
76+
if messages_new and messages_new[-1].role == "user":
77+
messages_new[-1] = Message(
78+
"user",
79+
content=f"{messages_new[-1].content}\n{message.content}",
80+
)
81+
else:
82+
messages_new.append(message)
83+
messages = messages_new
84+
85+
return messages, system_prompt

0 commit comments

Comments
 (0)