Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom header configuration in jwk client #823

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Added
- Add ``compute_hash_digest`` as a method of ``Algorithm`` objects, which uses
the underlying hash algorithm to compute a digest. If there is no appropriate
hash algorithm, a ``NotImplementedError`` will be raised
- Add optional ``headers`` argument to ``PyJWKClient``. If provided, the headers
will be included in requests that the client uses when fetching the JWK set.

`v2.6.0 <https://github.com/jpadilla/pyjwt/compare/2.5.0...2.6.0>`__
-----------------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ Retrieve RSA signing keys from a JWKS endpoint
>>> token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ik5FRTFRVVJCT1RNNE16STVSa0ZETlRZeE9UVTFNRGcyT0Rnd1EwVXpNVGsxUWpZeVJrUkZRdyJ9.eyJpc3MiOiJodHRwczovL2Rldi04N2V2eDlydS5hdXRoMC5jb20vIiwic3ViIjoiYVc0Q2NhNzl4UmVMV1V6MGFFMkg2a0QwTzNjWEJWdENAY2xpZW50cyIsImF1ZCI6Imh0dHBzOi8vZXhwZW5zZXMtYXBpIiwiaWF0IjoxNTcyMDA2OTU0LCJleHAiOjE1NzIwMDY5NjQsImF6cCI6ImFXNENjYTc5eFJlTFdVejBhRTJINmtEME8zY1hCVnRDIiwiZ3R5IjoiY2xpZW50LWNyZWRlbnRpYWxzIn0.PUxE7xn52aTCohGiWoSdMBZGiYAHwE5FYie0Y1qUT68IHSTXwXVd6hn02HTah6epvHHVKA2FqcFZ4GGv5VTHEvYpeggiiZMgbxFrmTEY0csL6VNkX1eaJGcuehwQCRBKRLL3zKmA5IKGy5GeUnIbpPHLHDxr-GXvgFzsdsyWlVQvPX2xjeaQ217r2PtxDeqjlf66UYl6oY6AqNS8DH3iryCvIfCcybRZkc_hdy-6ZMoKT6Piijvk_aXdm7-QQqKJFHLuEqrVSOuBqqiNfVrG27QzAPuPOxvfXTVLXL2jek5meH6n-VWgrBdoMFH93QEszEDowDAEhQPHVs0xj7SIzA"
>>> kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw"
>>> url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
>>> jwks_client = PyJWKClient(url)
>>> optional_custom_headers = {"User-agent": "custom-user-agent"}
>>> jwks_client = PyJWKClient(url, headers=optional_custom_headers)
>>> signing_key = jwks_client.get_signing_key_from_jwt(token)
>>> data = jwt.decode(
... token,
Expand Down
5 changes: 4 additions & 1 deletion jwt/jwks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ def __init__(
max_cached_keys: int = 16,
cache_jwk_set: bool = True,
lifespan: int = 300,
headers: dict = {},
):
self.uri = uri
self.jwk_set_cache: Optional[JWKSetCache] = None
self.headers = headers

if cache_jwk_set:
# Init jwt set cache with default or given lifespan.
Expand All @@ -41,7 +43,8 @@ def __init__(
def fetch_data(self) -> Any:
jwk_set: Any = None
try:
with urllib.request.urlopen(self.uri) as response:
r = urllib.request.Request(url=self.uri, headers=self.headers)
with urllib.request.urlopen(r) as response:
jwk_set = json.load(response)
except URLError as e:
raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"')
Expand Down
13 changes: 13 additions & 0 deletions tests/test_jwks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ def mocked_first_call_wrong_kid_second_call_correct_kid(

@crypto_required
class TestPyJWKClient:
def test_fetch_data_forwards_headers_to_correct_url(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"

with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as mock_request:
custom_headers = {"User-agent": "my-custom-agent"}
jwks_client = PyJWKClient(url, headers=custom_headers)
jwk_set = jwks_client.get_jwk_set()
request_params = mock_request.call_args[0][0]
assert request_params.full_url == url
assert request_params.headers == custom_headers

assert len(jwk_set.keys) == 1

def test_get_jwk_set(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"

Expand Down