Skip to content

Commit

Permalink
feat: GenAI - Release the GenerativeModel to GA
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605495995
  • Loading branch information
Ark-kun authored and Copybara-Service committed Feb 9, 2024
1 parent ecc6454 commit c7e3f07
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tests/system/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from google import auth
from google.cloud import aiplatform
from tests.system.aiplatform import e2e_base
from vertexai.preview import generative_models
from vertexai import generative_models


class TestGenerativeModels(e2e_base.TestEndToEnd):
Expand Down
27 changes: 22 additions & 5 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

import vertexai
from google.cloud.aiplatform import initializer
from vertexai.preview import generative_models
from vertexai import generative_models
from vertexai.preview import generative_models as preview_generative_models
from vertexai.generative_models._generative_models import (
prediction_service,
gapic_prediction_service_types,
Expand Down Expand Up @@ -231,7 +232,11 @@ def teardown_method(self):
attribute="generate_content",
new=mock_generate_content,
)
def test_generate_content(self):
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
)
def test_generate_content(self, generative_models: generative_models):
model = generative_models.GenerativeModel("gemini-pro")
response = model.generate_content("Why is sky blue?")
assert response.text
Expand All @@ -254,7 +259,11 @@ def test_generate_content(self):
attribute="stream_generate_content",
new=mock_stream_generate_content,
)
def test_generate_content_streaming(self):
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
)
def test_generate_content_streaming(self, generative_models: generative_models):
model = generative_models.GenerativeModel("gemini-pro")
stream = model.generate_content("Why is sky blue?", stream=True)
for chunk in stream:
Expand All @@ -265,7 +274,11 @@ def test_generate_content_streaming(self):
attribute="generate_content",
new=mock_generate_content,
)
def test_chat_send_message(self):
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
)
def test_chat_send_message(self, generative_models: generative_models):
model = generative_models.GenerativeModel("gemini-pro")
chat = model.start_chat()
response1 = chat.send_message("Why is sky blue?")
Expand All @@ -278,7 +291,11 @@ def test_chat_send_message(self):
attribute="generate_content",
new=mock_generate_content,
)
def test_chat_function_calling(self):
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
)
def test_chat_function_calling(self, generative_models: generative_models):
get_current_weather_func = generative_models.FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
Expand Down
51 changes: 51 additions & 0 deletions vertexai/generative_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Classes for working with the Gemini models."""

# We just want to re-export certain classes
# pylint: disable=g-multiple-import,g-importing-member
from vertexai.generative_models._generative_models import (
GenerativeModel,
GenerationConfig,
GenerationResponse,
Candidate,
ChatSession,
Content,
FinishReason,
FunctionDeclaration,
HarmCategory,
HarmBlockThreshold,
Image,
Part,
ResponseBlockedError,
Tool,
)

__all__ = [
"GenerationConfig",
"GenerativeModel",
"GenerationResponse",
"Candidate",
"ChatSession",
"Content",
"FinishReason",
"FunctionDeclaration",
"HarmCategory",
"HarmBlockThreshold",
"Image",
"Part",
"ResponseBlockedError",
"Tool",
]
4 changes: 4 additions & 0 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1896,6 +1896,10 @@ def _tune_model(
return job


class GenerativeModel(_GenerativeModel):
__module__ = "vertexai.generative_models"


class _PreviewGenerativeModel(_GenerativeModel, _TunableGenerativeModelMixin):
__name__ = "GenerativeModel"
__module__ = "vertexai.preview.generative_models"

0 comments on commit c7e3f07

Please sign in to comment.