Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 25 additions & 29 deletions src/elevenlabs/conversational_ai/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

class ClientToOrchestratorEvent(str, Enum):
"""Event types that can be sent from client to orchestrator."""

# Response to a ping request.
PONG = "pong"
CLIENT_TOOL_RESULT = "client_tool_result"
Expand All @@ -29,42 +30,34 @@ class ClientToOrchestratorEvent(str, Enum):

class UserMessageClientToOrchestratorEvent:
"""Event for sending user text messages."""

def __init__(self, text: Optional[str] = None):
self.type: Literal[ClientToOrchestratorEvent.USER_MESSAGE] = ClientToOrchestratorEvent.USER_MESSAGE
self.text = text

def to_dict(self) -> dict:
return {
"type": self.type,
"text": self.text
}
return {"type": self.type, "text": self.text}


class UserActivityClientToOrchestratorEvent:
"""Event for registering user activity (ping to prevent timeout)."""

def __init__(self) -> None:
self.type: Literal[ClientToOrchestratorEvent.USER_ACTIVITY] = ClientToOrchestratorEvent.USER_ACTIVITY

def to_dict(self) -> dict:
return {
"type": self.type
}
return {"type": self.type}


class ContextualUpdateClientToOrchestratorEvent:
"""Event for sending non-interrupting contextual updates to the conversation state."""

def __init__(self, text: str):
self.type: Literal[ClientToOrchestratorEvent.CONTEXTUAL_UPDATE] = ClientToOrchestratorEvent.CONTEXTUAL_UPDATE
self.text = text

def to_dict(self) -> dict:
return {
"type": self.type,
"content": self.text
}
return {"type": self.type, "content": self.text}


class AudioInterface(ABC):
Expand Down Expand Up @@ -196,7 +189,7 @@ def execute_tool(self, tool_name: str, parameters: dict, callback: Callable[[dic
"""
if not self._running.is_set():
raise RuntimeError("ClientTools event loop is not running")

if self._loop is None:
raise RuntimeError("Event loop is not available")

Expand Down Expand Up @@ -257,6 +250,7 @@ def __init__(
self,
client: BaseElevenLabs,
agent_id: str,
user_id: Optional[str] = None,
*,
requires_auth: bool,
audio_interface: AudioInterface,
Expand All @@ -274,6 +268,7 @@ def __init__(
Args:
client: The ElevenLabs client to use for the conversation.
agent_id: The ID of the agent to converse with.
user_id: The ID of the user conversing with the agent.
requires_auth: Whether the agent requires authentication.
audio_interface: The audio interface to use for input and output.
client_tools: The client tools to use for the conversation.
Expand All @@ -287,6 +282,7 @@ def __init__(

self.client = client
self.agent_id = agent_id
self.user_id = user_id
self.requires_auth = requires_auth
self.audio_interface = audio_interface
self.callback_agent_response = callback_agent_response
Expand Down Expand Up @@ -334,16 +330,16 @@ def wait_for_session_end(self) -> Optional[str]:

def send_user_message(self, text: str):
"""Send a text message from the user to the agent.

Args:
text: The text message to send to the agent.

Raises:
RuntimeError: If the session is not active or websocket is not connected.
"""
if not self._ws:
raise RuntimeError("Session not started or websocket not connected.")

event = UserMessageClientToOrchestratorEvent(text=text)
try:
self._ws.send(json.dumps(event.to_dict()))
Expand All @@ -353,15 +349,15 @@ def send_user_message(self, text: str):

def register_user_activity(self):
"""Register user activity to prevent session timeout.

This sends a ping to the orchestrator to reset the timeout timer.

Raises:
RuntimeError: If the session is not active or websocket is not connected.
"""
if not self._ws:
raise RuntimeError("Session not started or websocket not connected.")

event = UserActivityClientToOrchestratorEvent()
try:
self._ws.send(json.dumps(event.to_dict()))
Expand All @@ -371,19 +367,19 @@ def register_user_activity(self):

def send_contextual_update(self, text: str):
"""Send a contextual update to the conversation.

Contextual updates are non-interrupting content that is sent to the server
to update the conversation state without directly prompting the agent.

Args:
content: The contextual information to send to the conversation.

Raises:
RuntimeError: If the session is not active or websocket is not connected.
"""
if not self._ws:
raise RuntimeError("Session not started or websocket not connected.")

event = ContextualUpdateClientToOrchestratorEvent(text=text)
try:
self._ws.send(json.dumps(event.to_dict()))
Expand Down Expand Up @@ -435,7 +431,7 @@ def input_callback(audio):
except Exception as e:
print(f"Error receiving message: {e}")
self.end_session()

self._ws = None

def _handle_message(self, message, ws):
Expand Down
9 changes: 6 additions & 3 deletions tests/test_convai.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ def test_conversation_basic_flow():
mock_ws = create_mock_websocket()
mock_client = MagicMock()
agent_response_callback = MagicMock()
test_user_id = "test_user_123"

# Setup the conversation
conversation = Conversation(
client=mock_client,
agent_id=TEST_AGENT_ID,
user_id=test_user_id,
requires_auth=False,
audio_interface=MockAudioInterface(),
callback_agent_response=agent_response_callback,
Expand Down Expand Up @@ -86,6 +88,7 @@ def test_conversation_basic_flow():
mock_ws.send.assert_any_call(json.dumps(expected_init_message))
agent_response_callback.assert_called_once_with("Hello there!")
assert conversation._conversation_id == TEST_CONVERSATION_ID
assert conversation.user_id == test_user_id


def test_conversation_with_auth():
Expand Down Expand Up @@ -118,6 +121,7 @@ def test_conversation_with_auth():
# Assertions
mock_client.conversational_ai.conversations.get_signed_url.assert_called_once_with(agent_id=TEST_AGENT_ID)


def test_conversation_with_dynamic_variables():
# Mock setup
mock_ws = create_mock_websocket()
Expand Down Expand Up @@ -156,14 +160,13 @@ def test_conversation_with_dynamic_variables():
"type": "conversation_initiation_client_data",
"custom_llm_extra_body": {},
"conversation_config_override": {},
"dynamic_variables": {
"name": "angelo"
},
"dynamic_variables": {"name": "angelo"},
}
mock_ws.send.assert_any_call(json.dumps(expected_init_message))
agent_response_callback.assert_called_once_with("Hello there!")
assert conversation._conversation_id == TEST_CONVERSATION_ID


def test_conversation_with_contextual_update():
# Mock setup
mock_ws = create_mock_websocket([])
Expand Down