Skip to content

Commit

Permalink
vertexai[patch]: support model param (#81)
Browse files Browse the repository at this point in the history
* vertexai[patch]: support `model` param

---------

Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru>
  • Loading branch information
baskaryan and lkuligin authored Mar 26, 2024
1 parent 58f8c6b commit 73f83ea
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 174 deletions.
10 changes: 8 additions & 2 deletions libs/vertexai/langchain_google_vertexai/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ class _VertexAIBase(BaseModel):
model_name: Optional[str] = None
"Underlying model name."

@root_validator(pre=True)
def validate_params(cls, values: dict) -> dict:
if "model" in values and "model_name" not in values:
values["model_name"] = values.pop("model")
return values


class _VertexAICommon(_VertexAIBase):
client_preview: Any = None #: :meta private:
Expand Down Expand Up @@ -137,9 +143,9 @@ def _default_params(self) -> Dict[str, Any]:
updated_params = {}
for param_name, param_value in params.items():
default_value = default_params.get(param_name)
if param_value or default_value:
if param_value is not None or default_value is not None:
updated_params[param_name] = (
param_value if param_value else default_value
param_value if param_value is not None else default_value
)
return updated_params

Expand Down
175 changes: 3 additions & 172 deletions libs/vertexai/langchain_google_vertexai/llms.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from __future__ import annotations

from concurrent.futures import Executor
from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Optional, Union
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union

import vertexai # type: ignore[import-untyped]
from google.cloud.aiplatform import telemetry
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.pydantic_v1 import root_validator
from vertexai.generative_models import ( # type: ignore[import-untyped]
Candidate,
GenerativeModel,
Expand All @@ -25,29 +23,18 @@
TextGenerationResponse,
)
from vertexai.preview.language_models import ( # type: ignore[import-untyped]
ChatModel as PreviewChatModel,
)
from vertexai.preview.language_models import (
CodeChatModel as PreviewCodeChatModel,
)
from vertexai.preview.language_models import (
CodeGenerationModel as PreviewCodeGenerationModel,
)
from vertexai.preview.language_models import (
TextGenerationModel as PreviewTextGenerationModel,
)

from langchain_google_vertexai._base import (
_PALM_DEFAULT_MAX_OUTPUT_TOKENS,
_PALM_DEFAULT_TEMPERATURE,
_PALM_DEFAULT_TOP_K,
_PALM_DEFAULT_TOP_P,
_VertexAICommon,
)
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
from langchain_google_vertexai._utils import (
create_retry_decorator,
get_generation_info,
get_user_agent,
is_codey_model,
is_gemini_model,
)
Expand Down Expand Up @@ -120,162 +107,6 @@ async def _acompletion_with_retry_inner(
)


class _VertexAIBase(BaseModel):
project: Optional[str] = None
"The default GCP project to use when making Vertex API calls."
location: str = "us-central1"
"The default location to use when making API calls."
request_parallelism: int = 5
"The amount of parallelism allowed for requests issued to VertexAI models. "
"Default is 5."
max_retries: int = 6
"""The maximum number of retries to make when generating."""
task_executor: ClassVar[Optional[Executor]] = Field(default=None, exclude=True)
stop: Optional[List[str]] = None
"Optional list of stop words to use when generating."
model_name: Optional[str] = None
"Underlying model name."


class _VertexAICommon(_VertexAIBase):
client: Any = None #: :meta private:
client_preview: Any = None #: :meta private:
model_name: str
"Underlying model name."
temperature: Optional[float] = None
"Sampling temperature, it controls the degree of randomness in token selection."
max_output_tokens: Optional[int] = None
"Token limit determines the maximum amount of text output from one prompt."
top_p: Optional[float] = None
"Tokens are selected from most probable to least until the sum of their "
"probabilities equals the top-p value. Top-p is ignored for Codey models."
top_k: Optional[int] = None
"How the model selects tokens for output, the next token is selected from "
"among the top-k most probable tokens. Top-k is ignored for Codey models."
credentials: Any = Field(default=None, exclude=True)
"The default custom credentials (google.auth.credentials.Credentials) to use "
"when making API calls. If not provided, credentials will be ascertained from "
"the environment."
n: int = 1
"""How many completions to generate for each prompt."""
streaming: bool = False
"""Whether to stream the results or not."""
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
"""The default safety settings to use for all generations.
For example:
from langchain_google_vertexai import HarmBlockThreshold, HarmCategory
safety_settings = {
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}
""" # noqa: E501

@property
def _llm_type(self) -> str:
return "vertexai"

@property
def is_codey_model(self) -> bool:
return is_codey_model(self.model_name)

@property
def _is_gemini_model(self) -> bool:
return is_gemini_model(self.model_name)

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Gets the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params}

@property
def _default_params(self) -> Dict[str, Any]:
if self._is_gemini_model:
default_params = {}
else:
default_params = {
"temperature": _PALM_DEFAULT_TEMPERATURE,
"max_output_tokens": _PALM_DEFAULT_MAX_OUTPUT_TOKENS,
"top_p": _PALM_DEFAULT_TOP_P,
"top_k": _PALM_DEFAULT_TOP_K,
}
params = {
"temperature": self.temperature,
"max_output_tokens": self.max_output_tokens,
"candidate_count": self.n,
}
if not self.is_codey_model:
params.update(
{
"top_k": self.top_k,
"top_p": self.top_p,
}
)
updated_params = {}
for param_name, param_value in params.items():
default_value = default_params.get(param_name)
if param_value is not None or default_value is not None:
updated_params[param_name] = (
param_value if param_value is not None else default_value
)
return updated_params

@property
def _user_agent(self) -> str:
"""Gets the User Agent."""
_, user_agent = get_user_agent(f"{type(self).__name__}_{self.model_name}")
return user_agent

@classmethod
def _init_vertexai(cls, values: Dict) -> None:
vertexai.init(
project=values.get("project"),
location=values.get("location"),
credentials=values.get("credentials"),
)
return None

def _prepare_params(
self,
stop: Optional[List[str]] = None,
stream: bool = False,
**kwargs: Any,
) -> dict:
stop_sequences = stop or self.stop
params_mapping = {"n": "candidate_count"}
params = {params_mapping.get(k, k): v for k, v in kwargs.items()}
params = {**self._default_params, "stop_sequences": stop_sequences, **params}
if stream or self.streaming:
params.pop("candidate_count")
return params

def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
is_palm_chat_model = isinstance(
self.client_preview, PreviewChatModel
) or isinstance(self.client_preview, PreviewCodeChatModel)
if is_palm_chat_model:
result = self.client_preview.start_chat().count_tokens(text)
else:
result = self.client_preview.count_tokens([text])

return result.total_tokens


class VertexAI(_VertexAICommon, BaseLLM):
"""Google Vertex AI large language models."""

Expand Down
9 changes: 9 additions & 0 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
)


def test_model_name() -> None:
for llm in [
ChatVertexAI(model_name="gemini-pro", project="test-project"),
ChatVertexAI(model="gemini-pro", project="test-project"), # type: ignore[call-arg]
]:
assert llm.model_name == "gemini-pro"


def test_parse_examples_correct() -> None:
text_question = (
"Hello, could you recommend a good movie for me to watch this evening, please?"
Expand Down Expand Up @@ -178,6 +186,7 @@ def test_default_params_palm() -> None:
top_k=40,
top_p=0.95,
stop_sequences=None,
temperature=0.0,
)


Expand Down
8 changes: 8 additions & 0 deletions libs/vertexai/tests/unit_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
from langchain_google_vertexai.llms import VertexAI


def test_model_name() -> None:
for llm in [
VertexAI(model_name="gemini-pro", project="test-project"),
VertexAI(model="gemini-pro", project="test-project"), # type: ignore[call-arg]
]:
assert llm.model_name == "gemini-pro"


def test_vertexai_args_passed() -> None:
response_text = "Goodbye"
user_prompt = "Hello"
Expand Down

0 comments on commit 73f83ea

Please sign in to comment.