Skip to content

Commit

Permalink
adding max_tokens defaults back in
Browse files Browse the repository at this point in the history
  • Loading branch information
emjay07 committed Mar 19, 2024
1 parent 3ec97ae commit b9183de
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class AnthropicImageQueryDriver(BaseImageQueryDriver):
),
kw_only=True,
)
max_output_tokens: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
max_output_tokens: Optional[int] = field(default=4096, kw_only=True, metadata={"serializable": True})

def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
content = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class BedrockClaudeImageQueryModelDriver(BaseImageQueryModelDriver):
max_output_tokens: Max output tokens to return.
"""

max_output_tokens: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
max_output_tokens: Optional[int] = field(default=4096, kw_only=True, metadata={"serializable": True})

def construct_image_query_request_parameters(self, query: str, images: list[ImageArtifact]) -> dict:
content = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_try_query(self, mock_client):
expected_message = self._expected_message(test_binary_data, "image/png", test_prompt_string)

mock_client.return_value.messages.create.assert_called_once_with(
model=driver.model, messages=[expected_message]
model=driver.model, max_tokens=4096, messages=[expected_message]
)

assert text_artifact.value == "['content-block']"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def test_construct_image_query_request_parameters(self):
assert params["anthropic_version"] == "bedrock-2023-05-31"
assert "messages" in params
assert len(params["messages"]) == 1
assert "max_tokens" not in params
assert "max_tokens" in params
assert params["max_tokens"] == 4096

def test_construct_image_query_request_parameters_max_tokens(self):
model_driver = BedrockClaudeImageQueryModelDriver(max_output_tokens=1024)
Expand Down

0 comments on commit b9183de

Please sign in to comment.