Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions src/cohere/manually_maintained/cohere_aws/chat.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from .response import CohereObject
from .error import CohereError
from .mode import Mode
from typing import List, Optional, Generator, Dict, Any, Union
from enum import Enum
import json
from enum import Enum
from typing import Any, Dict, Generator, List, Optional, Union

from .mode import Mode
from .response import CohereObject

# Tools


class ToolParameterDefinitionsValue(CohereObject, dict):
def __init__(
self,
Expand Down Expand Up @@ -47,29 +48,41 @@ def __init__(
generation_id: str,
**kwargs,
) -> None:
super().__init__(**kwargs)
# Bypass super().__init__ if kwargs is empty (small perf win)
if kwargs:
super().__init__(**kwargs)
self.__dict__ = self
self.name = name
self.parameters = parameters
self.generation_id = generation_id

@classmethod
def from_dict(cls, tool_call_res: Dict[str, Any]) -> "ToolCall":
# Use local variable lookups for keys (faster than key lookups in tight loop)
get = tool_call_res.get
return cls(
name=tool_call_res.get("name"),
parameters=tool_call_res.get("parameters"),
generation_id=tool_call_res.get("generation_id"),
name=get("name"),
parameters=get("parameters"),
generation_id=get("generation_id"),
)

@classmethod
def from_list(cls, tool_calls_res: Optional[List[Dict[str, Any]]]) -> Optional[List["ToolCall"]]:
if tool_calls_res is None or not isinstance(tool_calls_res, list):
return None

return [ToolCall.from_dict(tc) for tc in tool_calls_res]
# Use localize variables for improved lookup during iteration
from_dict = cls.from_dict
# Preallocate list for performance if input is not enormous
result = [None] * len(tool_calls_res)
for idx, tc in enumerate(tool_calls_res):
result[idx] = from_dict(tc)
return result


# Chat


class Chat(CohereObject):
def __init__(
self,
Expand Down Expand Up @@ -119,10 +132,12 @@ def from_dict(cls, response: Dict[str, Any]) -> "Chat":
tool_calls=ToolCall.from_list(response.get("tool_calls")), # optional
)


# ---------------|
# Steaming event |
# ---------------|


class StreamEvent(str, Enum):
STREAM_START = "stream-start"
SEARCH_QUERIES_GENERATION = "search-queries-generation"
Expand All @@ -132,6 +147,7 @@ class StreamEvent(str, Enum):
CITATION_GENERATION = "citation-generation"
STREAM_END = "stream-end"


class StreamResponse(CohereObject):
def __init__(
self,
Expand Down Expand Up @@ -219,6 +235,7 @@ def __init__(
super().__init__(**kwargs)
self.tool_calls = tool_calls


class StreamingChat(CohereObject):
def __init__(self, stream_response, mode):
self.stream_response = stream_response
Expand Down