Skip to content

Commit 0387bb9

Browse files
authored
feat: MDS connections use mTLS (#1856)
Use mTLS/HTTPS when connecting to MDS **Feature Gating** The `GCE_METADATA_MTLS_MODE` environment variable is introduced, which can be set to strict, none, or default. The `should_use_mds_mtls` function determines whether to use mTLS based on the environment variable and the existence of the certificate files in well-known location ((https://docs.cloud.google.com/compute/docs/metadata/overview#https-mds-certificates). **Description of changes** A custom `MdsMtlsAdapter` is implemented to handle the SSL context for mTLS. MdsMtlsAdapter loads MDS mTLS certificates from well-known location. MdsMtlsAdapter is mounted into the provided request.Session. **Behavior** If mode == none: Continue to use HTTP. If mode == default: Use HTTPS if certificates exist. If HTTPS/mTLS fails, falls back to HTTP. If mode == strict: Use HTTPS always, even if certificates don't exist (will result in error). **Integrating with existing code** compute_engine/_metadata.py: - The metadata server URL construction is now dynamic, supporting both http and https schemes based on whether mTLS is enabled. - ping and get functions are updated to use mTLS when it's enabled.
1 parent 5b96011 commit 0387bb9

File tree

5 files changed

+749
-36
lines changed

5 files changed

+749
-36
lines changed

google/auth/compute_engine/_metadata.py

Lines changed: 101 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,31 +24,72 @@
2424
import os
2525
from urllib.parse import urljoin
2626

27+
import requests
28+
2729
from google.auth import _helpers
2830
from google.auth import environment_vars
2931
from google.auth import exceptions
3032
from google.auth import metrics
3133
from google.auth import transport
3234
from google.auth._exponential_backoff import ExponentialBackoff
35+
from google.auth.compute_engine import _mtls
36+
3337

3438
_LOGGER = logging.getLogger(__name__)
3539

40+
_GCE_DEFAULT_MDS_IP = "169.254.169.254"
41+
_GCE_DEFAULT_HOST = "metadata.google.internal"
42+
_GCE_DEFAULT_MDS_HOSTS = [_GCE_DEFAULT_HOST, _GCE_DEFAULT_MDS_IP]
43+
3644
# Environment variable GCE_METADATA_HOST is originally named
3745
# GCE_METADATA_ROOT. For compatibility reasons, here it checks
3846
# the new variable first; if not set, the system falls back
3947
# to the old variable.
4048
_GCE_METADATA_HOST = os.getenv(environment_vars.GCE_METADATA_HOST, None)
4149
if not _GCE_METADATA_HOST:
4250
_GCE_METADATA_HOST = os.getenv(
43-
environment_vars.GCE_METADATA_ROOT, "metadata.google.internal"
51+
environment_vars.GCE_METADATA_ROOT, _GCE_DEFAULT_HOST
52+
)
53+
54+
55+
def _validate_gce_mds_configured_environment():
56+
"""Validates the GCE metadata server environment configuration for mTLS.
57+
58+
mTLS is only supported when connecting to the default metadata server hosts.
59+
If we are in strict mode (which requires mTLS), ensure that the metadata host
60+
has not been overridden to a custom value (which means mTLS will fail).
61+
62+
Raises:
63+
google.auth.exceptions.MutualTLSChannelError: if the environment
64+
configuration is invalid for mTLS.
65+
"""
66+
mode = _mtls._parse_mds_mode()
67+
if mode == _mtls.MdsMtlsMode.STRICT:
68+
# mTLS is only supported when connecting to the default metadata host.
69+
# Raise an exception if we are in strict mode (which requires mTLS)
70+
# but the metadata host has been overridden to a custom MDS. (which means mTLS will fail)
71+
if _GCE_METADATA_HOST not in _GCE_DEFAULT_MDS_HOSTS:
72+
raise exceptions.MutualTLSChannelError(
73+
"Mutual TLS is required, but the metadata host has been overridden. "
74+
"mTLS is only supported when connecting to the default metadata host."
75+
)
76+
77+
78+
def _get_metadata_root(use_mtls: bool):
79+
"""Returns the metadata server root URL."""
80+
81+
scheme = "https" if use_mtls else "http"
82+
return "{}://{}/computeMetadata/v1/".format(scheme, _GCE_METADATA_HOST)
83+
84+
85+
def _get_metadata_ip_root(use_mtls: bool):
86+
"""Returns the metadata server IP root URL."""
87+
scheme = "https" if use_mtls else "http"
88+
return "{}://{}".format(
89+
scheme, os.getenv(environment_vars.GCE_METADATA_IP, _GCE_DEFAULT_MDS_IP)
4490
)
45-
_METADATA_ROOT = "http://{}/computeMetadata/v1/".format(_GCE_METADATA_HOST)
4691

47-
# This is used to ping the metadata server, it avoids the cost of a DNS
48-
# lookup.
49-
_METADATA_IP_ROOT = "http://{}".format(
50-
os.getenv(environment_vars.GCE_METADATA_IP, "169.254.169.254")
51-
)
92+
5293
_METADATA_FLAVOR_HEADER = "metadata-flavor"
5394
_METADATA_FLAVOR_VALUE = "Google"
5495
_METADATA_HEADERS = {_METADATA_FLAVOR_HEADER: _METADATA_FLAVOR_VALUE}
@@ -102,6 +143,33 @@ def detect_gce_residency_linux():
102143
return content.startswith(_GOOGLE)
103144

104145

146+
def _prepare_request_for_mds(request, use_mtls=False) -> None:
147+
"""Prepares a request for the metadata server.
148+
149+
This will check if mTLS should be used and mount the mTLS adapter if needed.
150+
151+
Args:
152+
request (google.auth.transport.Request): A callable used to make
153+
HTTP requests.
154+
use_mtls (bool): Whether to use mTLS for the request.
155+
156+
Returns:
157+
google.auth.transport.Request: A request object to use.
158+
If mTLS is enabled, the request will have the mTLS adapter mounted.
159+
Otherwise, the original request will be returned unchanged.
160+
"""
161+
# Only modify the request if mTLS is enabled.
162+
if use_mtls:
163+
# Ensure the request has a session to mount the adapter to.
164+
if not request.session:
165+
request.session = requests.Session()
166+
167+
adapter = _mtls.MdsMtlsAdapter()
168+
# Mount the adapter for all default GCE metadata hosts.
169+
for host in _GCE_DEFAULT_MDS_HOSTS:
170+
request.session.mount(f"https://{host}/", adapter)
171+
172+
105173
def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
106174
"""Checks to see if the metadata server is available.
107175
@@ -115,6 +183,8 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
115183
Returns:
116184
bool: True if the metadata server is reachable, False otherwise.
117185
"""
186+
use_mtls = _mtls.should_use_mds_mtls()
187+
_prepare_request_for_mds(request, use_mtls=use_mtls)
118188
# NOTE: The explicit ``timeout`` is a workaround. The underlying
119189
# issue is that resolving an unknown host on some networks will take
120190
# 20-30 seconds; making this timeout short fixes the issue, but
@@ -129,7 +199,10 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
129199
for attempt in backoff:
130200
try:
131201
response = request(
132-
url=_METADATA_IP_ROOT, method="GET", headers=headers, timeout=timeout
202+
url=_get_metadata_ip_root(use_mtls),
203+
method="GET",
204+
headers=headers,
205+
timeout=timeout,
133206
)
134207

135208
metadata_flavor = response.headers.get(_METADATA_FLAVOR_HEADER)
@@ -153,7 +226,7 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
153226
def get(
154227
request,
155228
path,
156-
root=_METADATA_ROOT,
229+
root=None,
157230
params=None,
158231
recursive=False,
159232
retry_count=5,
@@ -168,7 +241,8 @@ def get(
168241
HTTP requests.
169242
path (str): The resource to retrieve. For example,
170243
``'instance/service-accounts/default'``.
171-
root (str): The full path to the metadata server root.
244+
root (Optional[str]): The full path to the metadata server root. If not
245+
provided, the default root will be used.
172246
params (Optional[Mapping[str, str]]): A mapping of query parameter
173247
keys to values.
174248
recursive (bool): Whether to do a recursive query of metadata. See
@@ -189,7 +263,24 @@ def get(
189263
Raises:
190264
google.auth.exceptions.TransportError: if an error occurred while
191265
retrieving metadata.
266+
google.auth.exceptions.MutualTLSChannelError: if using mtls and the environment
267+
configuration is invalid for mTLS (for example, the metadata host
268+
has been overridden in strict mTLS mode).
269+
192270
"""
271+
use_mtls = _mtls.should_use_mds_mtls()
272+
# Prepare the request object for mTLS if needed.
273+
# This will create a new request object with the mTLS session.
274+
_prepare_request_for_mds(request, use_mtls=use_mtls)
275+
276+
if root is None:
277+
root = _get_metadata_root(use_mtls)
278+
279+
# mTLS is only supported when connecting to the default metadata host.
280+
# If we are in strict mode (which requires mTLS), ensure that the metadata host
281+
# has not been overridden to a non-default host value (which means mTLS will fail).
282+
_validate_gce_mds_configured_environment()
283+
193284
base_url = urljoin(root, path)
194285
query_params = {} if params is None else params
195286

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright 2024 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
"""Mutual TLS for Google Compute Engine metadata server."""
18+
19+
from dataclasses import dataclass, field
20+
import enum
21+
import logging
22+
import os
23+
from pathlib import Path
24+
import ssl
25+
from urllib.parse import urlparse, urlunparse
26+
27+
import requests
28+
from requests.adapters import HTTPAdapter
29+
30+
from google.auth import environment_vars, exceptions
31+
32+
33+
_LOGGER = logging.getLogger(__name__)
34+
35+
_WINDOWS_OS_NAME = "nt"
36+
37+
# MDS mTLS certificate paths based on OS.
38+
# Documentation to well known locations can be found at:
39+
# https://cloud.google.com/compute/docs/metadata/overview#https-mds-certificates
40+
_WINDOWS_MTLS_COMPONENTS_BASE_PATH = Path("C:/ProgramData/Google/ComputeEngine")
41+
_MTLS_COMPONENTS_BASE_PATH = Path("/run/google-mds-mtls")
42+
43+
44+
def _get_mds_root_crt_path():
45+
if os.name == _WINDOWS_OS_NAME:
46+
return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-root.crt"
47+
else:
48+
return _MTLS_COMPONENTS_BASE_PATH / "root.crt"
49+
50+
51+
def _get_mds_client_combined_cert_path():
52+
if os.name == _WINDOWS_OS_NAME:
53+
return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-client.key"
54+
else:
55+
return _MTLS_COMPONENTS_BASE_PATH / "client.key"
56+
57+
58+
@dataclass
59+
class MdsMtlsConfig:
60+
ca_cert_path: Path = field(
61+
default_factory=_get_mds_root_crt_path
62+
) # path to CA certificate
63+
client_combined_cert_path: Path = field(
64+
default_factory=_get_mds_client_combined_cert_path
65+
) # path to file containing client certificate and key
66+
67+
68+
def _certs_exist(mds_mtls_config: MdsMtlsConfig):
69+
"""Checks if the mTLS certificates exist."""
70+
return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists(
71+
mds_mtls_config.client_combined_cert_path
72+
)
73+
74+
75+
class MdsMtlsMode(enum.Enum):
76+
"""MDS mTLS mode. Used to configure connection behavior when connecting to MDS.
77+
78+
STRICT: Always use HTTPS/mTLS. If certificates are not found locally, an error will be returned.
79+
NONE: Never use mTLS. Requests will use regular HTTP.
80+
DEFAULT: Use mTLS if certificates are found locally, otherwise use regular HTTP.
81+
"""
82+
83+
STRICT = "strict"
84+
NONE = "none"
85+
DEFAULT = "default"
86+
87+
88+
def _parse_mds_mode():
89+
"""Parses the GCE_METADATA_MTLS_MODE environment variable."""
90+
mode_str = os.environ.get(
91+
environment_vars.GCE_METADATA_MTLS_MODE, "default"
92+
).lower()
93+
try:
94+
return MdsMtlsMode(mode_str)
95+
except ValueError:
96+
raise ValueError(
97+
"Invalid value for GCE_METADATA_MTLS_MODE. Must be one of 'strict', 'none', or 'default'."
98+
)
99+
100+
101+
def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()):
102+
"""Determines if mTLS should be used for the metadata server."""
103+
mode = _parse_mds_mode()
104+
if mode == MdsMtlsMode.STRICT:
105+
if not _certs_exist(mds_mtls_config):
106+
raise exceptions.MutualTLSChannelError(
107+
"mTLS certificates not found in strict mode."
108+
)
109+
return True
110+
elif mode == MdsMtlsMode.NONE:
111+
return False
112+
else: # Default mode
113+
return _certs_exist(mds_mtls_config)
114+
115+
116+
class MdsMtlsAdapter(HTTPAdapter):
117+
"""An HTTP adapter that uses mTLS for the metadata server."""
118+
119+
def __init__(
120+
self, mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig(), *args, **kwargs
121+
):
122+
self.ssl_context = ssl.create_default_context()
123+
self.ssl_context.load_verify_locations(cafile=mds_mtls_config.ca_cert_path)
124+
self.ssl_context.load_cert_chain(
125+
certfile=mds_mtls_config.client_combined_cert_path
126+
)
127+
super(MdsMtlsAdapter, self).__init__(*args, **kwargs)
128+
129+
def init_poolmanager(self, *args, **kwargs):
130+
kwargs["ssl_context"] = self.ssl_context
131+
return super(MdsMtlsAdapter, self).init_poolmanager(*args, **kwargs)
132+
133+
def proxy_manager_for(self, *args, **kwargs):
134+
kwargs["ssl_context"] = self.ssl_context
135+
return super(MdsMtlsAdapter, self).proxy_manager_for(*args, **kwargs)
136+
137+
def send(self, request, **kwargs):
138+
# If we are in strict mode, always use mTLS (no HTTP fallback)
139+
if _parse_mds_mode() == MdsMtlsMode.STRICT:
140+
return super(MdsMtlsAdapter, self).send(request, **kwargs)
141+
142+
# In default mode, attempt mTLS first, then fallback to HTTP on failure
143+
try:
144+
response = super(MdsMtlsAdapter, self).send(request, **kwargs)
145+
response.raise_for_status()
146+
return response
147+
except (
148+
ssl.SSLError,
149+
requests.exceptions.SSLError,
150+
requests.exceptions.HTTPError,
151+
) as e:
152+
_LOGGER.warning(
153+
"mTLS connection to Compute Engine Metadata server failed. "
154+
"Falling back to standard HTTP. Reason: %s",
155+
e,
156+
)
157+
# Fallback to standard HTTP
158+
parsed_original_url = urlparse(request.url)
159+
http_fallback_url = urlunparse(parsed_original_url._replace(scheme="http"))
160+
request.url = http_fallback_url
161+
162+
# Use a standard HTTPAdapter for the fallback
163+
http_adapter = HTTPAdapter()
164+
return http_adapter.send(request, **kwargs)

google/auth/environment_vars.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@
6060
"""Environment variable providing an alternate ip:port to be used for ip-only
6161
GCE metadata requests."""
6262

63+
GCE_METADATA_MTLS_MODE = "GCE_METADATA_MTLS_MODE"
64+
"""Environment variable controlling the mTLS behavior for GCE metadata requests.
65+
66+
Can be one of "strict", "none", or "default".
67+
"""
68+
6369
GOOGLE_API_USE_CLIENT_CERTIFICATE = "GOOGLE_API_USE_CLIENT_CERTIFICATE"
6470
"""Environment variable controlling whether to use client certificate or not.
6571

0 commit comments

Comments
 (0)