Skip to content

Commit

Permalink
fix: sagemaker config and chat methods (#1142)
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloogc committed Oct 30, 2023
1 parent b0e2582 commit a517a58
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
31 changes: 29 additions & 2 deletions private_gpt/components/llm/custom/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import io
import json
import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

import boto3 # type: ignore
from llama_index.bridge.pydantic import Field
Expand All @@ -13,7 +13,14 @@
CustomLLM,
LLMMetadata,
)
from llama_index.llms.base import llm_completion_callback
from llama_index.llms.base import (
llm_chat_callback,
llm_completion_callback,
)
from llama_index.llms.generic_utils import (
completion_response_to_chat_response,
stream_completion_response_to_chat_response,
)
from llama_index.llms.llama_utils import (
completion_to_prompt as generic_completion_to_prompt,
)
Expand All @@ -22,8 +29,14 @@
)

if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Any

from llama_index.callbacks import CallbackManager
from llama_index.llms import (
ChatMessage,
ChatResponse,
ChatResponseGen,
CompletionResponseGen,
)

Expand Down Expand Up @@ -247,3 +260,17 @@ def get_stream():
yield CompletionResponse(delta=delta, text=text, raw=data)

return get_stream()

@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
prompt = self.messages_to_prompt(messages)
completion_response = self.complete(prompt, formatted=True, **kwargs)
return completion_response_to_chat_response(completion_response)

@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
prompt = self.messages_to_prompt(messages)
completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
return stream_completion_response_to_chat_response(completion_response)
2 changes: 0 additions & 2 deletions private_gpt/components/llm/llm_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def __init__(self) -> None:

self.llm = SagemakerLLM(
endpoint_name=settings.sagemaker.endpoint_name,
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
)
case "openai":
from llama_index.llms import OpenAI
Expand Down

0 comments on commit a517a58

Please sign in to comment.