Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pyrit/prompt_target/openai/openai_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,14 @@ async def _construct_message_from_response(self, response: Any, request: Message
if not pieces:
raise EmptyResponseException(message="Failed to extract any response content.")

# Capture token usage from the API response and store in the first piece's metadata
if hasattr(response, "usage") and response.usage and pieces:
pieces[0].prompt_metadata["token_usage_model_name"] = getattr(response, "model", "unknown")
pieces[0].prompt_metadata["token_usage_prompt_tokens"] = getattr(response.usage, "prompt_tokens", 0)
pieces[0].prompt_metadata["token_usage_completion_tokens"] = getattr(response.usage, "completion_tokens", 0)
pieces[0].prompt_metadata["token_usage_total_tokens"] = getattr(response.usage, "total_tokens", 0)
pieces[0].prompt_metadata["token_usage_cached_tokens"] = getattr(response.usage, "cached_tokens", 0)

return Message(message_pieces=pieces)

async def _save_audio_response_async(self, *, audio_data_base64: str) -> str:
Expand Down
50 changes: 50 additions & 0 deletions tests/integration/targets/test_openai_chat_target_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,53 @@ async def test_openai_chat_target_tool_calling_multiple_tools(sqlite_instance, p
tool_call_data = json.loads(tool_call_pieces[0].converted_value)
assert tool_call_data["function"]["name"] == "get_stock_price"
assert "msft" in tool_call_data["function"]["arguments"].lower()


# ============================================================================
# Token Usage Metadata Tests
# ============================================================================


@pytest.mark.asyncio
async def test_openai_chat_target_token_usage_in_metadata(sqlite_instance, platform_openai_chat_args):
"""
Test that token usage metadata is captured from a real API response.

This test verifies that:
1. Token usage keys are present in prompt_metadata of the response
2. Token counts are non-negative integers
3. Model name is a non-empty string
"""
target = OpenAIChatTarget(**platform_openai_chat_args)

conv_id = str(uuid.uuid4())

user_piece = MessagePiece(
role="user",
original_value="Say hello in one word.",
original_value_data_type="text",
conversation_id=conv_id,
)

result = await target.send_prompt_async(message=user_piece.to_message())
assert result is not None
assert len(result) >= 1

first_piece = result[0].message_pieces[0]
metadata = first_piece.prompt_metadata

# Verify token usage keys are present
assert "token_usage_model_name" in metadata, "Response should contain token_usage_model_name in metadata"
assert "token_usage_prompt_tokens" in metadata, "Response should contain token_usage_prompt_tokens in metadata"
assert "token_usage_completion_tokens" in metadata
assert "token_usage_total_tokens" in metadata

# Verify values are reasonable
assert isinstance(metadata["token_usage_model_name"], str)
assert len(metadata["token_usage_model_name"]) > 0
assert metadata["token_usage_prompt_tokens"] > 0
assert metadata["token_usage_completion_tokens"] > 0
assert metadata["token_usage_total_tokens"] > 0
assert metadata["token_usage_total_tokens"] == (
metadata["token_usage_prompt_tokens"] + metadata["token_usage_completion_tokens"]
)
68 changes: 68 additions & 0 deletions tests/unit/target/test_openai_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -1928,3 +1928,71 @@ def test_construct_request_body_with_tools(patch_central_database):

assert target._extra_body_parameters.get("tools") == tools
assert target._extra_body_parameters.get("tool_choice") == "auto"


# ============================================================================
# Token Usage Metadata Tests
# ============================================================================


@pytest.mark.asyncio
async def test_construct_message_from_response_captures_token_usage(
target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece
):
"""Test that token usage from the API response is stored in prompt_metadata."""
mock_response = create_mock_completion(content="Hello")
mock_response.model = "gpt-4o-2024-05-13"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 20
mock_response.usage.total_tokens = 30
mock_response.usage.cached_tokens = 5

result = await target._construct_message_from_response(mock_response, dummy_text_message_piece)

piece = result.message_pieces[0]
assert piece.prompt_metadata["token_usage_model_name"] == "gpt-4o-2024-05-13"
assert piece.prompt_metadata["token_usage_prompt_tokens"] == 10
assert piece.prompt_metadata["token_usage_completion_tokens"] == 20
assert piece.prompt_metadata["token_usage_total_tokens"] == 30
assert piece.prompt_metadata["token_usage_cached_tokens"] == 5


@pytest.mark.asyncio
async def test_construct_message_from_response_no_usage_no_metadata(
target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece
):
"""Test that no token usage metadata is written when response.usage is None."""
mock_response = create_mock_completion(content="Hello")
mock_response.usage = None

result = await target._construct_message_from_response(mock_response, dummy_text_message_piece)

piece = result.message_pieces[0]
assert "token_usage_model_name" not in piece.prompt_metadata
assert "token_usage_prompt_tokens" not in piece.prompt_metadata


@pytest.mark.asyncio
async def test_construct_message_from_response_token_usage_defaults_on_missing_attrs(
target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece
):
"""Test that missing usage attributes default to 0 and missing model defaults to 'unknown'."""
mock_response = create_mock_completion(content="Hello")
# Create a usage object without cached_tokens
mock_usage = MagicMock(spec=[])
mock_usage.prompt_tokens = 5
mock_usage.completion_tokens = 10
mock_usage.total_tokens = 15
mock_response.usage = mock_usage
# Remove model attribute to test default
del mock_response.model

result = await target._construct_message_from_response(mock_response, dummy_text_message_piece)

piece = result.message_pieces[0]
assert piece.prompt_metadata["token_usage_model_name"] == "unknown"
assert piece.prompt_metadata["token_usage_prompt_tokens"] == 5
assert piece.prompt_metadata["token_usage_completion_tokens"] == 10
assert piece.prompt_metadata["token_usage_total_tokens"] == 15
assert piece.prompt_metadata["token_usage_cached_tokens"] == 0
Loading