From 3ca7bc0de7b134e1511fb7672f2b92c6bf4ef6a3 Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Mon, 27 Nov 2023 17:05:55 +0100 Subject: [PATCH] fixed VertexAIModelGarden class --- libs/langchain/langchain/llms/vertexai.py | 109 +++++++++--------- .../integration_tests/llms/test_vertexai.py | 66 ++++++++--- 2 files changed, 110 insertions(+), 65 deletions(-) diff --git a/libs/langchain/langchain/llms/vertexai.py b/libs/langchain/langchain/llms/vertexai.py index 4a3c67b7818a..994cfdf6bcf7 100644 --- a/libs/langchain/langchain/llms/vertexai.py +++ b/libs/langchain/langchain/llms/vertexai.py @@ -32,6 +32,8 @@ PredictionServiceAsyncClient, PredictionServiceClient, ) + from google.cloud.aiplatform.models import Prediction + from google.protobuf.struct_pb2 import Value from vertexai.language_models._language_models import ( TextGenerationResponse, _LanguageModel, @@ -370,9 +372,11 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM): endpoint_id: str "A name of an endpoint where the model has been deployed." allowed_model_args: Optional[List[str]] = None - """Allowed optional args to be passed to the model.""" + "Allowed optional args to be passed to the model." prompt_arg: str = "prompt" - result_arg: str = "generated_text" + result_arg: Optional[str] = "generated_text" + "Set result_arg to None if output of the model is expected to be a string." + "Otherwise, if it's a dict, provided an argument that contains the result." @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -386,7 +390,7 @@ def validate_environment(cls, values: Dict) -> Dict: except ImportError: raise_vertex_import_error() - if values["project"] is None: + if not values["project"]: raise ValueError( "A GCP project should be provided to run inference on Model Garden!" ) @@ -401,20 +405,18 @@ def validate_environment(cls, values: Dict) -> Dict: values["async_client"] = PredictionServiceAsyncClient( client_options=client_options, client_info=client_info ) + values["endpoint_path"] = values["client"].endpoint_path( + project=values["project"], + location=values["location"], + endpoint=values["endpoint_id"], + ) return values @property def _llm_type(self) -> str: return "vertexai_model_garden" - def _generate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - """Run the LLM on the given prompt and input.""" + def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]: try: from google.protobuf import json_format from google.protobuf.struct_pb2 import Value @@ -423,7 +425,6 @@ def _generate( "protobuf package not found, please install it with" " `pip install protobuf`" ) - instances = [] for prompt in prompts: if self.allowed_model_args: @@ -438,18 +439,53 @@ def _generate( predict_instances = [ json_format.ParseDict(instance_dict, Value()) for instance_dict in instances ] + return predict_instances - endpoint = self.client.endpoint_path( - project=self.project, location=self.location, endpoint=self.endpoint_id - ) - response = self.client.predict(endpoint=endpoint, instances=predict_instances) + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + instances = self._prepare_request(prompts, **kwargs) + response = self.client.predict(endpoint=self.endpoint_path, instances=instances) + return self._parse_response(response) + + def _parse_response(self, predictions: "Prediction") -> LLMResult: generations: List[List[Generation]] = [] - for result in response.predictions: + for result in predictions.predictions: generations.append( - [Generation(text=prediction[self.result_arg]) for prediction in result] + [ + Generation(text=self._parse_prediction(prediction)) + for prediction in result + ] ) return LLMResult(generations=generations) + def _parse_prediction(self, prediction: Any) -> str: + if isinstance(prediction, str): + return prediction + + if self.result_arg: + try: + return prediction[self.result_arg] + except KeyError: + if isinstance(prediction, str): + error_desc = ( + "Provided non-None `result_arg` (result_arg=" + f"{self.result_arg}). But got prediction of type " + f"{type(prediction)} instead of dict. Most probably, you" + "need to set `result_arg=None` during VertexAIModelGarden " + "initialization." + ) + raise ValueError(error_desc) + else: + raise ValueError(f"{self.result_arg} key not found in prediction!") + + return prediction + async def _agenerate( self, prompts: List[str], @@ -458,39 +494,8 @@ async def _agenerate( **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" - try: - from google.protobuf import json_format - from google.protobuf.struct_pb2 import Value - except ImportError: - raise ImportError( - "protobuf package not found, please install it with" - " `pip install protobuf`" - ) - - instances = [] - for prompt in prompts: - if self.allowed_model_args: - instance = { - k: v for k, v in kwargs.items() if k in self.allowed_model_args - } - else: - instance = {} - instance[self.prompt_arg] = prompt - instances.append(instance) - - predict_instances = [ - json_format.ParseDict(instance_dict, Value()) for instance_dict in instances - ] - - endpoint = self.async_client.endpoint_path( - project=self.project, location=self.location, endpoint=self.endpoint_id - ) + instances = self._prepare_request(prompts, **kwargs) response = await self.async_client.predict( - endpoint=endpoint, instances=predict_instances + endpoint=self.endpoint_path, instances=instances ) - generations: List[List[Generation]] = [] - for result in response.predictions: - generations.append( - [Generation(text=prediction[self.result_arg]) for prediction in result] - ) - return LLMResult(generations=generations) + return self._parse_response(response) diff --git a/libs/langchain/tests/integration_tests/llms/test_vertexai.py b/libs/langchain/tests/integration_tests/llms/test_vertexai.py index ef9c8fb1b538..6ddb7044874f 100644 --- a/libs/langchain/tests/integration_tests/llms/test_vertexai.py +++ b/libs/langchain/tests/integration_tests/llms/test_vertexai.py @@ -8,6 +8,7 @@ `gcloud auth login` first). """ import os +from typing import Optional import pytest from langchain_core.outputs import LLMResult @@ -71,40 +72,79 @@ async def test_vertex_consistency() -> None: assert output.generations[0][0].text == async_output.generations[0][0].text -def test_model_garden() -> None: - """In order to run this test, you should provide an endpoint name. +@pytest.mark.parametrize( + "endpoint_os_variable_name,result_arg", + [("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)], +) +def test_model_garden( + endpoint_os_variable_name: str, result_arg: Optional[str] +) -> None: + """In order to run this test, you should provide endpoint names. Example: - export ENDPOINT_ID=... + export FALCON_ENDPOINT_ID=... + export LLAMA_ENDPOINT_ID=... export PROJECT=... """ - endpoint_id = os.environ["ENDPOINT_ID"] + endpoint_id = os.environ[endpoint_os_variable_name] project = os.environ["PROJECT"] - llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project) + location = "europe-west4" + llm = VertexAIModelGarden( + endpoint_id=endpoint_id, + project=project, + result_arg=result_arg, + location=location, + ) output = llm("What is the meaning of life?") assert isinstance(output, str) assert llm._llm_type == "vertexai_model_garden" -def test_model_garden_generate() -> None: - """In order to run this test, you should provide an endpoint name. +@pytest.mark.parametrize( + "endpoint_os_variable_name,result_arg", + [("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)], +) +def test_model_garden_generate( + endpoint_os_variable_name: str, result_arg: Optional[str] +) -> None: + """In order to run this test, you should provide endpoint names. Example: - export ENDPOINT_ID=... + export FALCON_ENDPOINT_ID=... + export LLAMA_ENDPOINT_ID=... export PROJECT=... """ - endpoint_id = os.environ["ENDPOINT_ID"] + endpoint_id = os.environ[endpoint_os_variable_name] project = os.environ["PROJECT"] - llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project) + location = "europe-west4" + llm = VertexAIModelGarden( + endpoint_id=endpoint_id, + project=project, + result_arg=result_arg, + location=location, + ) output = llm.generate(["What is the meaning of life?", "How much is 2+2"]) assert isinstance(output, LLMResult) assert len(output.generations) == 2 -async def test_model_garden_agenerate() -> None: - endpoint_id = os.environ["ENDPOINT_ID"] +@pytest.mark.asyncio +@pytest.mark.parametrize( + "endpoint_os_variable_name,result_arg", + [("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)], +) +async def test_model_garden_agenerate( + endpoint_os_variable_name: str, result_arg: Optional[str] +) -> None: + endpoint_id = os.environ[endpoint_os_variable_name] project = os.environ["PROJECT"] - llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project) + location = "europe-west4" + llm = VertexAIModelGarden( + endpoint_id=endpoint_id, + project=project, + result_arg=result_arg, + location=location, + ) output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"]) assert isinstance(output, LLMResult) assert len(output.generations) == 2