forked from zhiqix/NL2GQL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Bigger_LLM.py
91 lines (68 loc) · 2.9 KB
/
Bigger_LLM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import asyncio
from typing import NamedTuple, Union
import openai
from openai.error import APIConnectionError
from Config import *
import time
"""
Call the OpenAI's API
"""
class OPENAI():
def __init__(self):
self.__init_openai()
self.llm = openai
self.model = OPENAI_API_MODEL
self.auto_max_tokens = False
def __init_openai(self):
openai.api_key = OPENAI_API_KEY
if OPENAI_API_BASE:
openai.api_base = OPENAI_API_BASE
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response = await openai.ChatCompletion.acreate(**self._cons_kwargs(messages), stream=True)
# create variables to collect the stream of chunks
collected_chunks = []
collected_messages = []
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
choices = chunk["choices"]
if len(choices) > 0:
chunk_message = chunk["choices"][0].get("delta", {}) # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(chunk_message["content"], end="")
print()
full_reply_content = "".join([m.get("content", "") for m in collected_messages])
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
def _cons_kwargs(self, messages: list[dict]) -> dict:
kwargs = {
"messages": messages,
# "max_tokens": MAX_TOKENS,
"n": 1,
"stop": None,
"temperature": 0.3,
"timeout": 3,
}
kwargs_mode = {"model": self.model}
kwargs.update(kwargs_mode)
return kwargs
def _chat_completion(self, messages: list[dict]) -> dict:
rsp = self.llm.ChatCompletion.create(**self._cons_kwargs(messages))
# self._update_costs(rsp)
return rsp.choices[0].message.content
def completion(self, messages: list[dict]) -> dict:
# if isinstance(messages[0], Message):
# messages = self.messages_to_dict(messages)
return self._chat_completion(messages)
# async def acompletion(self, messages: list[dict]) -> dict:
# # if isinstance(messages[0], Message):
# # messages = self.messages_to_dict(messages)
# return await self._achat_completion(messages)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
"""when streaming, print each token in place."""
if stream:
return await self._achat_completion_stream(messages)
rsp = await self._achat_completion(messages)
return self.get_choice_text(rsp)