Skip to content

Commit

Permalink
fix: LLM - CodeGenerationModel now supports safety attributes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560967317
  • Loading branch information
Ark-kun authored and Copybara-Service committed Aug 29, 2023
1 parent 2a08535 commit c2c8a5e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 12 deletions.
17 changes: 14 additions & 3 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@

_TEST_CODE_GENERATION_PREDICTION = {
"safetyAttributes": {
"categories": [],
"blocked": False,
"scores": [],
"blocked": True,
"categories": ["Finance"],
"scores": [0.1],
},
"content": """
```python
Expand Down Expand Up @@ -2188,6 +2188,17 @@ def test_code_generation(self):
temperature=0.2,
)
assert response.text == _TEST_CODE_GENERATION_PREDICTION["content"]
expected_safety_attributes_raw = _TEST_CODE_GENERATION_PREDICTION[
"safetyAttributes"
]
expected_safety_attributes = dict(
zip(
expected_safety_attributes_raw["categories"],
expected_safety_attributes_raw["scores"],
)
)
assert response.safety_attributes == expected_safety_attributes
assert response.is_blocked == expected_safety_attributes_raw["blocked"]

# Validating the parameters
predict_temperature = 0.1
Expand Down
31 changes: 22 additions & 9 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,26 @@ def predict_streaming(
)


def _parse_text_generation_model_response(
prediction_response: aiplatform.models.Prediction,
prediction_idx: int = 0,
) -> TextGenerationResponse:
"""Converts the raw text_generation model response to `TextGenerationResponse`."""
prediction = prediction_response.predictions[prediction_idx]
safety_attributes_dict = prediction.get("safetyAttributes", {})
return TextGenerationResponse(
text=prediction["content"],
_prediction_response=prediction_response,
is_blocked=safety_attributes_dict.get("blocked", False),
safety_attributes=dict(
zip(
safety_attributes_dict.get("categories") or [],
safety_attributes_dict.get("scores") or [],
)
),
)


class _ModelWithBatchPredict(_LanguageModel):
"""Model that supports batch prediction."""

Expand Down Expand Up @@ -1754,11 +1774,7 @@ def predict(
instances=[prediction_request.instance],
parameters=prediction_request.parameters,
)

return TextGenerationResponse(
text=prediction_response.predictions[0]["content"],
_prediction_response=prediction_response,
)
return _parse_text_generation_model_response(prediction_response)

def predict_streaming(
self,
Expand Down Expand Up @@ -1800,10 +1816,7 @@ def predict_streaming(
predictions=[prediction_dict],
deployed_model_id="",
)
yield TextGenerationResponse(
text=prediction_dict["content"],
_prediction_response=prediction_obj,
)
yield _parse_text_generation_model_response(prediction_obj)


class _PreviewCodeGenerationModel(CodeGenerationModel, _TunableModelMixin):
Expand Down

0 comments on commit c2c8a5e

Please sign in to comment.