Skip to content

Commit

Permalink
Remove one-time token behavior of JWT Credentials (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jon Wayne Parrott committed Feb 23, 2017
1 parent 254befe commit ab08689
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 108 deletions.
95 changes: 19 additions & 76 deletions google/auth/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@
import datetime
import json

from six.moves import urllib

from google.auth import _helpers
from google.auth import _service_account_info
from google.auth import credentials
Expand Down Expand Up @@ -246,11 +244,7 @@ class Credentials(credentials.Signing,
"""Credentials that use a JWT as the bearer token.
These credentials require an "audience" claim. This claim identifies the
intended recipient of the bearer token. You can set the audience when
you construct these credentials, however, these credentials can also set
the audience claim automatically if not specified. In this case, whenever
a request is made the credentials will automatically generate a one-time
JWT with the request URI as the audience.
intended recipient of the bearer token.
The constructor arguments determine the claims for the JWT that is
sent with requests. Usually, you'll construct these credentials with
Expand All @@ -260,13 +254,15 @@ class Credentials(credentials.Signing,
JSON file::
credentials = jwt.Credentials.from_service_account_file(
'service-account.json')
'service-account.json',
audience='https://speech.googleapis.com')
If you already have the service account file loaded and parsed::
service_account_info = json.load(open('service_account.json'))
credentials = jwt.Credentials.from_service_account_info(
service_account_info)
service_account_info,
audience='https://speech.googleapis.com')
Both helper methods pass on arguments to the constructor, so you can
specify the JWT claims::
Expand All @@ -280,7 +276,10 @@ class Credentials(credentials.Signing,
:class:`~google.auth.crypt.Signer` instance::
credentials = jwt.Credentials(
signer, issuer='your-issuer', subject='your-subject')
signer,
issuer='your-issuer',
subject='your-subject',
audience=''https://speech.googleapis.com'')
The claims are considered immutable. If you want to modify the claims,
you can easily create another instance using :meth:`with_claims`::
Expand All @@ -289,7 +288,7 @@ class Credentials(credentials.Signing,
audience='https://vision.googleapis.com')
"""

def __init__(self, signer, issuer=None, subject=None, audience=None,
def __init__(self, signer, issuer, subject, audience,
additional_claims=None,
token_lifetime=_DEFAULT_TOKEN_LIFETIME_SECS):
"""
Expand All @@ -298,8 +297,7 @@ def __init__(self, signer, issuer=None, subject=None, audience=None,
issuer (str): The `iss` claim.
subject (str): The `sub` claim.
audience (str): the `aud` claim. The intended audience for the
credentials. If not specified, a new JWT will be generated for
every request and will use the request URI as the audience.
credentials.
additional_claims (Mapping[str, str]): Any additional claims for
the JWT payload.
token_lifetime (int): The amount of time in seconds for
Expand Down Expand Up @@ -334,7 +332,8 @@ def _from_signer_and_info(cls, signer, info, **kwargs):
ValueError: If the info is not in the expected format.
"""
kwargs.setdefault('subject', info['client_email'])
return cls(signer, issuer=info['client_email'], **kwargs)
kwargs.setdefault('issuer', info['client_email'])
return cls(signer, **kwargs)

@classmethod
def from_service_account_info(cls, info, **kwargs):
Expand Down Expand Up @@ -381,9 +380,8 @@ def with_claims(self, issuer=None, subject=None, audience=None,
claim will be used.
subject (str): The `sub` claim. If unspecified the current subject
claim will be used.
audience (str): the `aud` claim. If not specified, a new
JWT will be generated for every request and will use
the request URI as the audience.
audience (str): the `aud` claim. If unspecified the current
audience claim will be used.
additional_claims (Mapping[str, str]): Any additional claims for
the JWT payload. This will be merged with the current
additional claims.
Expand All @@ -399,12 +397,9 @@ def with_claims(self, issuer=None, subject=None, audience=None,
additional_claims=self._additional_claims.copy().update(
additional_claims or {}))

def _make_jwt(self, audience=None):
def _make_jwt(self):
"""Make a signed JWT.
Args:
audience (str): Overrides the instance's current audience claim.
Returns:
Tuple[bytes, datetime]: The encoded JWT and the expiration.
"""
Expand All @@ -414,10 +409,10 @@ def _make_jwt(self, audience=None):

payload = {
'iss': self._issuer,
'sub': self._subject or self._issuer,
'sub': self._subject,
'iat': _helpers.datetime_to_secs(now),
'exp': _helpers.datetime_to_secs(expiry),
'aud': audience or self._audience,
'aud': self._audience,
}

payload.update(self._additional_claims)
Expand All @@ -426,22 +421,6 @@ def _make_jwt(self, audience=None):

return jwt, expiry

def _make_one_time_jwt(self, uri):
"""Makes a one-off JWT with the URI as the audience.
Args:
uri (str): The request URI.
Returns:
bytes: The encoded JWT.
"""
parts = urllib.parse.urlsplit(uri)
# Strip query string and fragment
audience = urllib.parse.urlunsplit(
(parts.scheme, parts.netloc, parts.path, None, None))
token, _ = self._make_jwt(audience=audience)
return token

def refresh(self, request):
"""Refreshes the access token.
Expand All @@ -452,15 +431,8 @@ def refresh(self, request):
# (pylint doesn't correctly recognize overridden methods.)
self.token, self.expiry = self._make_jwt()

@_helpers.copy_docstring(credentials.Signing)
def sign_bytes(self, message):
"""Signs the given message.
Args:
message (bytes): The message to sign.
Returns:
bytes: The message signature.
"""
return self._signer.sign(message)

@property
Expand All @@ -472,32 +444,3 @@ def signer_email(self):
@_helpers.copy_docstring(credentials.Signing)
def signer(self):
return self._signer

def before_request(self, request, method, url, headers):
"""Performs credential-specific before request logic.
If an audience is specified it will refresh the credentials if
necessary. If no audience is specified it will generate a one-time
token for the request URI. In either case, it will set the
authorization header in headers to the token.
Args:
request (Any): Unused.
method (str): The request's HTTP method.
url (str): The request's URI.
headers (Mapping): The request's headers.
"""
# pylint: disable=unused-argument
# (pylint doesn't correctly recognize overridden methods.)

# If this set of credentials has a pre-set audience, just ensure that
# there is a valid token and apply the auth headers.
if self._audience:
if not self.valid:
self.refresh(request)
self.apply(headers)
# Otherwise, generate a one-time token using the URL
# (without the query string and fragment) as the audience.
else:
token = self._make_one_time_jwt(url)
self.apply(headers, token=token)
9 changes: 7 additions & 2 deletions google/oauth2/service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def from_service_account_file(cls, filename, **kwargs):
filename, require=['client_email', 'token_uri'])
return cls._from_signer_and_info(signer, info, **kwargs)

def to_jwt_credentials(self):
def to_jwt_credentials(self, audience):
"""Creates a :class:`google.auth.jwt.Credentials` instance from this
instance.
Expand All @@ -223,13 +223,18 @@ def to_jwt_credentials(self):
jwt_creds = jwt.Credentials.from_service_account_file(
'service_account.json')
Args:
audience (str): the `aud` claim. The intended audience for the
credentials.
Returns:
google.auth.jwt.Credentials: A new Credentials instance.
"""
return jwt.Credentials(
self._signer,
issuer=self._service_account_email,
subject=self._service_account_email)
subject=self._service_account_email,
audience=audience)

@property
def service_account_email(self):
Expand Down
5 changes: 4 additions & 1 deletion system_tests/test_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def test_grpc_request_with_regular_credentials(http_request):

def test_grpc_request_with_jwt_credentials(http_request):
credentials, project_id = google.auth.default()
credentials = credentials.to_jwt_credentials()
audience = 'https://{}/google.pubsub.v1.Publisher'.format(
publisher_client.PublisherClient.SERVICE_ADDRESS)
credentials = credentials.to_jwt_credentials(
audience=audience)

channel = google.auth.transport.grpc.secure_authorized_channel(
credentials,
Expand Down
6 changes: 4 additions & 2 deletions tests/oauth2/test_service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ def test_from_service_account_file_args(self):
assert credentials._additional_claims == additional_claims

def test_to_jwt_credentials(self):
jwt_from_svc = self.credentials.to_jwt_credentials()
jwt_from_svc = self.credentials.to_jwt_credentials(
audience=mock.sentinel.audience)
jwt_from_info = jwt.Credentials.from_service_account_info(
SERVICE_ACCOUNT_INFO)
SERVICE_ACCOUNT_INFO,
audience=mock.sentinel.audience)

assert isinstance(jwt_from_svc, jwt.Credentials)
assert jwt_from_svc._signer.key_id == jwt_from_info._signer.key_id
Expand Down
51 changes: 24 additions & 27 deletions tests/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,20 @@ class TestCredentials:
@pytest.fixture(autouse=True)
def credentials_fixture(self, signer):
self.credentials = jwt.Credentials(
signer, self.SERVICE_ACCOUNT_EMAIL)
signer, self.SERVICE_ACCOUNT_EMAIL, self.SERVICE_ACCOUNT_EMAIL,
self.AUDIENCE)

def test_from_service_account_info(self):
with open(SERVICE_ACCOUNT_JSON_FILE, 'r') as fh:
info = json.load(fh)

credentials = jwt.Credentials.from_service_account_info(info)
credentials = jwt.Credentials.from_service_account_info(
info, audience=self.AUDIENCE)

assert credentials._signer.key_id == info['private_key_id']
assert credentials._issuer == info['client_email']
assert credentials._subject == info['client_email']
assert credentials._audience == self.AUDIENCE

def test_from_service_account_info_args(self):
info = SERVICE_ACCOUNT_INFO.copy()
Expand All @@ -235,11 +238,12 @@ def test_from_service_account_file(self):
info = SERVICE_ACCOUNT_INFO.copy()

credentials = jwt.Credentials.from_service_account_file(
SERVICE_ACCOUNT_JSON_FILE)
SERVICE_ACCOUNT_JSON_FILE, audience=self.AUDIENCE)

assert credentials._signer.key_id == info['private_key_id']
assert credentials._issuer == info['client_email']
assert credentials._subject == info['client_email']
assert credentials._audience == self.AUDIENCE

def test_from_service_account_file_args(self):
info = SERVICE_ACCOUNT_INFO.copy()
Expand All @@ -259,6 +263,18 @@ def test_default_state(self):
# Expiration hasn't been set yet
assert not self.credentials.expired

def test_with_claims(self):
new_audience = 'new_audience'
new_credentials = self.credentials.with_claims(
audience=new_audience)

assert new_credentials._signer == self.credentials._signer
assert new_credentials._issuer == self.credentials._issuer
assert new_credentials._subject == self.credentials._subject
assert new_credentials._audience == new_audience
assert (new_credentials._additional_claims ==
self.credentials._additional_claims)

def test_sign_bytes(self):
to_sign = b'123'
signature = self.credentials.sign_bytes(to_sign)
Expand Down Expand Up @@ -292,43 +308,24 @@ def test_expired(self):
now.return_value = self.credentials.expiry + one_day
assert self.credentials.expired

def test_before_request_one_time_token(self):
def test_before_request(self):
headers = {}

self.credentials.refresh(None)
self.credentials.before_request(
mock.Mock(), 'GET', 'http://example.com?a=1#3', headers)

header_value = headers['authorization']
_, token = header_value.split(' ')

# This should be a one-off token, so it shouldn't be the same as the
# credentials' stored token.
assert token != self.credentials.token

payload = self._verify_token(token)
assert payload['aud'] == 'http://example.com'

def test_before_request_with_preset_audience(self):
headers = {}

credentials = self.credentials.with_claims(audience=self.AUDIENCE)
credentials.refresh(None)
credentials.before_request(
None, 'GET', 'http://example.com?a=1#3', headers)

header_value = headers['authorization']
_, token = header_value.split(' ')

# Since the audience is set, it should use the existing token.
assert token.encode('utf-8') == credentials.token
assert token.encode('utf-8') == self.credentials.token

payload = self._verify_token(token)
assert payload['aud'] == self.AUDIENCE

def test_before_request_refreshes(self):
credentials = self.credentials.with_claims(audience=self.AUDIENCE)
assert not credentials.valid
credentials.before_request(
assert not self.credentials.valid
self.credentials.before_request(
None, 'GET', 'http://example.com?a=1#3', {})
assert credentials.valid
assert self.credentials.valid

0 comments on commit ab08689

Please sign in to comment.