diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index a9d631da6..ef0950c0f 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -433,6 +433,14 @@ async def _construct_message_from_response(self, response: Any, request: Message if not pieces: raise EmptyResponseException(message="Failed to extract any response content.") + # Capture token usage from the API response and store in the first piece's metadata + if hasattr(response, "usage") and response.usage and pieces: + pieces[0].prompt_metadata["token_usage_model_name"] = getattr(response, "model", "unknown") + pieces[0].prompt_metadata["token_usage_prompt_tokens"] = getattr(response.usage, "prompt_tokens", 0) + pieces[0].prompt_metadata["token_usage_completion_tokens"] = getattr(response.usage, "completion_tokens", 0) + pieces[0].prompt_metadata["token_usage_total_tokens"] = getattr(response.usage, "total_tokens", 0) + pieces[0].prompt_metadata["token_usage_cached_tokens"] = getattr(response.usage, "cached_tokens", 0) + return Message(message_pieces=pieces) async def _save_audio_response_async(self, *, audio_data_base64: str) -> str: diff --git a/tests/integration/targets/test_openai_chat_target_integration.py b/tests/integration/targets/test_openai_chat_target_integration.py index bb862cdb8..ae7c98919 100644 --- a/tests/integration/targets/test_openai_chat_target_integration.py +++ b/tests/integration/targets/test_openai_chat_target_integration.py @@ -199,3 +199,53 @@ async def test_openai_chat_target_tool_calling_multiple_tools(sqlite_instance, p tool_call_data = json.loads(tool_call_pieces[0].converted_value) assert tool_call_data["function"]["name"] == "get_stock_price" assert "msft" in tool_call_data["function"]["arguments"].lower() + + +# ============================================================================ +# Token Usage Metadata Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_openai_chat_target_token_usage_in_metadata(sqlite_instance, platform_openai_chat_args): + """ + Test that token usage metadata is captured from a real API response. + + This test verifies that: + 1. Token usage keys are present in prompt_metadata of the response + 2. Token counts are non-negative integers + 3. Model name is a non-empty string + """ + target = OpenAIChatTarget(**platform_openai_chat_args) + + conv_id = str(uuid.uuid4()) + + user_piece = MessagePiece( + role="user", + original_value="Say hello in one word.", + original_value_data_type="text", + conversation_id=conv_id, + ) + + result = await target.send_prompt_async(message=user_piece.to_message()) + assert result is not None + assert len(result) >= 1 + + first_piece = result[0].message_pieces[0] + metadata = first_piece.prompt_metadata + + # Verify token usage keys are present + assert "token_usage_model_name" in metadata, "Response should contain token_usage_model_name in metadata" + assert "token_usage_prompt_tokens" in metadata, "Response should contain token_usage_prompt_tokens in metadata" + assert "token_usage_completion_tokens" in metadata + assert "token_usage_total_tokens" in metadata + + # Verify values are reasonable + assert isinstance(metadata["token_usage_model_name"], str) + assert len(metadata["token_usage_model_name"]) > 0 + assert metadata["token_usage_prompt_tokens"] > 0 + assert metadata["token_usage_completion_tokens"] > 0 + assert metadata["token_usage_total_tokens"] > 0 + assert metadata["token_usage_total_tokens"] == ( + metadata["token_usage_prompt_tokens"] + metadata["token_usage_completion_tokens"] + ) diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/target/test_openai_chat_target.py index 846efe353..bae75e4b4 100644 --- a/tests/unit/target/test_openai_chat_target.py +++ b/tests/unit/target/test_openai_chat_target.py @@ -1928,3 +1928,71 @@ def test_construct_request_body_with_tools(patch_central_database): assert target._extra_body_parameters.get("tools") == tools assert target._extra_body_parameters.get("tool_choice") == "auto" + + +# ============================================================================ +# Token Usage Metadata Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_construct_message_from_response_captures_token_usage( + target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece +): + """Test that token usage from the API response is stored in prompt_metadata.""" + mock_response = create_mock_completion(content="Hello") + mock_response.model = "gpt-4o-2024-05-13" + mock_response.usage = MagicMock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 20 + mock_response.usage.total_tokens = 30 + mock_response.usage.cached_tokens = 5 + + result = await target._construct_message_from_response(mock_response, dummy_text_message_piece) + + piece = result.message_pieces[0] + assert piece.prompt_metadata["token_usage_model_name"] == "gpt-4o-2024-05-13" + assert piece.prompt_metadata["token_usage_prompt_tokens"] == 10 + assert piece.prompt_metadata["token_usage_completion_tokens"] == 20 + assert piece.prompt_metadata["token_usage_total_tokens"] == 30 + assert piece.prompt_metadata["token_usage_cached_tokens"] == 5 + + +@pytest.mark.asyncio +async def test_construct_message_from_response_no_usage_no_metadata( + target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece +): + """Test that no token usage metadata is written when response.usage is None.""" + mock_response = create_mock_completion(content="Hello") + mock_response.usage = None + + result = await target._construct_message_from_response(mock_response, dummy_text_message_piece) + + piece = result.message_pieces[0] + assert "token_usage_model_name" not in piece.prompt_metadata + assert "token_usage_prompt_tokens" not in piece.prompt_metadata + + +@pytest.mark.asyncio +async def test_construct_message_from_response_token_usage_defaults_on_missing_attrs( + target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece +): + """Test that missing usage attributes default to 0 and missing model defaults to 'unknown'.""" + mock_response = create_mock_completion(content="Hello") + # Create a usage object without cached_tokens + mock_usage = MagicMock(spec=[]) + mock_usage.prompt_tokens = 5 + mock_usage.completion_tokens = 10 + mock_usage.total_tokens = 15 + mock_response.usage = mock_usage + # Remove model attribute to test default + del mock_response.model + + result = await target._construct_message_from_response(mock_response, dummy_text_message_piece) + + piece = result.message_pieces[0] + assert piece.prompt_metadata["token_usage_model_name"] == "unknown" + assert piece.prompt_metadata["token_usage_prompt_tokens"] == 5 + assert piece.prompt_metadata["token_usage_completion_tokens"] == 10 + assert piece.prompt_metadata["token_usage_total_tokens"] == 15 + assert piece.prompt_metadata["token_usage_cached_tokens"] == 0