Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed VertexAIModelGarden class #13917

Merged
merged 1 commit into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 57 additions & 52 deletions libs/langchain/langchain/llms/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
PredictionServiceAsyncClient,
PredictionServiceClient,
)
from google.cloud.aiplatform.models import Prediction
from google.protobuf.struct_pb2 import Value
from vertexai.language_models._language_models import (
TextGenerationResponse,
_LanguageModel,
Expand Down Expand Up @@ -370,9 +372,11 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
endpoint_id: str
"A name of an endpoint where the model has been deployed."
allowed_model_args: Optional[List[str]] = None
"""Allowed optional args to be passed to the model."""
"Allowed optional args to be passed to the model."
prompt_arg: str = "prompt"
result_arg: str = "generated_text"
result_arg: Optional[str] = "generated_text"
"Set result_arg to None if output of the model is expected to be a string."
"Otherwise, if it's a dict, provided an argument that contains the result."

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
Expand All @@ -386,7 +390,7 @@ def validate_environment(cls, values: Dict) -> Dict:
except ImportError:
raise_vertex_import_error()

if values["project"] is None:
if not values["project"]:
raise ValueError(
"A GCP project should be provided to run inference on Model Garden!"
)
Expand All @@ -401,20 +405,18 @@ def validate_environment(cls, values: Dict) -> Dict:
values["async_client"] = PredictionServiceAsyncClient(
client_options=client_options, client_info=client_info
)
values["endpoint_path"] = values["client"].endpoint_path(
project=values["project"],
location=values["location"],
endpoint=values["endpoint_id"],
)
return values

@property
def _llm_type(self) -> str:
return "vertexai_model_garden"

def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]:
try:
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
Expand All @@ -423,7 +425,6 @@ def _generate(
"protobuf package not found, please install it with"
" `pip install protobuf`"
)

instances = []
for prompt in prompts:
if self.allowed_model_args:
Expand All @@ -438,18 +439,53 @@ def _generate(
predict_instances = [
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
]
return predict_instances

endpoint = self.client.endpoint_path(
project=self.project, location=self.location, endpoint=self.endpoint_id
)
response = self.client.predict(endpoint=endpoint, instances=predict_instances)
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
instances = self._prepare_request(prompts, **kwargs)
response = self.client.predict(endpoint=self.endpoint_path, instances=instances)
return self._parse_response(response)

def _parse_response(self, predictions: "Prediction") -> LLMResult:
generations: List[List[Generation]] = []
for result in response.predictions:
for result in predictions.predictions:
generations.append(
[Generation(text=prediction[self.result_arg]) for prediction in result]
[
Generation(text=self._parse_prediction(prediction))
for prediction in result
]
)
return LLMResult(generations=generations)

def _parse_prediction(self, prediction: Any) -> str:
if isinstance(prediction, str):
return prediction

if self.result_arg:
try:
return prediction[self.result_arg]
except KeyError:
if isinstance(prediction, str):
error_desc = (
"Provided non-None `result_arg` (result_arg="
f"{self.result_arg}). But got prediction of type "
f"{type(prediction)} instead of dict. Most probably, you"
"need to set `result_arg=None` during VertexAIModelGarden "
"initialization."
)
raise ValueError(error_desc)
else:
raise ValueError(f"{self.result_arg} key not found in prediction!")

return prediction

async def _agenerate(
self,
prompts: List[str],
Expand All @@ -458,39 +494,8 @@ async def _agenerate(
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
try:
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
except ImportError:
raise ImportError(
"protobuf package not found, please install it with"
" `pip install protobuf`"
)

instances = []
for prompt in prompts:
if self.allowed_model_args:
instance = {
k: v for k, v in kwargs.items() if k in self.allowed_model_args
}
else:
instance = {}
instance[self.prompt_arg] = prompt
instances.append(instance)

predict_instances = [
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
]

endpoint = self.async_client.endpoint_path(
project=self.project, location=self.location, endpoint=self.endpoint_id
)
instances = self._prepare_request(prompts, **kwargs)
response = await self.async_client.predict(
endpoint=endpoint, instances=predict_instances
endpoint=self.endpoint_path, instances=instances
)
generations: List[List[Generation]] = []
for result in response.predictions:
generations.append(
[Generation(text=prediction[self.result_arg]) for prediction in result]
)
return LLMResult(generations=generations)
return self._parse_response(response)
66 changes: 53 additions & 13 deletions libs/langchain/tests/integration_tests/llms/test_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
`gcloud auth login` first).
"""
import os
from typing import Optional

import pytest
from langchain_core.outputs import LLMResult
Expand Down Expand Up @@ -71,40 +72,79 @@ async def test_vertex_consistency() -> None:
assert output.generations[0][0].text == async_output.generations[0][0].text


def test_model_garden() -> None:
"""In order to run this test, you should provide an endpoint name.
@pytest.mark.parametrize(
"endpoint_os_variable_name,result_arg",
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
)
def test_model_garden(
endpoint_os_variable_name: str, result_arg: Optional[str]
) -> None:
"""In order to run this test, you should provide endpoint names.

Example:
export ENDPOINT_ID=...
export FALCON_ENDPOINT_ID=...
export LLAMA_ENDPOINT_ID=...
export PROJECT=...
"""
endpoint_id = os.environ["ENDPOINT_ID"]
endpoint_id = os.environ[endpoint_os_variable_name]
project = os.environ["PROJECT"]
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
location = "europe-west4"
llm = VertexAIModelGarden(
endpoint_id=endpoint_id,
project=project,
result_arg=result_arg,
location=location,
)
output = llm("What is the meaning of life?")
assert isinstance(output, str)
assert llm._llm_type == "vertexai_model_garden"


def test_model_garden_generate() -> None:
"""In order to run this test, you should provide an endpoint name.
@pytest.mark.parametrize(
"endpoint_os_variable_name,result_arg",
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
)
def test_model_garden_generate(
endpoint_os_variable_name: str, result_arg: Optional[str]
) -> None:
"""In order to run this test, you should provide endpoint names.

Example:
export ENDPOINT_ID=...
export FALCON_ENDPOINT_ID=...
export LLAMA_ENDPOINT_ID=...
export PROJECT=...
"""
endpoint_id = os.environ["ENDPOINT_ID"]
endpoint_id = os.environ[endpoint_os_variable_name]
project = os.environ["PROJECT"]
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
location = "europe-west4"
llm = VertexAIModelGarden(
endpoint_id=endpoint_id,
project=project,
result_arg=result_arg,
location=location,
)
output = llm.generate(["What is the meaning of life?", "How much is 2+2"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 2


async def test_model_garden_agenerate() -> None:
endpoint_id = os.environ["ENDPOINT_ID"]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"endpoint_os_variable_name,result_arg",
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
)
async def test_model_garden_agenerate(
endpoint_os_variable_name: str, result_arg: Optional[str]
) -> None:
endpoint_id = os.environ[endpoint_os_variable_name]
project = os.environ["PROJECT"]
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
location = "europe-west4"
llm = VertexAIModelGarden(
endpoint_id=endpoint_id,
project=project,
result_arg=result_arg,
location=location,
)
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 2
Expand Down
Loading