Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/unit/vertexai/genai/replays/test_batch_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#
# pylint: disable=protected-access,bad-continuation,missing-function-docstring

import pytest

from tests.unit.vertexai.genai.replays import pytest_helper
from vertexai._genai import types

Expand Down Expand Up @@ -43,3 +45,24 @@ def test_batch_eval(client):
globals_for_file=globals(),
test_method="evals.batch_evaluate",
)

pytest_plugins = ("pytest_asyncio",)


@pytest.mark.asyncio
async def test_batch_eval_async(client):
eval_dataset = types.EvaluationDataset(
gcs_source=types.GcsSource(
uris=["gs://genai-eval-sdk-replay-test/test_data/inference_results.jsonl"]
)
)

response = await client.aio.evals.batch_evaluate(
dataset=eval_dataset,
metrics=[
types.PrebuiltMetric.TEXT_QUALITY,
],
dest="gs://genai-eval-sdk-replay-test/test_data/batch_eval_output",
)
assert "operations" in response.name
assert "EvaluateDatasetOperationMetadata" in response.metadata.get("@type")
55 changes: 47 additions & 8 deletions tests/unit/vertexai/genai/replays/test_evaluate_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
#
# pylint: disable=protected-access,bad-continuation,missing-function-docstring

import json

from tests.unit.vertexai.genai.replays import pytest_helper
from vertexai._genai import types
import pandas as pd
import json
import pytest


def test_bleu_metric(client):
Expand All @@ -31,7 +32,11 @@ def test_bleu_metric(client):
],
metric_spec=types.BleuSpec(),
)
response = client.evals._evaluate_instances(bleu_input=test_bleu_input)
response = client.evals.evaluate_instances(
metric_config=types._EvaluateInstancesRequestParameters(
bleu_input=test_bleu_input
)
)
assert len(response.bleu_results.bleu_metric_values) == 1


Expand All @@ -46,8 +51,10 @@ def test_exact_match_metric(client):
],
metric_spec=types.ExactMatchSpec(),
)
response = client.evals._evaluate_instances(
exact_match_input=test_exact_match_input
response = client.evals.evaluate_instances(
metric_config=types._EvaluateInstancesRequestParameters(
exact_match_input=test_exact_match_input
)
)
assert len(response.exact_match_results.exact_match_metric_values) == 1

Expand All @@ -63,7 +70,11 @@ def test_rouge_metric(client):
],
metric_spec=types.RougeSpec(rouge_type="rougeL"),
)
response = client.evals._evaluate_instances(rouge_input=test_rouge_input)
response = client.evals.evaluate_instances(
metric_config=types._EvaluateInstancesRequestParameters(
rouge_input=test_rouge_input
)
)
assert len(response.rouge_results.rouge_metric_values) == 1


Expand All @@ -78,7 +89,11 @@ def test_pointwise_metric(client):
metric_prompt_template="Evaluate if the response '{response}' correctly answers the prompt '{prompt}'."
),
)
response = client.evals._evaluate_instances(pointwise_metric_input=test_input)
response = client.evals.evaluate_instances(
metric_config=types._EvaluateInstancesRequestParameters(
pointwise_metric_input=test_input
)
)
assert response.pointwise_metric_result is not None
assert response.pointwise_metric_result.score is not None

Expand All @@ -100,8 +115,10 @@ def test_pairwise_metric_with_autorater(client):
)
autorater_config = types.AutoraterConfig(sampling_count=2)

response = client.evals._evaluate_instances(
pairwise_metric_input=test_input, autorater_config=autorater_config
response = client.evals.evaluate_instances(
metric_config=types._EvaluateInstancesRequestParameters(
pairwise_metric_input=test_input, autorater_config=autorater_config
)
)
assert response.pairwise_metric_result is not None
assert response.pairwise_metric_result.pairwise_choice is not None
Expand Down Expand Up @@ -147,3 +164,25 @@ def test_inference_with_prompt_template(client):
globals_for_file=globals(),
test_method="evals.evaluate",
)


pytest_plugins = ("pytest_asyncio",)


@pytest.mark.asyncio
async def test_bleu_metric_async(client):
test_bleu_input = types.BleuInput(
instances=[
types.BleuInstance(
reference="The quick brown fox jumps over the lazy dog.",
prediction="A fast brown fox leaps over a lazy dog.",
)
],
metric_spec=types.BleuSpec(),
)
response = await client.aio.evals.evaluate_instances(
metric_config=types._EvaluateInstancesRequestParameters(
bleu_input=test_bleu_input
)
)
assert len(response.bleu_results.bleu_metric_values) == 1
18 changes: 18 additions & 0 deletions vertexai/_genai/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,3 +1522,21 @@ async def batch_evaluate(
self._api_client._verify_response(return_value)

return return_value

async def evaluate_instances(
self,
*,
metric_config: types._EvaluateInstancesRequestParameters,
) -> types.EvaluateInstancesResponse:
"""Evaluates an instance of a model."""

if isinstance(metric_config, types._EvaluateInstancesRequestParameters):
metric_config = metric_config.model_dump()
else:
metric_config = dict(metric_config)

result = await self._evaluate_instances(
**metric_config,
)

return result
22 changes: 22 additions & 0 deletions vertexai/_genai/mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[mypy]
# TODO(b/422425982): Fix arg-type errors
disable_error_code = import-not-found, import-untyped, arg-type

# We only want to run mypy on _genai dir, ignore dependent modules
[mypy-vertexai.agent_engines.*]
ignore_errors = True

[mypy-vertexai.preview.*]
ignore_errors = True

[mypy-vertexai.generative_models.*]
ignore_errors = True

[mypy-vertexai.prompts.*]
ignore_errors = True

[mypy-vertexai.tuning.*]
ignore_errors = True

[mypy-vertexai.caching.*]
ignore_errors = True
Loading