Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion libs/aws/langchain_aws/retrievers/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from typing import Any, Dict, List, Literal, Optional, Union

import langchain_core
from botocore.client import Config
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
Expand Down Expand Up @@ -147,7 +148,10 @@ def create_client(cls, values: Dict[str, Any]) -> Any:
endpoint_url=values.get("endpoint_url"),
config=values.get("config")
or Config(
connect_timeout=120, read_timeout=120, retries={"max_attempts": 0}
connect_timeout=120,
read_timeout=120,
retries={"max_attempts": 0},
user_agent_extra=f"langchain/{langchain_core.__version__}",
),
service_name="bedrock-agent-runtime",
)
Expand Down
45 changes: 44 additions & 1 deletion libs/aws/tests/unit_tests/retrievers/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# type: ignore
from typing import Any, List
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
from langchain_core.documents import Document
Expand Down Expand Up @@ -649,3 +649,46 @@ def test_guardrail_config_with_retrieval_config(mock_client, mock_retriever_conf
}
},
)


@patch("langchain_aws.retrievers.bedrock.create_aws_client")
def test_user_agent_extra_in_config(mock_create_client):
"""Test that user_agent_extra is properly set in the Config when creating client."""
mock_client = MagicMock()
mock_create_client.return_value = mock_client

AmazonKnowledgeBasesRetriever(
knowledge_base_id="test_kb_id",
)

# Verify create_aws_client was called
mock_create_client.assert_called_once()
call_args = mock_create_client.call_args

# Check that config parameter contains user_agent_extra
config = call_args.kwargs["config"]
assert hasattr(config, "user_agent_extra")
assert config.user_agent_extra.startswith("langchain/")


@patch("langchain_aws.retrievers.bedrock.create_aws_client")
def test_custom_config_preserves_user_agent_extra(mock_create_client):
"""Test that custom config doesn't override user_agent_extra."""
from botocore.client import Config

mock_client = MagicMock()
mock_create_client.return_value = mock_client

custom_config = Config(region_name="us-west-2")

AmazonKnowledgeBasesRetriever(
knowledge_base_id="test_kb_id",
config=custom_config,
)

# Verify create_aws_client was called with the custom config
mock_create_client.assert_called_once()
call_args = mock_create_client.call_args

# The custom config should be passed through
assert call_args.kwargs["config"] == custom_config