Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions pyrit/score/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,14 +410,13 @@ async def score_prompts_batch_async(
list[Score]: A flattened list of Score objects from all scored prompts.

Raises:
ValueError: If objectives is empty or if the number of objectives doesn't match
ValueError: If objectives is not None and the number of objectives doesn't match
the number of messages.
"""
if not objectives:
if objectives is None:
objectives = [""] * len(messages)

elif len(objectives) != len(messages):
raise ValueError("The number of tasks must match the number of messages.")
raise ValueError("The number of objectives must match the number of messages.")

if len(messages) == 0:
return []
Expand Down Expand Up @@ -456,7 +455,7 @@ async def score_image_batch_async(
Raises:
ValueError: If the number of objectives does not match the number of image_paths.
"""
if objectives and len(objectives) != len(image_paths):
if objectives is not None and len(objectives) != len(image_paths):
raise ValueError("The number of objectives must match the number of image_paths.")

if len(image_paths) == 0:
Expand All @@ -465,10 +464,10 @@ async def score_image_batch_async(
prompt_target = getattr(self, "_prompt_target", None)
results = await batch_task_async(
task_func=self.score_image_async,
task_arguments=["image_path", "objective"] if objectives else ["image_path"],
task_arguments=["image_path", "objective"] if objectives is not None else ["image_path"],
prompt_target=prompt_target,
batch_size=batch_size,
items_to_batch=[image_paths, objectives] if objectives else [image_paths],
items_to_batch=[image_paths, objectives] if objectives is not None else [image_paths],
)

return [score for sublist in results for score in sublist]
Expand Down
49 changes: 49 additions & 0 deletions tests/unit/score/test_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,55 @@ async def test_scorer_score_responses_batch_async(patch_central_database):
assert len(fake_scores) == 2


@pytest.mark.asyncio
async def test_score_prompts_batch_async_rejects_explicit_empty_objectives():
"""Test explicit empty objectives are rejected for non-empty message batches."""
scorer = MockScorer()
message = MessagePiece(role="user", original_value="Hello user", sequence=1).to_message()

with pytest.raises(ValueError, match="objectives"):
await scorer.score_prompts_batch_async(messages=[message], objectives=[])


@pytest.mark.asyncio
async def test_score_image_batch_async_rejects_explicit_empty_objectives():
"""Test explicit empty objectives are rejected for non-empty image batches."""
scorer = MockScorer()

with pytest.raises(ValueError, match="objectives"):
await scorer.score_image_batch_async(image_paths=["test_image.png"], objectives=[])


@pytest.mark.asyncio
async def test_score_prompts_batch_async_defaults_objectives_when_none(patch_central_database):
"""Test that objectives=None defaults to empty-string objectives matching message count."""
scorer = MockScorer()

with patch.object(scorer, "score_async", new_callable=AsyncMock) as mock_score_async:
mock_score_async.return_value = [MagicMock()]
message = MessagePiece(role="user", original_value="Hello user", sequence=1).to_message()

await scorer.score_prompts_batch_async(messages=[message])

_, call_kwargs = mock_score_async.call_args
assert call_kwargs["objective"] == ""


@pytest.mark.asyncio
async def test_score_image_batch_async_works_when_objectives_none(patch_central_database):
"""Test that objectives=None omits objectives from the batch call."""
scorer = MockScorer()

with patch.object(scorer, "score_image_async", new_callable=AsyncMock) as mock_score_image:
mock_score_image.return_value = [MagicMock()]

await scorer.score_image_batch_async(image_paths=["test.png"])

mock_score_image.assert_called_once()
_, call_kwargs = mock_score_image.call_args
assert "objective" not in call_kwargs


@pytest.mark.asyncio
async def test_score_response_async_empty_scorers():
"""Test that score_response_async returns empty list when no scorers provided."""
Expand Down
Loading