1
1
import logging
2
2
import shutil
3
3
import sys
4
- from collections .abc import Generator , Iterator
4
+ from collections .abc import Iterator
5
+ from typing import Literal
5
6
6
- from anthropic import Anthropic
7
- from openai import AzureOpenAI , OpenAI
8
7
from rich import print
9
8
9
+ from .llm_anthropic import chat as chat_anthropic
10
+ from .llm_anthropic import get_client as get_anthropic_client
11
+ from .llm_anthropic import init as init_anthropic
12
+ from .llm_anthropic import stream as stream_anthropic
13
+ from .llm_openai import chat as chat_openai
14
+ from .llm_openai import get_client as get_openai_client
15
+ from .llm_openai import init as init_openai
16
+ from .llm_openai import stream as stream_openai
10
17
from .config import get_config
11
18
from .constants import PROMPT_ASSISTANT
12
- from .message import Message , len_tokens , msgs2dicts
19
+ from .message import Message , len_tokens
13
20
from .models import MODELS , get_summary_model
14
21
from .util import extract_codeblocks
15
22
16
- # Optimized for code
17
- # Discussion here: https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683
18
- # TODO: make these configurable
19
- temperature = 0
20
- top_p = 0.1
21
-
22
23
logger = logging .getLogger (__name__ )
23
24
24
- oai_client : OpenAI | None = None
25
- anthropic_client : Anthropic | None = None
26
25
26
+ Provider = Literal ["openai" , "anthropic" , "azure" , "openrouter" , "local" ]
27
27
28
- def init_llm (llm : str ):
29
- global oai_client , anthropic_client
30
28
29
+ def init_llm (llm : str ):
31
30
# set up API_KEY (if openai) and API_BASE (if local)
32
31
config = get_config ()
33
32
34
- if llm == "openai" :
35
- api_key = config .get_env_required ("OPENAI_API_KEY" )
36
- oai_client = OpenAI (api_key = api_key )
37
- elif llm == "azure" :
38
- api_key = config .get_env_required ("AZURE_OPENAI_API_KEY" )
39
- azure_endpoint = config .get_env_required ("AZURE_OPENAI_ENDPOINT" )
40
- oai_client = AzureOpenAI (
41
- api_key = api_key ,
42
- api_version = "2023-07-01-preview" ,
43
- azure_endpoint = azure_endpoint ,
44
- )
33
+ if llm in ["openai" , "azure" , "openrouter" , "local" ]:
34
+ init_openai (llm , config )
35
+ assert get_openai_client ()
45
36
elif llm == "anthropic" :
46
- api_key = config .get_env_required ("ANTHROPIC_API_KEY" )
47
- anthropic_client = Anthropic (
48
- api_key = api_key ,
49
- max_retries = 5 ,
50
- )
51
- elif llm == "openrouter" :
52
- api_key = config .get_env_required ("OPENROUTER_API_KEY" )
53
- oai_client = OpenAI (api_key = api_key , base_url = "https://openrouter.ai/api/v1" )
54
- elif llm == "local" :
55
- api_base = config .get_env_required ("OPENAI_API_BASE" )
56
- api_key = config .get_env ("OPENAI_API_KEY" ) or "ollama"
57
- oai_client = OpenAI (api_key = api_key , base_url = api_base )
37
+ init_anthropic (config )
38
+ assert get_anthropic_client ()
58
39
else :
59
40
print (f"Error: Unknown LLM: { llm } " )
60
41
sys .exit (1 )
61
42
62
- # ensure we have initialized the client
63
- assert oai_client or anthropic_client
64
-
65
43
66
44
def reply (messages : list [Message ], model : str , stream : bool = False ) -> Message :
67
45
if stream :
@@ -74,128 +52,26 @@ def reply(messages: list[Message], model: str, stream: bool = False) -> Message:
74
52
return Message ("assistant" , response )
75
53
76
54
77
- def _chat_complete_openai (messages : list [Message ], model : str ) -> str :
78
- # This will generate code and such, so we need appropriate temperature and top_p params
79
- # top_p controls diversity, temperature controls randomness
80
- assert oai_client , "LLM not initialized"
81
- response = oai_client .chat .completions .create (
82
- model = model ,
83
- messages = msgs2dicts (messages , openai = True ), # type: ignore
84
- temperature = temperature ,
85
- top_p = top_p ,
86
- )
87
- content = response .choices [0 ].message .content
88
- assert content
89
- return content
90
-
91
-
92
- def _chat_complete_anthropic (messages : list [Message ], model : str ) -> str :
93
- assert anthropic_client , "LLM not initialized"
94
- messages , system_message = _transform_system_messages_anthropic (messages )
95
- response = anthropic_client .messages .create (
96
- model = model ,
97
- messages = msgs2dicts (messages , anthropic = True ), # type: ignore
98
- system = system_message ,
99
- temperature = temperature ,
100
- top_p = top_p ,
101
- max_tokens = 4096 ,
102
- )
103
- content = response .content
104
- assert content
105
- assert len (content ) == 1
106
- return content [0 ].text # type: ignore
107
-
108
-
109
55
def _chat_complete (messages : list [Message ], model : str ) -> str :
110
- if oai_client :
111
- return _chat_complete_openai (messages , model )
112
- elif anthropic_client :
113
- return _chat_complete_anthropic (messages , model )
56
+ provider = _client_to_provider ()
57
+ if provider == "openai" :
58
+ return chat_openai (messages , model )
59
+ elif provider == "anthropic" :
60
+ return chat_anthropic (messages , model )
114
61
else :
115
62
raise ValueError ("LLM not initialized" )
116
63
117
64
118
- def _transform_system_messages_anthropic (
119
- messages : list [Message ],
120
- ) -> tuple [list [Message ], str ]:
121
- # transform system messages into system kwarg for anthropic
122
- # for first system message, transform it into a system kwarg
123
- assert messages [0 ].role == "system"
124
- system_prompt = messages [0 ].content
125
- messages .pop (0 )
126
-
127
- # for any subsequent system messages, transform them into a <system> message
128
- for i , message in enumerate (messages ):
129
- if message .role == "system" :
130
- messages [i ] = Message (
131
- "user" ,
132
- content = f"<system>{ message .content } </system>" ,
133
- )
134
-
135
- # find consecutive user role messages and merge them into a single <system> message
136
- messages_new : list [Message ] = []
137
- while messages :
138
- message = messages .pop (0 )
139
- if messages_new and messages_new [- 1 ].role == "user" :
140
- messages_new [- 1 ] = Message (
141
- "user" ,
142
- content = f"{ messages_new [- 1 ].content } \n { message .content } " ,
143
- )
144
- else :
145
- messages_new .append (message )
146
- messages = messages_new
147
-
148
- return messages , system_prompt
149
-
150
-
151
65
def _stream (messages : list [Message ], model : str ) -> Iterator [str ]:
152
- if oai_client :
153
- return _stream_openai (messages , model )
154
- elif anthropic_client :
155
- return _stream_anthropic (messages , model )
66
+ provider = _client_to_provider ()
67
+ if provider == "openai" :
68
+ return stream_openai (messages , model )
69
+ elif provider == "anthropic" :
70
+ return stream_anthropic (messages , model )
156
71
else :
157
72
raise ValueError ("LLM not initialized" )
158
73
159
74
160
- def _stream_openai (messages : list [Message ], model : str ) -> Generator [str , None , None ]:
161
- assert oai_client , "LLM not initialized"
162
- stop_reason = None
163
- for chunk in oai_client .chat .completions .create (
164
- model = model ,
165
- messages = msgs2dicts (messages , openai = True ), # type: ignore
166
- temperature = temperature ,
167
- top_p = top_p ,
168
- stream = True ,
169
- # the llama-cpp-python server needs this explicitly set, otherwise unreliable results
170
- # TODO: make this better
171
- max_tokens = 1000 if not model .startswith ("gpt-" ) else 4096 ,
172
- ):
173
- if not chunk .choices : # type: ignore
174
- # Got a chunk with no choices, Azure always sends one of these at the start
175
- continue
176
- stop_reason = chunk .choices [0 ].finish_reason # type: ignore
177
- content = chunk .choices [0 ].delta .content # type: ignore
178
- if content :
179
- yield content
180
- logger .debug (f"Stop reason: { stop_reason } " )
181
-
182
-
183
- def _stream_anthropic (
184
- messages : list [Message ], model : str
185
- ) -> Generator [str , None , None ]:
186
- messages , system_prompt = _transform_system_messages_anthropic (messages )
187
- assert anthropic_client , "LLM not initialized"
188
- with anthropic_client .messages .stream (
189
- model = model ,
190
- messages = msgs2dicts (messages , anthropic = True ), # type: ignore
191
- system = system_prompt ,
192
- temperature = temperature ,
193
- top_p = top_p ,
194
- max_tokens = 4096 ,
195
- ) as stream :
196
- yield from stream .text_stream
197
-
198
-
199
75
def _reply_stream (messages : list [Message ], model : str ) -> Message :
200
76
print (f"{ PROMPT_ASSISTANT } : Thinking..." , end = "\r " )
201
77
@@ -236,11 +112,14 @@ def print_clear():
236
112
return Message ("assistant" , output )
237
113
238
114
239
- def _client_to_provider () -> str :
240
- if oai_client :
241
- if "openai" in oai_client .base_url .host :
115
+ def _client_to_provider () -> Provider :
116
+ openai_client = get_openai_client ()
117
+ anthropic_client = get_anthropic_client ()
118
+ assert openai_client or anthropic_client , "No client initialized"
119
+ if openai_client :
120
+ if "openai" in openai_client .base_url .host :
242
121
return "openai"
243
- elif "openrouter" in oai_client .base_url .host :
122
+ elif "openrouter" in openai_client .base_url .host :
244
123
return "openrouter"
245
124
else :
246
125
return "azure"
@@ -265,8 +144,9 @@ def summarize(content: str) -> str:
265
144
Message ("user" , content = f"Summarize this:\n { content } " ),
266
145
]
267
146
268
- model = get_summary_model (_client_to_provider ())
269
- context_limit = MODELS ["openai" if oai_client else "anthropic" ][model ]["context" ]
147
+ provider = _client_to_provider ()
148
+ model = get_summary_model (provider )
149
+ context_limit = MODELS [provider ][model ]["context" ]
270
150
if len_tokens (messages ) > context_limit :
271
151
raise ValueError (
272
152
f"Cannot summarize more than { context_limit } tokens, got { len_tokens (messages )} "
0 commit comments