diff --git a/docs/changelog.rst b/docs/changelog.rst index 6f97097..fa8630b 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -12,6 +12,8 @@ UNRELEASED forced to HTTPS, because the `the OIDC spec`_ is actually only a strong recommendation (:issue:`35`). You can use ``OVERWRITE_REDIRECT_URI`` if you want to force it to HTTPS (or any other URL). +- Handle token expiration when there is no ``refresh_token`` or no token URL + (:issue:`39`) 2.0.3 (2023-09-08) diff --git a/flask_oidc/__init__.py b/flask_oidc/__init__.py index 5e984e4..5ff8986 100644 --- a/flask_oidc/__init__.py +++ b/flask_oidc/__init__.py @@ -30,6 +30,7 @@ from urllib.parse import quote_plus from authlib.common.errors import AuthlibBaseError +from authlib.integrations.base_client import InvalidTokenError from authlib.integrations.flask_client import OAuth from authlib.integrations.flask_oauth2 import ResourceProtector from authlib.oauth2.rfc6749 import OAuth2Token @@ -216,7 +217,12 @@ def check_token_expiry(self): def ensure_active_token(self, token: OAuth2Token): metadata = self.oauth.oidc.load_server_metadata() with self.oauth.oidc._get_oauth_client(**metadata) as session: - return session.ensure_active_token(token) + result = session.ensure_active_token(token) + if result is None: + # See the ensure_active_token method in + # authlib.integrations.requests_client.oauth2_session:OAuth2Auth + raise InvalidTokenError() + return result def _update_token(name, token, refresh_token=None, access_token=None): session["oidc_auth_token"] = g.oidc_id_token = token diff --git a/tests/test_flask_oidc.py b/tests/test_flask_oidc.py index 6d02d6a..b5d5a87 100644 --- a/tests/test_flask_oidc.py +++ b/tests/test_flask_oidc.py @@ -26,6 +26,12 @@ def callback_url_for(response): return f"{query['redirect_uri'][0]}?state={query['state'][0]}&code=mock_auth_code" +def _set_token(client, token): + with client.session_transaction() as session: + session["oidc_auth_token"] = token + session["oidc_auth_profile"] = {"nickname": "dummy"} + + def test_signin(test_app, client, mocked_responses, dummy_token): """Happy path authentication test.""" mocked_responses.post("https://test/openidc/Token", json=dummy_token) @@ -90,9 +96,7 @@ def test_expired_token(client, dummy_token, mocked_responses): refresh_call = mocked_responses.post("https://test/openidc/Token", json=new_token) dummy_token["expires_at"] = int(time.time()) - with client.session_transaction() as session: - session["oidc_auth_token"] = dummy_token - session["oidc_auth_profile"] = {"nickname": "dummy"} + _set_token(client, dummy_token) resp = client.get("/") @@ -115,16 +119,12 @@ def test_expired_token(client, dummy_token, mocked_responses): def test_expired_token_cant_renew(client, dummy_token, mocked_responses): - new_token = dummy_token.copy() - new_token["access_token"] = "this-is-new" refresh_call = mocked_responses.post( "https://test/openidc/Token", json={"error": "dummy"}, status=401 ) dummy_token["expires_at"] = int(time.time()) - with client.session_transaction() as session: - session["oidc_auth_token"] = dummy_token - session["oidc_auth_profile"] = {"nickname": "dummy"} + _set_token(client, dummy_token) resp = client.get("/") @@ -137,10 +137,23 @@ def test_expired_token_cant_renew(client, dummy_token, mocked_responses): assert "oidc_auth_token" not in flask.session +def test_expired_token_no_refresh_token(client, dummy_token): + del dummy_token["refresh_token"] + dummy_token["expires_at"] = int(time.time()) + _set_token(client, dummy_token) + + resp = client.get("/") + + assert resp.status_code == 302 + assert resp.location == "/logout?reason=expired" + resp = client.get(resp.location) + assert resp.status_code == 302 + assert resp.location == "http://localhost/" + assert "oidc_auth_token" not in flask.session + + def test_bad_token(client): - with client.session_transaction() as session: - session["oidc_auth_token"] = "bad_token" - session["oidc_auth_profile"] = {"nickname": "dummy"} + _set_token(client, "bad_token") resp = client.get("/") assert resp.status_code == 500 assert "Internal Server Error" in resp.get_data(as_text=True)