Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
50 changes: 27 additions & 23 deletions src/google/adk/auth/auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING
Expand All @@ -22,6 +17,7 @@
from .auth_schemes import AuthSchemeType
from .auth_schemes import OpenIdConnectWithConfig
from .auth_tool import AuthConfig
from .credential_manager import CredentialManager
from .exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger

if TYPE_CHECKING:
Expand All @@ -48,9 +44,13 @@ async def exchange_auth_token(
self,
) -> AuthCredential:
exchanger = OAuth2CredentialExchanger()
return await exchanger.exchange(
self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme
)

# Restore secret if needed
credential = self.auth_config.exchanged_auth_credential

with CredentialManager.restore_client_secret(credential):
res = await exchanger.exchange(credential, self.auth_config.auth_scheme)
return res

async def parse_and_store_auth_response(self, state: State) -> None:

Expand Down Expand Up @@ -182,21 +182,25 @@ def generate_auth_uri(
)
scopes = list(scopes.keys())

client = OAuth2Session(
auth_credential.oauth2.client_id,
auth_credential.oauth2.client_secret,
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
)
params = {
"access_type": "offline",
"prompt": "consent",
}
if auth_credential.oauth2.audience:
params["audience"] = auth_credential.oauth2.audience
uri, state = client.create_authorization_url(
url=authorization_endpoint, **params
)
client_id = auth_credential.oauth2.client_id

with CredentialManager.restore_client_secret(auth_credential):
client_secret = auth_credential.oauth2.client_secret
client = OAuth2Session(
client_id,
client_secret,
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
)
params = {
"access_type": "offline",
"prompt": "consent",
}
if auth_credential.oauth2.audience:
params["audience"] = auth_credential.oauth2.audience
uri, state = client.create_authorization_url(
url=authorization_endpoint, **params
)

exchanged_auth_credential = auth_credential.model_copy(deep=True)
exchanged_auth_credential.oauth2.auth_uri = uri
Expand Down
11 changes: 9 additions & 2 deletions src/google/adk/auth/auth_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,16 @@ def get_credential_key(self):
)

auth_credential = self.raw_auth_credential
if auth_credential and auth_credential.model_extra:
if auth_credential and (
auth_credential.model_extra or auth_credential.oauth2
):
auth_credential = auth_credential.model_copy(deep=True)
auth_credential.model_extra.clear()
if auth_credential.model_extra:
auth_credential.model_extra.clear()
# Normalize secret to ensure stable key regardless of redaction
if auth_credential.oauth2:
auth_credential.oauth2.client_secret = None

credential_name = (
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
if auth_credential
Expand Down
94 changes: 89 additions & 5 deletions src/google/adk/auth/credential_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

from __future__ import annotations

import contextlib
import logging
from typing import Optional

from fastapi.openapi.models import OAuth2

from ..agents.callback_context import CallbackContext
from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
from ..utils.feature_decorator import experimental
from .auth_credential import AuthCredential
from .auth_credential import AuthCredentialTypes
Expand Down Expand Up @@ -75,11 +75,23 @@ class CredentialManager:
```
"""

# A map to store client secrets in memory. Key is client_id, value is client_secret
_CLIENT_SECRETS: dict[str, str] = {}

def __init__(
self,
auth_config: AuthConfig,
):
self._auth_config = auth_config
# We deep copy the auth_config to avoid modifying the original object passed
# by the user. This allows for safe redaction of sensitive information without
# causing side effects.

self._auth_config = auth_config.model_copy(deep=True)

# Secure the client secret
self._secure_client_secret(self._auth_config.raw_auth_credential)
self._secure_client_secret(self._auth_config.exchanged_auth_credential)

self._exchanger_registry = CredentialExchangerRegistry()
self._refresher_registry = CredentialRefresherRegistry()
self._discovery_manager = OAuth2DiscoveryManager()
Expand All @@ -97,6 +109,8 @@ def __init__(
)

# TODO: Move ServiceAccountCredentialExchanger to the auth module
from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger

self._exchanger_registry.register(
AuthCredentialTypes.SERVICE_ACCOUNT,
ServiceAccountCredentialExchanger(),
Expand All @@ -110,6 +124,36 @@ def __init__(
AuthCredentialTypes.OPEN_ID_CONNECT, oauth2_refresher
)

def _secure_client_secret(self, credential: Optional[AuthCredential]):
"""Extracts client secret to memory and redacts it from the credential."""
if (
credential
and credential.oauth2
and credential.oauth2.client_id
and credential.oauth2.client_secret
and credential.oauth2.client_secret != "<redacted>"
):
logger.info(
f"Securing client secret for client_id: {credential.oauth2.client_id}"
)
# Store in memory map
CredentialManager._CLIENT_SECRETS[credential.oauth2.client_id] = (
credential.oauth2.client_secret
)
# Redact from config
credential.oauth2.client_secret = "<redacted>"
else:
if credential and credential.oauth2:
logger.debug(
f"Not securing secret for client_id {credential.oauth2.client_id}:"
f" secret is {credential.oauth2.client_secret}"
)

@staticmethod
def get_client_secret(client_id: str) -> Optional[str]:
"""Retrieves the client secret for a given client_id."""
return CredentialManager._CLIENT_SECRETS.get(client_id)

def register_credential_exchanger(
self,
credential_type: AuthCredentialTypes,
Expand All @@ -124,6 +168,9 @@ def register_credential_exchanger(
self._exchanger_registry.register(credential_type, exchanger_instance)

async def request_credential(self, callback_context: CallbackContext) -> None:
# We send the auth_config (which is already redacted in __init__) to the client
# Note: we need to ensure we don't send any stale exchanged credentials if they are not valid
# But usually CredentialManager manages that.
callback_context.request_credential(self._auth_config)

async def get_auth_credential(
Expand Down Expand Up @@ -205,6 +252,40 @@ async def _load_from_auth_response(
"""Load credential from auth response in callback context."""
return callback_context.get_auth_response(self._auth_config)

@staticmethod
@contextlib.contextmanager
def restore_client_secret(credential: AuthCredential, secret: str = None):
"""Context manager to temporarily restore client secret in a credential.

Args:
credential: The credential to restore secret for.
secret: Optional secret to use. If not provided, looks up by client_id.
"""
if not credential or not credential.oauth2:
yield
return

restored = False
if secret:
credential.oauth2.client_secret = secret
restored = True
elif (
credential.oauth2.client_id
and credential.oauth2.client_secret == "<redacted>"
):
stored_secret = CredentialManager.get_client_secret(
credential.oauth2.client_id
)
if stored_secret:
credential.oauth2.client_secret = stored_secret
restored = True

try:
yield
finally:
if restored:
credential.oauth2.client_secret = "<redacted>"

async def _exchange_credential(
self, credential: AuthCredential
) -> tuple[AuthCredential, bool]:
Expand All @@ -213,14 +294,17 @@ async def _exchange_credential(
if not exchanger:
return credential, False

from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger

if isinstance(exchanger, ServiceAccountCredentialExchanger):
exchanged_credential = exchanger.exchange_credential(
self._auth_config.auth_scheme, credential
)
else:
exchanged_credential = await exchanger.exchange(
credential, self._auth_config.auth_scheme
)
with self.restore_client_secret(credential):
exchanged_credential = await exchanger.exchange(
credential, self._auth_config.auth_scheme
)

return exchanged_credential, True

Expand Down
Loading