Skip to content

Commit

Permalink
Support Fireworks batching (langchain-ai#8) (langchain-ai#12052)
Browse files Browse the repository at this point in the history
Description

* Add _generate and _agenerate to support Fireworks batching.
* Add stop words test cases
* Opt out retry mechanism

Issue - Not applicable
Dependencies - None
Tag maintainer - @baskaryan
  • Loading branch information
ZixinYang authored and jakubno committed Nov 1, 2023
1 parent deb4dc9 commit 856d3a6
Show file tree
Hide file tree
Showing 4 changed files with 347 additions and 45 deletions.
35 changes: 28 additions & 7 deletions libs/langchain/langchain/chat_models/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class ChatFireworks(BaseChatModel):
)
fireworks_api_key: Optional[str] = None
max_retries: int = 20
use_retry: bool = True

@property
def lc_secrets(self) -> Dict[str, str]:
Expand Down Expand Up @@ -134,7 +135,11 @@ def _generate(
**self.model_kwargs,
}
response = completion_with_retry(
self, run_manager=run_manager, stop=stop, **params
self,
self.use_retry,
run_manager=run_manager,
stop=stop,
**params,
)
return self._create_chat_result(response)

Expand All @@ -152,7 +157,7 @@ async def _agenerate(
**self.model_kwargs,
}
response = await acompletion_with_retry(
self, run_manager=run_manager, stop=stop, **params
self, self.use_retry, run_manager=run_manager, stop=stop, **params
)
return self._create_chat_result(response)

Expand Down Expand Up @@ -195,7 +200,7 @@ def _stream(
**self.model_kwargs,
}
for chunk in completion_with_retry(
self, run_manager=run_manager, stop=stop, **params
self, self.use_retry, run_manager=run_manager, stop=stop, **params
):
choice = chunk.choices[0]
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
Expand Down Expand Up @@ -224,7 +229,7 @@ async def _astream(
**self.model_kwargs,
}
async for chunk in await acompletion_with_retry_streaming(
self, run_manager=run_manager, stop=stop, **params
self, self.use_retry, run_manager=run_manager, stop=stop, **params
):
choice = chunk.choices[0]
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
Expand All @@ -238,8 +243,20 @@ async def _astream(
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)


def conditional_decorator(
condition: bool, decorator: Callable[[Any], Any]
) -> Callable[[Any], Any]:
def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
if condition:
return decorator(func)
return func

return actual_decorator


def completion_with_retry(
llm: ChatFireworks,
use_retry: bool,
*,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
Expand All @@ -249,7 +266,7 @@ def completion_with_retry(

retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)

@retry_decorator
@conditional_decorator(use_retry, retry_decorator)
def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.ChatCompletion.create(
**kwargs,
Expand All @@ -260,6 +277,7 @@ def _completion_with_retry(**kwargs: Any) -> Any:

async def acompletion_with_retry(
llm: ChatFireworks,
use_retry: bool,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
Expand All @@ -269,7 +287,7 @@ async def acompletion_with_retry(

retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)

@retry_decorator
@conditional_decorator(use_retry, retry_decorator)
async def _completion_with_retry(**kwargs: Any) -> Any:
return await fireworks.client.ChatCompletion.acreate(
**kwargs,
Expand All @@ -280,6 +298,7 @@ async def _completion_with_retry(**kwargs: Any) -> Any:

async def acompletion_with_retry_streaming(
llm: ChatFireworks,
use_retry: bool,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
Expand All @@ -289,7 +308,7 @@ async def acompletion_with_retry_streaming(

retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)

@retry_decorator
@conditional_decorator(use_retry, retry_decorator)
async def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.ChatCompletion.acreate(
**kwargs,
Expand All @@ -309,6 +328,8 @@ def _create_retry_decorator(

errors = [
fireworks.client.error.RateLimitError,
fireworks.client.error.InternalServerError,
fireworks.client.error.BadGatewayError,
fireworks.client.error.ServiceUnavailableError,
]
return create_base_retry_decorator(
Expand Down

0 comments on commit 856d3a6

Please sign in to comment.