diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index 4195867de..7d60a7410 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -1,6 +1,7 @@ import json from typing import Any, Dict, List, Literal, Optional, Union +import langchain_core from botocore.client import Config from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document @@ -147,7 +148,10 @@ def create_client(cls, values: Dict[str, Any]) -> Any: endpoint_url=values.get("endpoint_url"), config=values.get("config") or Config( - connect_timeout=120, read_timeout=120, retries={"max_attempts": 0} + connect_timeout=120, + read_timeout=120, + retries={"max_attempts": 0}, + user_agent_extra=f"langchain/{langchain_core.__version__}", ), service_name="bedrock-agent-runtime", ) diff --git a/libs/aws/tests/unit_tests/retrievers/test_bedrock.py b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py index e3baf108a..d12e7fd9e 100644 --- a/libs/aws/tests/unit_tests/retrievers/test_bedrock.py +++ b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py @@ -1,6 +1,6 @@ # type: ignore from typing import Any, List -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from langchain_core.documents import Document @@ -649,3 +649,46 @@ def test_guardrail_config_with_retrieval_config(mock_client, mock_retriever_conf } }, ) + + +@patch("langchain_aws.retrievers.bedrock.create_aws_client") +def test_user_agent_extra_in_config(mock_create_client): + """Test that user_agent_extra is properly set in the Config when creating client.""" + mock_client = MagicMock() + mock_create_client.return_value = mock_client + + AmazonKnowledgeBasesRetriever( + knowledge_base_id="test_kb_id", + ) + + # Verify create_aws_client was called + mock_create_client.assert_called_once() + call_args = mock_create_client.call_args + + # Check that config parameter contains user_agent_extra + config = call_args.kwargs["config"] + assert hasattr(config, "user_agent_extra") + assert config.user_agent_extra.startswith("langchain/") + + +@patch("langchain_aws.retrievers.bedrock.create_aws_client") +def test_custom_config_preserves_user_agent_extra(mock_create_client): + """Test that custom config doesn't override user_agent_extra.""" + from botocore.client import Config + + mock_client = MagicMock() + mock_create_client.return_value = mock_client + + custom_config = Config(region_name="us-west-2") + + AmazonKnowledgeBasesRetriever( + knowledge_base_id="test_kb_id", + config=custom_config, + ) + + # Verify create_aws_client was called with the custom config + mock_create_client.assert_called_once() + call_args = mock_create_client.call_args + + # The custom config should be passed through + assert call_args.kwargs["config"] == custom_config