Skip to content

Commit

Permalink
vertexai[patch]: standardize model params (#121)
Browse files Browse the repository at this point in the history
* vertexai[patch]: standardize model params
  • Loading branch information
baskaryan committed Apr 29, 2024
1 parent 3c9efa1 commit bfb6eb9
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 23 deletions.
13 changes: 9 additions & 4 deletions libs/vertexai/langchain_google_vertexai/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,16 @@ class _VertexAIBase(BaseModel):
max_retries: int = 6
"""The maximum number of retries to make when generating."""
task_executor: ClassVar[Optional[Executor]] = Field(default=None, exclude=True)
stop: Optional[List[str]] = None
stop: Optional[List[str]] = Field(default=None, alias="stop_sequences")
"Optional list of stop words to use when generating."
model_name: Optional[str] = None
model_name: Optional[str] = Field(default=None, alias="model")
"Underlying model name."

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

@root_validator(pre=True)
def validate_params(cls, values: dict) -> dict:
if "model" in values and "model_name" not in values:
Expand All @@ -64,11 +69,11 @@ def validate_params(cls, values: dict) -> dict:

class _VertexAICommon(_VertexAIBase):
client_preview: Any = None #: :meta private:
model_name: str
model_name: str = Field(default=None, alias="model")
"Underlying model name."
temperature: Optional[float] = None
"Sampling temperature, it controls the degree of randomness in token selection."
max_output_tokens: Optional[int] = None
max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
"Token limit determines the maximum amount of text output from one prompt."
top_p: Optional[float] = None
"Tokens are selected from most probable to least until the sum of their "
Expand Down
16 changes: 14 additions & 2 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
)
from langchain_core.output_parsers.openai_tools import parse_tool_calls
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.pydantic_v1 import BaseModel, root_validator, Field
from langchain_core.runnables import Runnable, RunnablePassthrough
from vertexai.generative_models import ( # type: ignore
Candidate,
Expand Down Expand Up @@ -498,7 +498,7 @@ async def _completion_with_retry_inner(
class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""`Vertex AI` Chat large language models API."""

model_name: str = "chat-bison"
model_name: str = Field(default="chat-bison", alias="model")
"Underlying model name."
examples: Optional[List[BaseMessage]] = None
tuned_model_name: Optional[str] = None
Expand All @@ -510,6 +510,18 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
setting this parameter to True is discouraged.
"""

def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
"""Needed for mypy typing to recognize model_name as a valid arg."""
if model_name:
kwargs["model_name"] = model_name
super().__init__(**kwargs)

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True
arbitrary_types_allowed = True

@classmethod
def is_lc_serializable(self) -> bool:
return True
Expand Down
5 changes: 3 additions & 2 deletions libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def validate_environment(cls, values: Dict) -> Dict:

def __init__(
self,
model_name: str,
model_name: Optional[str] = None,
project: Optional[str] = None,
location: str = "us-central1",
request_parallelism: int = 5,
Expand All @@ -87,13 +87,14 @@ def __init__(
**kwargs: Any,
):
"""Initialize the sentence_transformer."""
if model_name:
kwargs["model_name"] = model_name
super().__init__(
project=project,
location=location,
credentials=credentials,
request_parallelism=request_parallelism,
max_retries=max_retries,
model_name=model_name,
**kwargs,
)
self.instance["max_batch_size"] = kwargs.get("max_batch_size", _MAX_BATCH_SIZE)
Expand Down
35 changes: 32 additions & 3 deletions libs/vertexai/langchain_google_vertexai/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Generation,
LLMResult,
)
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator

from langchain_google_vertexai._base import _BaseVertexAIModelGarden
from langchain_google_vertexai._utils import enforce_stop_tokens
Expand Down Expand Up @@ -115,6 +115,17 @@ class GemmaChatVertexAIModelGarden(_GemmaBase, _BaseVertexAIModelGarden, BaseCha
"""Whether to post-process the chat response and clean repeations """
"""or multi-turn statements."""

def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
"""Needed for mypy typing to recognize model_name as a valid arg."""
if model_name:
kwargs["model_name"] = model_name
super().__init__(**kwargs)

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

@property
def _llm_type(self) -> str:
return "gemma_vertexai_model_garden"
Expand Down Expand Up @@ -178,9 +189,15 @@ class _GemmaLocalKaggleBase(_GemmaBase):

client: Any = None #: :meta private:
keras_backend: str = "jax"
model_name: str = "gemma_2b_en"
model_name: str = Field(default="gemma_2b_en", alias="model")
"""Gemma model name."""

def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
"""Needed for mypy typing to recognize model_name as a valid arg."""
if model_name:
kwargs["model_name"] = model_name
super().__init__(**kwargs)

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that llama-cpp-python library is installed."""
Expand Down Expand Up @@ -212,6 +229,12 @@ def _get_params(self, **kwargs) -> Dict[str, Any]:
class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM):
"""Local gemma chat model loaded from Kaggle."""

def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
"""Only needed for typing."""
if model_name:
kwargs["model_name"] = model_name
super().__init__(**kwargs)

def _generate(
self,
prompts: List[str],
Expand All @@ -238,6 +261,12 @@ class GemmaChatLocalKaggle(_GemmaLocalKaggleBase, BaseChatModel):
"""Whether to post-process the chat response and clean repeations """
"""or multi-turn statements."""

def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
"""Needed for mypy typing to recognize model_name as a valid arg."""
if model_name:
kwargs["model_name"] = model_name
super().__init__(**kwargs)

def _generate(
self,
messages: List[BaseMessage],
Expand Down Expand Up @@ -268,7 +297,7 @@ class _GemmaLocalHFBase(_GemmaBase):
client: Any = None #: :meta private:
hf_access_token: str
cache_dir: Optional[str] = None
model_name: str = "gemma_2b_en"
model_name: str = Field(default="gemma_2b_en", alias="model")
"""Gemma model name."""

@root_validator()
Expand Down
15 changes: 13 additions & 2 deletions libs/vertexai/langchain_google_vertexai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import root_validator
from langchain_core.pydantic_v1 import Field, root_validator
from vertexai.generative_models import ( # type: ignore[import-untyped]
Candidate,
GenerativeModel,
Expand Down Expand Up @@ -110,13 +110,24 @@ async def _acompletion_with_retry_inner(
class VertexAI(_VertexAICommon, BaseLLM):
"""Google Vertex AI large language models."""

model_name: str = "text-bison"
model_name: str = Field(default="text-bison", alias="model")
"The name of the Vertex AI large language model."
tuned_model_name: Optional[str] = None
"""The name of a tuned model. If tuned_model_name is passed
model_name will be used to determine the model family
"""

def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
"""Needed for mypy typing to recognize model_name as a valid arg."""
if model_name:
kwargs["model_name"] = model_name
super().__init__(**kwargs)

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

@classmethod
def is_lc_serializable(self) -> bool:
return True
Expand Down
16 changes: 13 additions & 3 deletions libs/vertexai/langchain_google_vertexai/model_garden.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
Generation,
LLMResult,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.pydantic_v1 import Field, root_validator

from langchain_google_vertexai._anthropic_utils import _format_messages_anthropic
from langchain_google_vertexai._base import _BaseVertexAIModelGarden, _VertexAICommon
Expand All @@ -34,6 +34,11 @@
class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM):
"""Large language models served from Vertex AI Model Garden."""

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

def _generate(
self,
prompts: List[str],
Expand Down Expand Up @@ -92,9 +97,14 @@ async def _agenerate(

class ChatAnthropicVertex(_VertexAICommon, BaseChatModel):
async_client: Any = None #: :meta private:
model_name: Optional[str] = None # type: ignore[assignment]
model_name: Optional[str] = Field(default=None, alias="model") # type: ignore[assignment]
"Underlying model name."
max_output_tokens: int = 1024
max_output_tokens: int = Field(default=1024, alias="max_tokens")

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
Expand Down
2 changes: 1 addition & 1 deletion libs/vertexai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ build-backend = "poetry.core.masonry.api"
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
addopts = "--strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
Expand Down
8 changes: 7 additions & 1 deletion libs/vertexai/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
@pytest.mark.release
def test_initialization() -> None:
"""Test embedding model initialization."""
VertexAIEmbeddings(model_name="textembedding-gecko@001")
for embeddings in [
VertexAIEmbeddings(
model_name="textembedding-gecko",
),
VertexAIEmbeddings(model="textembedding-gecko"),
]:
assert embeddings.model_name == "textembedding-gecko"


@pytest.mark.release
Expand Down
18 changes: 15 additions & 3 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,24 @@
)


def test_model_name() -> None:
def test_init() -> None:
for llm in [
ChatVertexAI(model_name="gemini-pro", project="test-project"),
ChatVertexAI(model="gemini-pro", project="test-project"), # type: ignore[call-arg]
ChatVertexAI(
model_name="gemini-pro",
project="test-project",
max_output_tokens=10,
stop=["bar"],
),
ChatVertexAI(
model="gemini-pro",
project="test-project",
max_tokens=10,
stop_sequences=["bar"],
),
]:
assert llm.model_name == "gemini-pro"
assert llm.max_output_tokens == 10
assert llm.stop == ["bar"]


def test_tuned_model_name() -> None:
Expand Down
5 changes: 3 additions & 2 deletions libs/vertexai/tests/unit_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

def test_model_name() -> None:
for llm in [
VertexAI(model_name="gemini-pro", project="test-project"),
VertexAI(model="gemini-pro", project="test-project"), # type: ignore[call-arg]
VertexAI(model_name="gemini-pro", project="test-project", max_output_tokens=10),
VertexAI(model="gemini-pro", project="test-project", max_tokens=10),
]:
assert llm.model_name == "gemini-pro"
assert llm.max_output_tokens == 10


def test_tuned_model_name() -> None:
Expand Down

0 comments on commit bfb6eb9

Please sign in to comment.