-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
custom_llm.py
93 lines (77 loc) · 3.04 KB
/
custom_llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import json
from collections.abc import Iterator
import requests
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage
from requests import Timeout
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import ToolChoiceOptions
from danswer.llm.utils import convert_lm_input_to_basic_string
from danswer.utils.logger import setup_logger
logger = setup_logger()
class CustomModelServer(LLM):
"""This class is to provide an example for how to use Danswer
with any LLM, even servers with custom API definitions.
To use with your own model server, simply implement the functions
below to fit your model server expectation
The implementation below works against the custom FastAPI server from the blog:
https://medium.com/@yuhongsun96/how-to-augment-llms-with-private-data-29349bd8ae9f
"""
@property
def requires_api_key(self) -> bool:
return False
def __init__(
self,
# Not used here but you probably want a model server that isn't completely open
api_key: str | None,
timeout: int,
endpoint: str | None = GEN_AI_API_ENDPOINT,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
):
if not endpoint:
raise ValueError(
"Cannot point Danswer to a custom LLM server without providing the "
"endpoint for the model server."
)
self._endpoint = endpoint
self._max_output_tokens = max_output_tokens
self._timeout = timeout
def _execute(self, input: LanguageModelInput) -> AIMessage:
headers = {
"Content-Type": "application/json",
}
data = {
"inputs": convert_lm_input_to_basic_string(input),
"parameters": {
"temperature": 0.0,
"max_tokens": self._max_output_tokens,
},
}
try:
response = requests.post(
self._endpoint, headers=headers, json=data, timeout=self._timeout
)
except Timeout as error:
raise Timeout(f"Model inference to {self._endpoint} timed out") from error
response.raise_for_status()
response_content = json.loads(response.content).get("generated_text", "")
return AIMessage(content=response_content)
def log_model_configs(self) -> None:
logger.debug(f"Custom model at: {self._endpoint}")
def invoke(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> BaseMessage:
return self._execute(prompt)
def stream(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> Iterator[BaseMessage]:
yield self._execute(prompt)