Skip to content

Commit

Permalink
Add support for tuned gemini models (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
eliasecchig committed Apr 12, 2024
1 parent 7eebc05 commit 42c12de
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 17 deletions.
21 changes: 17 additions & 4 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,10 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
model_name: str = "chat-bison"
"Underlying model name."
examples: Optional[List[BaseMessage]] = None
tuned_model_name: Optional[str] = None
"""The name of a tuned model. If tuned_model_name is passed
model_name will be used to determine the model family
"""
convert_system_message_to_human: bool = False
"""[Deprecated] Since new Gemini models support setting a System Message,
setting this parameter to True is discouraged.
Expand All @@ -407,17 +411,26 @@ def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
is_gemini = is_gemini_model(values["model_name"])
safety_settings = values["safety_settings"]
tuned_model_name = values.get("tuned_model_name")

if safety_settings and not is_gemini:
raise ValueError("Safety settings are only supported for Gemini models")

cls._init_vertexai(values)

if tuned_model_name:
generative_model_name = values["tuned_model_name"]
else:
generative_model_name = values["model_name"]

if is_gemini:
values["client"] = GenerativeModel(
model_name=values["model_name"], safety_settings=safety_settings
model_name=generative_model_name,
safety_settings=safety_settings,
)
values["client_preview"] = GenerativeModel(
model_name=values["model_name"], safety_settings=safety_settings
model_name=generative_model_name,
safety_settings=safety_settings,
)
else:
if is_codey_model(values["model_name"]):
Expand All @@ -426,9 +439,9 @@ def validate_environment(cls, values: Dict) -> Dict:
else:
model_cls = ChatModel
model_cls_preview = PreviewChatModel
values["client"] = model_cls.from_pretrained(values["model_name"])
values["client"] = model_cls.from_pretrained(generative_model_name)
values["client_preview"] = model_cls_preview.from_pretrained(
values["model_name"]
generative_model_name
)
return values

Expand Down
34 changes: 21 additions & 13 deletions libs/vertexai/langchain_google_vertexai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ class VertexAI(_VertexAICommon, BaseLLM):
model_name: str = "text-bison"
"The name of the Vertex AI large language model."
tuned_model_name: Optional[str] = None
"The name of a tuned model. If provided, model_name is ignored."
"""The name of a tuned model. If tuned_model_name is passed
model_name will be used to determine the model family
"""

@classmethod
def is_lc_serializable(self) -> bool:
Expand Down Expand Up @@ -147,22 +149,28 @@ def validate_environment(cls, values: Dict) -> Dict:
preview_model_cls = PreviewTextGenerationModel

if tuned_model_name:
values["client"] = model_cls.get_tuned_model(tuned_model_name)
values["client_preview"] = preview_model_cls.get_tuned_model(
tuned_model_name
generative_model_name = values["tuned_model_name"]
else:
generative_model_name = values["model_name"]

if is_gemini:
values["client"] = model_cls(
model_name=generative_model_name, safety_settings=safety_settings
)
values["client_preview"] = preview_model_cls(
model_name=generative_model_name, safety_settings=safety_settings
)
else:
if is_gemini:
values["client"] = model_cls(
model_name=model_name, safety_settings=safety_settings
)
values["client_preview"] = preview_model_cls(
model_name=model_name, safety_settings=safety_settings
if tuned_model_name:
values["client"] = model_cls.get_tuned_model(generative_model_name)
values["client_preview"] = preview_model_cls.get_tuned_model(
generative_model_name
)
else:
values["client"] = model_cls.from_pretrained(model_name)
values["client_preview"] = preview_model_cls.from_pretrained(model_name)

values["client"] = model_cls.from_pretrained(generative_model_name)
values["client_preview"] = preview_model_cls.from_pretrained(
generative_model_name
)
if values["streaming"] and values["n"] > 1:
raise ValueError("Only one candidate can be generated with streaming!")
return values
Expand Down
11 changes: 11 additions & 0 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ def test_model_name() -> None:
assert llm.model_name == "gemini-pro"


def test_tuned_model_name() -> None:
llm = ChatVertexAI(
model_name="gemini-pro",
project="test-project",
tuned_model_name="projects/123/locations/europe-west4/endpoints/456",
)
assert llm.model_name == "gemini-pro"
assert llm.tuned_model_name == "projects/123/locations/europe-west4/endpoints/456"
assert llm.client._model_name == "projects/123/locations/europe-west4/endpoints/456"


def test_parse_examples_correct() -> None:
text_question = (
"Hello, could you recommend a good movie for me to watch this evening, please?"
Expand Down
11 changes: 11 additions & 0 deletions libs/vertexai/tests/unit_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ def test_model_name() -> None:
assert llm.model_name == "gemini-pro"


def test_tuned_model_name() -> None:
llm = VertexAI(
model_name="gemini-pro",
project="test-project",
tuned_model_name="projects/123/locations/europe-west4/endpoints/456",
)
assert llm.model_name == "gemini-pro"
assert llm.tuned_model_name == "projects/123/locations/europe-west4/endpoints/456"
assert llm.client._model_name == "projects/123/locations/europe-west4/endpoints/456"


def test_vertexai_args_passed() -> None:
response_text = "Goodbye"
user_prompt = "Hello"
Expand Down

0 comments on commit 42c12de

Please sign in to comment.