diff --git a/libs/langchain/langchain/adapters/openai.py b/libs/langchain/langchain/adapters/openai.py index 5d8ffb7f2dba32b..0af759ebf5b0871 100644 --- a/libs/langchain/langchain/adapters/openai.py +++ b/libs/langchain/langchain/adapters/openai.py @@ -11,11 +11,8 @@ Sequence, Union, overload, - Optional, ) -from typing_extensions import Literal -from langchain_core.pydantic_v1 import BaseModel from langchain_core.chat_sessions import ChatSession from langchain_core.messages import ( AIMessage, @@ -28,6 +25,8 @@ SystemMessage, ToolMessage, ) +from langchain_core.pydantic_v1 import BaseModel +from typing_extensions import Literal async def aenumerate( @@ -40,6 +39,29 @@ async def aenumerate( i += 1 +class IndexableBaseModel(BaseModel): + """Allows a BaseModel to return its fields by string variable indexing""" + + def __getitem__(self, item: str) -> Any: + return getattr(self, item) + + +class Choice(IndexableBaseModel): + message: dict + + +class ChatCompletions(IndexableBaseModel): + choices: List[Choice] + + +class ChoiceChunk(IndexableBaseModel): + delta: dict + + +class ChatCompletionChunk(IndexableBaseModel): + choices: List[ChoiceChunk] + + def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: """Convert a dictionary to a LangChain message. @@ -131,7 +153,7 @@ def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMess return [convert_dict_to_message(m) for m in messages] -def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]: +def _convert_message_chunk(chunk: BaseMessageChunk, i: int) -> dict: _dict: Dict[str, Any] = {} if isinstance(chunk, AIMessageChunk): if i == 0: @@ -150,6 +172,11 @@ def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str # This only happens at the end of streams, and OpenAI returns as empty dict if _dict == {"content": ""}: _dict = {} + return _dict + + +def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]: + _dict = _convert_message_chunk(chunk, i) return {"choices": [{"delta": _dict}]} @@ -266,15 +293,6 @@ def convert_messages_for_finetuning( ] - -class Choice(BaseModel): - message: dict - - -class ChatCompletions(BaseModel): - choices: List[Choice] - - class Completions: """Completion.""" @@ -314,10 +332,14 @@ def create( converted_messages = convert_openai_messages(messages) if not stream: result = model_config.invoke(converted_messages) - return ChatCompletions(choices=[Choice(message=convert_message_to_dict(result))]) + return ChatCompletions( + choices=[Choice(message=convert_message_to_dict(result))] + ) else: return ( - _convert_message_chunk_to_delta(c, i) + ChatCompletionChunk( + choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))] + ) for i, c in enumerate(model_config.stream(converted_messages)) ) @@ -357,16 +379,21 @@ async def acreate( converted_messages = convert_openai_messages(messages) if not stream: result = await model_config.ainvoke(converted_messages) - return ChatCompletions(choices=[Choice(message=convert_message_to_dict(result))]) + return ChatCompletions( + choices=[Choice(message=convert_message_to_dict(result))] + ) else: return ( - _convert_message_chunk_to_delta(c, i) + ChatCompletionChunk( + choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))] + ) async for i, c in aenumerate(model_config.astream(converted_messages)) ) + class Chat: def __init__(self) -> None: self.completions = Completions() -chat = Chat() \ No newline at end of file +chat = Chat()