From ac393675f0b3d870df8c3ed11b769e45f3b438f3 Mon Sep 17 00:00:00 2001 From: James Zhou Date: Mon, 7 Jul 2025 18:22:04 -0700 Subject: [PATCH] feat: add user_id to convai --- .../conversational_ai/conversation.py | 54 +++++++++---------- tests/test_convai.py | 9 ++-- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/src/elevenlabs/conversational_ai/conversation.py b/src/elevenlabs/conversational_ai/conversation.py index 818efea2..c99f8468 100644 --- a/src/elevenlabs/conversational_ai/conversation.py +++ b/src/elevenlabs/conversational_ai/conversation.py @@ -15,6 +15,7 @@ class ClientToOrchestratorEvent(str, Enum): """Event types that can be sent from client to orchestrator.""" + # Response to a ping request. PONG = "pong" CLIENT_TOOL_RESULT = "client_tool_result" @@ -29,42 +30,34 @@ class ClientToOrchestratorEvent(str, Enum): class UserMessageClientToOrchestratorEvent: """Event for sending user text messages.""" - + def __init__(self, text: Optional[str] = None): self.type: Literal[ClientToOrchestratorEvent.USER_MESSAGE] = ClientToOrchestratorEvent.USER_MESSAGE self.text = text - + def to_dict(self) -> dict: - return { - "type": self.type, - "text": self.text - } + return {"type": self.type, "text": self.text} class UserActivityClientToOrchestratorEvent: """Event for registering user activity (ping to prevent timeout).""" - + def __init__(self) -> None: self.type: Literal[ClientToOrchestratorEvent.USER_ACTIVITY] = ClientToOrchestratorEvent.USER_ACTIVITY - + def to_dict(self) -> dict: - return { - "type": self.type - } + return {"type": self.type} class ContextualUpdateClientToOrchestratorEvent: """Event for sending non-interrupting contextual updates to the conversation state.""" - + def __init__(self, text: str): self.type: Literal[ClientToOrchestratorEvent.CONTEXTUAL_UPDATE] = ClientToOrchestratorEvent.CONTEXTUAL_UPDATE self.text = text - + def to_dict(self) -> dict: - return { - "type": self.type, - "content": self.text - } + return {"type": self.type, "content": self.text} class AudioInterface(ABC): @@ -196,7 +189,7 @@ def execute_tool(self, tool_name: str, parameters: dict, callback: Callable[[dic """ if not self._running.is_set(): raise RuntimeError("ClientTools event loop is not running") - + if self._loop is None: raise RuntimeError("Event loop is not available") @@ -257,6 +250,7 @@ def __init__( self, client: BaseElevenLabs, agent_id: str, + user_id: Optional[str] = None, *, requires_auth: bool, audio_interface: AudioInterface, @@ -274,6 +268,7 @@ def __init__( Args: client: The ElevenLabs client to use for the conversation. agent_id: The ID of the agent to converse with. + user_id: The ID of the user conversing with the agent. requires_auth: Whether the agent requires authentication. audio_interface: The audio interface to use for input and output. client_tools: The client tools to use for the conversation. @@ -287,6 +282,7 @@ def __init__( self.client = client self.agent_id = agent_id + self.user_id = user_id self.requires_auth = requires_auth self.audio_interface = audio_interface self.callback_agent_response = callback_agent_response @@ -334,16 +330,16 @@ def wait_for_session_end(self) -> Optional[str]: def send_user_message(self, text: str): """Send a text message from the user to the agent. - + Args: text: The text message to send to the agent. - + Raises: RuntimeError: If the session is not active or websocket is not connected. """ if not self._ws: raise RuntimeError("Session not started or websocket not connected.") - + event = UserMessageClientToOrchestratorEvent(text=text) try: self._ws.send(json.dumps(event.to_dict())) @@ -353,15 +349,15 @@ def send_user_message(self, text: str): def register_user_activity(self): """Register user activity to prevent session timeout. - + This sends a ping to the orchestrator to reset the timeout timer. - + Raises: RuntimeError: If the session is not active or websocket is not connected. """ if not self._ws: raise RuntimeError("Session not started or websocket not connected.") - + event = UserActivityClientToOrchestratorEvent() try: self._ws.send(json.dumps(event.to_dict())) @@ -371,19 +367,19 @@ def register_user_activity(self): def send_contextual_update(self, text: str): """Send a contextual update to the conversation. - + Contextual updates are non-interrupting content that is sent to the server to update the conversation state without directly prompting the agent. - + Args: content: The contextual information to send to the conversation. - + Raises: RuntimeError: If the session is not active or websocket is not connected. """ if not self._ws: raise RuntimeError("Session not started or websocket not connected.") - + event = ContextualUpdateClientToOrchestratorEvent(text=text) try: self._ws.send(json.dumps(event.to_dict())) @@ -435,7 +431,7 @@ def input_callback(audio): except Exception as e: print(f"Error receiving message: {e}") self.end_session() - + self._ws = None def _handle_message(self, message, ws): diff --git a/tests/test_convai.py b/tests/test_convai.py index 65a12cae..bee36928 100644 --- a/tests/test_convai.py +++ b/tests/test_convai.py @@ -52,11 +52,13 @@ def test_conversation_basic_flow(): mock_ws = create_mock_websocket() mock_client = MagicMock() agent_response_callback = MagicMock() + test_user_id = "test_user_123" # Setup the conversation conversation = Conversation( client=mock_client, agent_id=TEST_AGENT_ID, + user_id=test_user_id, requires_auth=False, audio_interface=MockAudioInterface(), callback_agent_response=agent_response_callback, @@ -86,6 +88,7 @@ def test_conversation_basic_flow(): mock_ws.send.assert_any_call(json.dumps(expected_init_message)) agent_response_callback.assert_called_once_with("Hello there!") assert conversation._conversation_id == TEST_CONVERSATION_ID + assert conversation.user_id == test_user_id def test_conversation_with_auth(): @@ -118,6 +121,7 @@ def test_conversation_with_auth(): # Assertions mock_client.conversational_ai.conversations.get_signed_url.assert_called_once_with(agent_id=TEST_AGENT_ID) + def test_conversation_with_dynamic_variables(): # Mock setup mock_ws = create_mock_websocket() @@ -156,14 +160,13 @@ def test_conversation_with_dynamic_variables(): "type": "conversation_initiation_client_data", "custom_llm_extra_body": {}, "conversation_config_override": {}, - "dynamic_variables": { - "name": "angelo" - }, + "dynamic_variables": {"name": "angelo"}, } mock_ws.send.assert_any_call(json.dumps(expected_init_message)) agent_response_callback.assert_called_once_with("Hello there!") assert conversation._conversation_id == TEST_CONVERSATION_ID + def test_conversation_with_contextual_update(): # Mock setup mock_ws = create_mock_websocket([])