Skip to content

Commit

Permalink
feat: LLM - Released the LLM SDK to GA
Browse files Browse the repository at this point in the history
This commit also refactors the machinery around the mapping between schemas and interface classes.
The central table is not needed now. Each interface class knows its own instance schema.
This architecture makes it much easier to experiment with different interface classes since `MyClass.from_pretrained` now returns an instance of `MyClass` (if the model's instance schema matches).
This change makes it trivial to have multiple versions of an interface class (e.g. GA, preview etc). It was much harder with a centralized table that could only hold a single interface class for each instance schema.

PiperOrigin-RevId: 538100676
  • Loading branch information
Ark-kun authored and copybara-github committed Jun 6, 2023
1 parent ce5dee4 commit 76465e2
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 72 deletions.
118 changes: 103 additions & 15 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@
model as gca_model,
)

from vertexai.preview import language_models
from vertexai.preview import (
language_models as preview_language_models,
)
from vertexai import language_models
from google.cloud.aiplatform_v1 import Execution as GapicExecution
from google.cloud.aiplatform.compat.types import (
encryption_spec as gca_encryption_spec,
Expand Down Expand Up @@ -456,7 +459,7 @@ def get_endpoint_mock():
@pytest.fixture
def mock_get_tuned_model(get_endpoint_mock):
with mock.patch.object(
language_models.TextGenerationModel, "get_tuned_model"
preview_language_models.TextGenerationModel, "get_tuned_model"
) as mock_text_generation_model:
mock_text_generation_model._model_id = (
test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
Expand Down Expand Up @@ -519,6 +522,50 @@ def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

def test_text_generation(self):
"""Tests the text generation model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_BISON_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = preview_language_models.TextGenerationModel.from_pretrained(
"text-bison@001"
)

mock_get_publisher_model.assert_called_once_with(
name="publishers/google/models/text-bison@001", retry=base._DEFAULT_RETRY
)

assert (
model._model_resource_name
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/text-bison@001"
)

gca_predict_response = gca_prediction_service.PredictResponse()
gca_predict_response.predictions.append(_TEST_TEXT_GENERATION_PREDICTION)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
):
response = model.predict(
"What is the best recipe for banana bread? Recipe:",
max_output_tokens=128,
temperature=0,
top_p=1,
top_k=5,
)

assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]

def test_text_generation_ga(self):
"""Tests the text generation model."""
aiplatform.init(
project=_TEST_PROJECT,
Expand Down Expand Up @@ -596,7 +643,7 @@ def test_tune_model(
_TEXT_BISON_PUBLISHER_MODEL_DICT
),
):
model = language_models.TextGenerationModel.from_pretrained(
model = preview_language_models.TextGenerationModel.from_pretrained(
"text-bison@001"
)

Expand Down Expand Up @@ -631,7 +678,7 @@ def test_get_tuned_model(
_TEXT_BISON_PUBLISHER_MODEL_DICT
),
):
tuned_model = language_models.TextGenerationModel.get_tuned_model(
tuned_model = preview_language_models.TextGenerationModel.get_tuned_model(
test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
)

Expand All @@ -651,7 +698,7 @@ def get_tuned_model_raises_if_not_called_with_mg_model(self):
)

with pytest.raises(ValueError):
language_models.TextGenerationModel.get_tuned_model(
preview_language_models.TextGenerationModel.get_tuned_model(
test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
)

Expand All @@ -668,7 +715,7 @@ def test_chat(self):
_CHAT_BISON_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = language_models.ChatModel.from_pretrained("chat-bison@001")
model = preview_language_models.ChatModel.from_pretrained("chat-bison@001")

mock_get_publisher_model.assert_called_once_with(
name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY
Expand All @@ -681,11 +728,11 @@ def test_chat(self):
My favorite movies are Lord of the Rings and Hobbit.
""",
examples=[
language_models.InputOutputTextPair(
preview_language_models.InputOutputTextPair(
input_text="Who do you work for?",
output_text="I work for Ned.",
),
language_models.InputOutputTextPair(
preview_language_models.InputOutputTextPair(
input_text="What do I like?",
output_text="Ned likes watching movies.",
),
Expand Down Expand Up @@ -786,7 +833,7 @@ def test_code_chat(self):
_CODECHAT_BISON_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = language_models.CodeChatModel.from_pretrained(
model = preview_language_models.CodeChatModel.from_pretrained(
"google/codechat-bison@001"
)

Expand Down Expand Up @@ -882,7 +929,7 @@ def test_code_generation(self):
_CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = language_models.CodeGenerationModel.from_pretrained(
model = preview_language_models.CodeGenerationModel.from_pretrained(
"google/code-bison@001"
)

Expand All @@ -909,9 +956,11 @@ def test_code_generation(self):
# Validating the parameters
predict_temperature = 0.1
predict_max_output_tokens = 100
default_temperature = language_models.CodeGenerationModel._DEFAULT_TEMPERATURE
default_temperature = (
preview_language_models.CodeGenerationModel._DEFAULT_TEMPERATURE
)
default_max_output_tokens = (
language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
preview_language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
)

with mock.patch.object(
Expand Down Expand Up @@ -948,7 +997,7 @@ def test_code_completion(self):
_CODE_COMPLETION_BISON_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = language_models.CodeGenerationModel.from_pretrained(
model = preview_language_models.CodeGenerationModel.from_pretrained(
"google/code-gecko@001"
)

Expand All @@ -975,9 +1024,11 @@ def test_code_completion(self):
# Validating the parameters
predict_temperature = 0.1
predict_max_output_tokens = 100
default_temperature = language_models.CodeGenerationModel._DEFAULT_TEMPERATURE
default_temperature = (
preview_language_models.CodeGenerationModel._DEFAULT_TEMPERATURE
)
default_max_output_tokens = (
language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
preview_language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
)

with mock.patch.object(
Expand All @@ -1002,6 +1053,43 @@ def test_code_completion(self):
assert prediction_parameters["maxOutputTokens"] == default_max_output_tokens

def test_text_embedding(self):
"""Tests the text embedding model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = preview_language_models.TextEmbeddingModel.from_pretrained(
"textembedding-gecko@001"
)

mock_get_publisher_model.assert_called_once_with(
name="publishers/google/models/textembedding-gecko@001",
retry=base._DEFAULT_RETRY,
)

gca_predict_response = gca_prediction_service.PredictResponse()
gca_predict_response.predictions.append(_TEST_TEXT_EMBEDDING_PREDICTION)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
):
embeddings = model.get_embeddings(["What is life?"])
assert embeddings
for embedding in embeddings:
vector = embedding.values
assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH
assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"]

def test_text_embedding_ga(self):
"""Tests the text embedding model."""
aiplatform.init(
project=_TEST_PROJECT,
Expand Down
10 changes: 1 addition & 9 deletions tests/unit/aiplatform/test_model_garden_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import pytest
from importlib import reload
from unittest import mock
from typing import Dict, Type

from google.cloud import aiplatform
from google.cloud.aiplatform import base
Expand Down Expand Up @@ -53,14 +52,7 @@ class TestModelGardenModels:
"""Unit tests for the _ModelGardenModel base class."""

class FakeModelGardenModel(_model_garden_models._ModelGardenModel):
@staticmethod
def _get_public_preview_class_map() -> Dict[
str, Type[_model_garden_models._ModelGardenModel]
]:
test_map = {
"gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml": TestModelGardenModels.FakeModelGardenModel
}
return test_map
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml"

def setup_method(self):
reload(initializer)
Expand Down
25 changes: 11 additions & 14 deletions vertexai/_model_garden/_model_garden_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def _get_model_info(
)

if not interface_class:
raise ValueError(f"Unknown model {publisher_model_res.name}")
raise ValueError(
f"Unknown model {publisher_model_res.name}; {schema_to_class_map}"
)

return _ModelInfo(
endpoint_name=endpoint_name,
Expand All @@ -120,18 +122,8 @@ def _get_model_info(
class _ModelGardenModel:
"""Base class for shared methods and properties across Model Garden models."""

@staticmethod
@abc.abstractmethod
def _get_public_preview_class_map() -> Dict[str, Type["_ModelGardenModel"]]:
"""Returns a Dict mapping schema URI to model class.
Subclasses should implement this method. Example mapping:
{
"gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml": TextGenerationModel
}
"""
pass
# Subclasses override this attribute to specify their instance schema
_INSTANCE_SCHEMA_URI: Optional[str] = None

def __init__(self, model_id: str, endpoint_name: Optional[str] = None):
"""Creates a _ModelGardenModel.
Expand Down Expand Up @@ -168,8 +160,13 @@ def from_pretrained(cls, model_name: str) -> "_ModelGardenModel":
ValueError: If model does not support this class.
"""

if not cls._INSTANCE_SCHEMA_URI:
raise ValueError(
f"Class {cls} is not a correct model interface class since it does not have an instance schema URI."
)

model_info = _get_model_info(
model_id=model_name, schema_to_class_map=cls._get_public_preview_class_map()
model_id=model_name, schema_to_class_map={cls._INSTANCE_SCHEMA_URI: cls}
)

if not issubclass(model_info.interface_class, cls):
Expand Down
17 changes: 16 additions & 1 deletion vertexai/language_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Classes for working with language models."""

from vertexai.language_models import _language_models
from vertexai.language_models._language_models import (
InputOutputTextPair,
TextEmbedding,
TextEmbeddingModel,
TextGenerationModel,
TextGenerationResponse,
)

__all__ = [
"InputOutputTextPair",
"TextEmbedding",
"TextEmbeddingModel",
"TextGenerationModel",
"TextGenerationResponse",
]
Loading

0 comments on commit 76465e2

Please sign in to comment.