Skip to content

Commit

Permalink
feat: fetch id token from GCE metadata server (#462)
Browse files Browse the repository at this point in the history
feat: fetch id token from GCE metadata server
  • Loading branch information
arithmetic1728 committed Mar 23, 2020
1 parent 8374e21 commit 97e7700
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 35 deletions.
163 changes: 128 additions & 35 deletions google/auth/compute_engine/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,24 @@ class IDTokenCredentials(credentials.Credentials, credentials.Signing):
These credentials relies on the default service account of a GCE instance.
In order for this to work, the GCE instance must have been started with
ID token can be requested from `GCE metadata server identity endpoint`_, IAM
token endpoint or other token endpoints you specify. If metadata server
identity endpoint is not used, the GCE instance must have been started with
a service account that has access to the IAM Cloud API.
.. _GCE metadata server identity endpoint:
https://cloud.google.com/compute/docs/instances/verifying-instance-identity
"""

def __init__(
self,
request,
target_audience,
token_uri=_DEFAULT_TOKEN_URI,
token_uri=None,
additional_claims=None,
service_account_email=None,
signer=None,
use_metadata_identity_endpoint=False,
):
"""
Args:
Expand All @@ -154,29 +160,54 @@ def __init__(
signer (google.auth.crypt.Signer): The signer used to sign JWTs.
In case the signer is specified, the request argument will be
ignored.
use_metadata_identity_endpoint (bool): Whether to use GCE metadata
identity endpoint. For backward compatibility the default value
is False. If set to True, ``token_uri``, ``additional_claims``,
``service_account_email``, ``signer`` argument should not be set;
otherwise ValueError will be raised.
Raises:
ValueError:
If ``use_metadata_identity_endpoint`` is set to True, and one of
``token_uri``, ``additional_claims``, ``service_account_email``,
``signer`` arguments is set.
"""
super(IDTokenCredentials, self).__init__()

if service_account_email is None:
sa_info = _metadata.get_service_account_info(request)
service_account_email = sa_info["email"]
self._service_account_email = service_account_email

if signer is None:
signer = iam.Signer(
request=request,
credentials=Credentials(),
service_account_email=service_account_email,
)
self._signer = signer

self._token_uri = token_uri
self._use_metadata_identity_endpoint = use_metadata_identity_endpoint
self._target_audience = target_audience

if additional_claims is not None:
self._additional_claims = additional_claims
if use_metadata_identity_endpoint:
if token_uri or additional_claims or service_account_email or signer:
raise ValueError(
"If use_metadata_identity_endpoint is set, token_uri, "
"additional_claims, service_account_email, signer arguments"
" must not be set"
)
self._token_uri = None
self._additional_claims = None
self._signer = None

if service_account_email is None:
sa_info = _metadata.get_service_account_info(request)
self._service_account_email = sa_info["email"]
else:
self._additional_claims = {}
self._service_account_email = service_account_email

if not use_metadata_identity_endpoint:
if signer is None:
signer = iam.Signer(
request=request,
credentials=Credentials(),
service_account_email=self._service_account_email,
)
self._signer = signer
self._token_uri = token_uri or _DEFAULT_TOKEN_URI

if additional_claims is not None:
self._additional_claims = additional_claims
else:
self._additional_claims = {}

def with_target_audience(self, target_audience):
"""Create a copy of these credentials with the specified target
Expand All @@ -190,14 +221,22 @@ def with_target_audience(self, target_audience):
"""
# since the signer is already instantiated,
# the request is not needed
return self.__class__(
None,
service_account_email=self._service_account_email,
token_uri=self._token_uri,
target_audience=target_audience,
additional_claims=self._additional_claims.copy(),
signer=self.signer,
)
if self._use_metadata_identity_endpoint:
return self.__class__(
None,
target_audience=target_audience,
use_metadata_identity_endpoint=True,
)
else:
return self.__class__(
None,
service_account_email=self._service_account_email,
token_uri=self._token_uri,
target_audience=target_audience,
additional_claims=self._additional_claims.copy(),
signer=self.signer,
use_metadata_identity_endpoint=False,
)

def _make_authorization_grant_assertion(self):
"""Create the OAuth 2.0 assertion.
Expand Down Expand Up @@ -228,22 +267,76 @@ def _make_authorization_grant_assertion(self):

return token

@_helpers.copy_docstring(credentials.Credentials)
def _call_metadata_identity_endpoint(self, request):
"""Request ID token from metadata identity endpoint.
Args:
request (google.auth.transport.Request): The object used to make
HTTP requests.
Raises:
google.auth.exceptions.RefreshError: If the Compute Engine metadata
service can't be reached or if the instance has no credentials.
ValueError: If extracting expiry from the obtained ID token fails.
"""
try:
id_token = _metadata.get(
request,
"instance/service-accounts/default/identity?audience={}&format=full".format(
self._target_audience
),
)
except exceptions.TransportError as caught_exc:
new_exc = exceptions.RefreshError(caught_exc)
six.raise_from(new_exc, caught_exc)

_, payload, _, _ = jwt._unverified_decode(id_token)
return id_token, payload["exp"]

def refresh(self, request):
assertion = self._make_authorization_grant_assertion()
access_token, expiry, _ = _client.id_token_jwt_grant(
request, self._token_uri, assertion
)
self.token = access_token
self.expiry = expiry
"""Refreshes the ID token.
Args:
request (google.auth.transport.Request): The object used to make
HTTP requests.
Raises:
google.auth.exceptions.RefreshError: If the credentials could
not be refreshed.
ValueError: If extracting expiry from the obtained ID token fails.
"""
if self._use_metadata_identity_endpoint:
self.token, self.expiry = self._call_metadata_identity_endpoint(request)
else:
assertion = self._make_authorization_grant_assertion()
access_token, expiry, _ = _client.id_token_jwt_grant(
request, self._token_uri, assertion
)
self.token = access_token
self.expiry = expiry

@property
@_helpers.copy_docstring(credentials.Signing)
def signer(self):
return self._signer

@_helpers.copy_docstring(credentials.Signing)
def sign_bytes(self, message):
"""Signs the given message.
Args:
message (bytes): The message to sign.
Returns:
bytes: The message's cryptographic signature.
Raises:
ValueError:
Signer is not available if metadata identity endpoint is used.
"""
if self._use_metadata_identity_endpoint:
raise ValueError(
"Signer is not available if metadata identity endpoint is used"
)
return self._signer.sign(message)

@property
Expand Down
12 changes: 12 additions & 0 deletions system_tests/test_compute_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from google.auth import compute_engine
from google.auth import _helpers
from google.auth import exceptions
from google.auth import jwt
from google.auth.compute_engine import _metadata


Expand Down Expand Up @@ -48,3 +49,14 @@ def test_default(verify_refresh):
assert project_id is not None
assert isinstance(credentials, compute_engine.Credentials)
verify_refresh(credentials)


def test_id_token_from_metadata(http_request):
credentials = compute_engine.IDTokenCredentials(
http_request, "target_audience", use_metadata_identity_endpoint=True
)
credentials.refresh(http_request)

_, payload, _, _ = jwt._unverified_decode(credentials.token)
assert payload["aud"] == "target_audience"
assert payload["exp"] == credentials.expiry
139 changes: 139 additions & 0 deletions tests/compute_engine/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@
from google.auth.compute_engine import credentials
from google.auth.transport import requests

SAMPLE_ID_TOKEN_EXP = 1584393400

# header: {"alg": "RS256", "typ": "JWT", "kid": "1"}
# payload: {"iss": "issuer", "iat": 1584393348, "sub": "subject",
# "exp": 1584393400,"aud": "audience"}
SAMPLE_ID_TOKEN = (
b"eyJhbGciOiAiUlMyNTYiLCAidHlwIjogIkpXVCIsICJraWQiOiAiMSJ9."
b"eyJpc3MiOiAiaXNzdWVyIiwgImlhdCI6IDE1ODQzOTMzNDgsICJzdWIiO"
b"iAic3ViamVjdCIsICJleHAiOiAxNTg0MzkzNDAwLCAiYXVkIjogImF1ZG"
b"llbmNlIn0."
b"OquNjHKhTmlgCk361omRo18F_uY-7y0f_AmLbzW062Q1Zr61HAwHYP5FM"
b"316CK4_0cH8MUNGASsvZc3VqXAqub6PUTfhemH8pFEwBdAdG0LhrNkU0H"
b"WN1YpT55IiQ31esLdL5q-qDsOPpNZJUti1y1lAreM5nIn2srdWzGXGs4i"
b"TRQsn0XkNUCL4RErpciXmjfhMrPkcAjKA-mXQm2fa4jmTlEZFqFmUlym1"
b"ozJ0yf5grjN6AslN4OGvAv1pS-_Ko_pGBS6IQtSBC6vVKCUuBfaqNjykg"
b"bsxbLa6Fp0SYeYwO8ifEnkRvasVpc1WTQqfRB2JCj5pTBDzJpIpFCMmnQ"
)


class TestCredentials(object):
credentials = None
Expand Down Expand Up @@ -238,6 +256,26 @@ def test_additional_claims(self, sign, get, utcnow):
"foo": "bar",
}

def test_token_uri(self):
request = mock.create_autospec(transport.Request, instance=True)

self.credentials = credentials.IDTokenCredentials(
request=request,
signer=mock.Mock(),
service_account_email="foo@example.com",
target_audience="https://audience.com",
)
assert self.credentials._token_uri == credentials._DEFAULT_TOKEN_URI

self.credentials = credentials.IDTokenCredentials(
request=request,
signer=mock.Mock(),
service_account_email="foo@example.com",
target_audience="https://audience.com",
token_uri="https://example.com/token",
)
assert self.credentials._token_uri == "https://example.com/token"

@mock.patch(
"google.auth._helpers.utcnow",
return_value=datetime.datetime.utcfromtimestamp(0),
Expand Down Expand Up @@ -469,3 +507,104 @@ def test_sign_bytes(self, sign, get):

# The JWT token signature is 'signature' encoded in base 64:
assert signature == b"signature"

@mock.patch(
"google.auth.compute_engine._metadata.get_service_account_info", autospec=True
)
@mock.patch("google.auth.compute_engine._metadata.get", autospec=True)
def test_get_id_token_from_metadata(self, get, get_service_account_info):
get.return_value = SAMPLE_ID_TOKEN
get_service_account_info.return_value = {"email": "foo@example.com"}

cred = credentials.IDTokenCredentials(
mock.Mock(), "audience", use_metadata_identity_endpoint=True
)
cred.refresh(request=mock.Mock())

assert cred.token == SAMPLE_ID_TOKEN
assert cred.expiry == SAMPLE_ID_TOKEN_EXP
assert cred._use_metadata_identity_endpoint
assert cred._signer is None
assert cred._token_uri is None
assert cred._service_account_email == "foo@example.com"
assert cred._target_audience == "audience"
with pytest.raises(ValueError):
cred.sign_bytes(b"bytes")

@mock.patch(
"google.auth.compute_engine._metadata.get_service_account_info", autospec=True
)
def test_with_target_audience_for_metadata(self, get_service_account_info):
get_service_account_info.return_value = {"email": "foo@example.com"}

cred = credentials.IDTokenCredentials(
mock.Mock(), "audience", use_metadata_identity_endpoint=True
)
cred = cred.with_target_audience("new_audience")

assert cred._target_audience == "new_audience"
assert cred._use_metadata_identity_endpoint
assert cred._signer is None
assert cred._token_uri is None
assert cred._service_account_email == "foo@example.com"

@mock.patch(
"google.auth.compute_engine._metadata.get_service_account_info", autospec=True
)
@mock.patch("google.auth.compute_engine._metadata.get", autospec=True)
def test_invalid_id_token_from_metadata(self, get, get_service_account_info):
get.return_value = "invalid_id_token"
get_service_account_info.return_value = {"email": "foo@example.com"}

cred = credentials.IDTokenCredentials(
mock.Mock(), "audience", use_metadata_identity_endpoint=True
)

with pytest.raises(ValueError):
cred.refresh(request=mock.Mock())

@mock.patch(
"google.auth.compute_engine._metadata.get_service_account_info", autospec=True
)
@mock.patch("google.auth.compute_engine._metadata.get", autospec=True)
def test_transport_error_from_metadata(self, get, get_service_account_info):
get.side_effect = exceptions.TransportError("transport error")
get_service_account_info.return_value = {"email": "foo@example.com"}

cred = credentials.IDTokenCredentials(
mock.Mock(), "audience", use_metadata_identity_endpoint=True
)

with pytest.raises(exceptions.RefreshError) as excinfo:
cred.refresh(request=mock.Mock())
assert excinfo.match(r"transport error")

def test_get_id_token_from_metadata_constructor(self):
with pytest.raises(ValueError):
credentials.IDTokenCredentials(
mock.Mock(),
"audience",
use_metadata_identity_endpoint=True,
token_uri="token_uri",
)
with pytest.raises(ValueError):
credentials.IDTokenCredentials(
mock.Mock(),
"audience",
use_metadata_identity_endpoint=True,
signer=mock.Mock(),
)
with pytest.raises(ValueError):
credentials.IDTokenCredentials(
mock.Mock(),
"audience",
use_metadata_identity_endpoint=True,
additional_claims={"key", "value"},
)
with pytest.raises(ValueError):
credentials.IDTokenCredentials(
mock.Mock(),
"audience",
use_metadata_identity_endpoint=True,
service_account_email="foo@example.com",
)

0 comments on commit 97e7700

Please sign in to comment.