Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions libs/vertexai/langchain_google_vertexai/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,12 @@ class _VertexAICommon(_VertexAIBase):
"Underlying model name."
temperature: Optional[float] = None
"Sampling temperature, it controls the degree of randomness in token selection."
frequency_penalty: Optional[float] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to extend _allow_model_args too, otherwise args sent to the invoke would be ignored

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I added them to the gemma models -- that's the only place I see allowed_model_args used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont' think Gemma supports these params.

I'm sorry I wasn't clear, that's the place that should be modified:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry. I'm clearly missing something. frequency_penalty and presence_penalty are already there, no?

"Positive values penalize tokens that repeatedly appear in the generated text, "
"decreasing the probability of repeating content."
presence_penalty: Optional[float] = None
"Positive values penalize tokens that already appear in the generated text, "
"increasing the probability of generating more diverse content."
max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
"Token limit determines the maximum amount of text output from one prompt."
top_p: Optional[float] = None
Expand Down
16 changes: 14 additions & 2 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,15 +971,27 @@ class Multiply(BaseModel):


def test_generation_config_gemini() -> None:
model = ChatVertexAI(model_name="gemini-pro", temperature=0.2, top_k=3)
model = ChatVertexAI(
model_name="gemini-pro",
temperature=0.2,
top_k=3,
frequency_penalty=0.2,
presence_penalty=0.6,
)
generation_config = model._generation_config_gemini(
temperature=0.3, stop=["stop"], candidate_count=2
temperature=0.3,
stop=["stop"],
candidate_count=2,
frequency_penalty=0.9,
presence_penalty=0.8,
)
expected = GenerationConfig(
stop_sequences=["stop"],
temperature=0.3,
top_k=3,
candidate_count=2,
frequency_penalty=0.9,
presence_penalty=0.8,
)
assert generation_config == expected

Expand Down
14 changes: 12 additions & 2 deletions libs/vertexai/tests/unit_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def test_vertexai_args_passed() -> None:
"temperature": 0,
"top_k": 10,
"top_p": 0.5,
"frequency_penalty": 0.2,
"presence_penalty": 0.3,
}

# Mock the library to ensure the args are passed correctly
Expand All @@ -78,7 +80,9 @@ def test_vertexai_args_passed() -> None:
mock_prediction_service.return_value.generate_content = mock_generate_content

llm = VertexAI(model_name="gemini-pro", **prompt_params)
response = llm.invoke(user_prompt, temperature=0.5)
response = llm.invoke(
user_prompt, temperature=0.5, frequency_penalty=0.5, presence_penalty=0.5
)
assert response == response_text
mock_generate_content.assert_called_once()

Expand All @@ -90,7 +94,13 @@ def test_vertexai_args_passed() -> None:
== "Hello"
)
expected = GenerationConfig(
candidate_count=1, temperature=0.5, top_p=0.5, top_k=10, max_output_tokens=1
candidate_count=1,
temperature=0.5,
top_p=0.5,
top_k=10,
max_output_tokens=1,
frequency_penalty=0.5,
presence_penalty=0.5,
)
assert (
mock_generate_content.call_args.kwargs["request"].generation_config
Expand Down
Loading