Skip to content

Commit

Permalink
chat.completions support get index item
Browse files Browse the repository at this point in the history
  • Loading branch information
169 committed Nov 30, 2023
1 parent 25b531b commit 8c63fc8
Showing 1 changed file with 45 additions and 18 deletions.
63 changes: 45 additions & 18 deletions libs/langchain/langchain/adapters/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
Sequence,
Union,
overload,
Optional,
)

from typing_extensions import Literal
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import (
AIMessage,
Expand All @@ -28,6 +25,8 @@
SystemMessage,
ToolMessage,
)
from langchain_core.pydantic_v1 import BaseModel
from typing_extensions import Literal


async def aenumerate(
Expand All @@ -40,6 +39,29 @@ async def aenumerate(
i += 1


class IndexableBaseModel(BaseModel):
"""Allows a BaseModel to return its fields by string variable indexing"""

def __getitem__(self, item: str) -> Any:
return getattr(self, item)


class Choice(IndexableBaseModel):
message: dict


class ChatCompletions(IndexableBaseModel):
choices: List[Choice]


class ChoiceChunk(IndexableBaseModel):
delta: dict


class ChatCompletionChunk(IndexableBaseModel):
choices: List[ChoiceChunk]


def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
Expand Down Expand Up @@ -131,7 +153,7 @@ def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMess
return [convert_dict_to_message(m) for m in messages]


def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
def _convert_message_chunk(chunk: BaseMessageChunk, i: int) -> dict:
_dict: Dict[str, Any] = {}
if isinstance(chunk, AIMessageChunk):
if i == 0:
Expand All @@ -150,6 +172,11 @@ def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str
# This only happens at the end of streams, and OpenAI returns as empty dict
if _dict == {"content": ""}:
_dict = {}
return _dict


def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
_dict = _convert_message_chunk(chunk, i)
return {"choices": [{"delta": _dict}]}


Expand Down Expand Up @@ -266,15 +293,6 @@ def convert_messages_for_finetuning(
]



class Choice(BaseModel):
message: dict


class ChatCompletions(BaseModel):
choices: List[Choice]


class Completions:
"""Completion."""

Expand Down Expand Up @@ -314,10 +332,14 @@ def create(
converted_messages = convert_openai_messages(messages)
if not stream:
result = model_config.invoke(converted_messages)
return ChatCompletions(choices=[Choice(message=convert_message_to_dict(result))])
return ChatCompletions(
choices=[Choice(message=convert_message_to_dict(result))]
)
else:
return (
_convert_message_chunk_to_delta(c, i)
ChatCompletionChunk(
choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))]
)
for i, c in enumerate(model_config.stream(converted_messages))
)

Expand Down Expand Up @@ -357,16 +379,21 @@ async def acreate(
converted_messages = convert_openai_messages(messages)
if not stream:
result = await model_config.ainvoke(converted_messages)
return ChatCompletions(choices=[Choice(message=convert_message_to_dict(result))])
return ChatCompletions(
choices=[Choice(message=convert_message_to_dict(result))]
)
else:
return (
_convert_message_chunk_to_delta(c, i)
ChatCompletionChunk(
choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))]
)
async for i, c in aenumerate(model_config.astream(converted_messages))
)


class Chat:
def __init__(self) -> None:
self.completions = Completions()


chat = Chat()
chat = Chat()

0 comments on commit 8c63fc8

Please sign in to comment.