Skip to content

Commit

Permalink
Support Fireworks batching (#8)
Browse files Browse the repository at this point in the history
* Support Fireworks batching

* Support ChatFireworks batching
  • Loading branch information
ZixinYang committed Oct 26, 2023
1 parent 869a49a commit ed823d6
Show file tree
Hide file tree
Showing 4 changed files with 391 additions and 27 deletions.
208 changes: 206 additions & 2 deletions libs/langchain/langchain/chat_models/fireworks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
AsyncIterator,
Expand All @@ -11,12 +13,16 @@
)

from langchain.adapters.openai import convert_message_to_dict
from langchain.callbacks.base import Callbacks
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
CallbackManager,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import create_base_retry_decorator
from langchain.load.dump import dumpd
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema.messages import (
AIMessage,
Expand All @@ -32,7 +38,13 @@
SystemMessage,
SystemMessageChunk,
)
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain.schema.output import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
LLMResult,
RunInfo,
)
from langchain.utils.env import get_from_dict_or_env


Expand All @@ -55,7 +67,7 @@ def _convert_delta_to_message_chunk(
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)
return default_class(type="", content=content)


def convert_dict_to_message(_dict: Any) -> BaseMessage:
Expand Down Expand Up @@ -89,6 +101,7 @@ class ChatFireworks(BaseChatModel):
)
fireworks_api_key: Optional[str] = None
max_retries: int = 20
batch_size: int = 20

@property
def lc_secrets(self) -> Dict[str, str]:
Expand Down Expand Up @@ -119,6 +132,197 @@ def _llm_type(self) -> str:
"""Return type of llm."""
return "fireworks-chat"

def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop}

callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
run_managers = callback_manager.on_chat_model_start(
dumpd(self),
messages,
invocation_params=params,
options=options,
name=run_name,
)

def _completion_with_retry_batching(
message: List[List[BaseMessage]],
) -> List[Any]:
args_list = [
(m, stop, run_managers[i] if run_managers else None, kwargs)
for i, m in enumerate(message)
]
with ThreadPoolExecutor() as executor:
results = list(executor.map(self._process_message, args_list))

return results

sub_messages = self.get_batch_messages(params, messages, stop)

results = []
for message in sub_messages:
results.extend(_completion_with_retry_batching(message))

flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
if run_managers:
run_infos = []
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
run_infos.append(RunInfo(run_id=manager.run_id))
output.run = run_infos
return output

async def agenerate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop}

callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)

run_managers = await callback_manager.on_chat_model_start(
dumpd(self),
messages,
invocation_params=params,
options=options,
name=run_name,
)

async def _acompletion_with_retry_batching(
message: List[List[BaseMessage]],
) -> List[Any]:
args_list = [
(m, stop, run_managers[i] if run_managers else None, kwargs)
for i, m in enumerate(message)
]
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as executor:
results = await asyncio.gather(
*[
loop.run_in_executor(executor, self._process_message, args)
for args in args_list
],
return_exceptions=True,
)

return results

sub_messages = self.get_batch_messages(params, messages, stop)

results = []
for message in sub_messages:
results.extend(await _acompletion_with_retry_batching(message))

exceptions = []
for i, res in enumerate(results):
if isinstance(res, BaseException):
if run_managers:
await run_managers[i].on_llm_error(res)
exceptions.append(res)
if exceptions:
if run_managers:
await asyncio.gather(
*[
run_manager.on_llm_end(
LLMResult(
generations=[res.generations], llm_output=res.llm_output
)
)
for run_manager, res in zip(run_managers, results)
if not isinstance(res, Exception)
]
)
raise exceptions[0]
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output

def _process_message(self, args: List[Any]) -> ChatResult:
m, stop, run_manager, kwargs = args
try:
return self._generate_with_cache(
messages=m, stop=stop, run_manager=run_manager, **kwargs
)
except BaseException as e:
if run_manager:
run_manager.on_llm_error(e)
raise e

def get_batch_messages(
self,
params: Dict[str, Any],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
) -> List[List[List[BaseMessage]]]:
"""Get the sub messages for llm call."""
if stop is not None:
params["stop"] = stop

sub_messages = [
messages[i : i + self.batch_size]
for i in range(0, len(messages), self.batch_size)
]
return sub_messages

def _generate(
self,
messages: List[BaseMessage],
Expand Down

0 comments on commit ed823d6

Please sign in to comment.