Skip to content

Commit

Permalink
Update for flask_jwt_extended >= 4
Browse files Browse the repository at this point in the history
* Rename TokenBlacklist to TokenBlocklist for consistency with flask
jwt extended
* Update functions and hooks to match flask jwt extended 4.0.0
  • Loading branch information
karec committed Mar 1, 2021
1 parent e1c4a2a commit f1e9a29
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@ def upgrade():
sa.UniqueConstraint("username"),
)
op.create_table(
"token_blacklist",
"token_blocklist",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("jti", sa.String(length=36), nullable=False),
sa.Column("token_type", sa.String(length=10), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("revoked", sa.Boolean(), nullable=False),
sa.Column("expires", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.id"],),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("jti"),
)
Expand All @@ -46,6 +49,6 @@ def upgrade():

def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("token_blacklist")
op.drop_table("token_blocklist")
op.drop_table("user")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class UserResource(Resource):
description: user does not exists
"""

method_decorators = [jwt_required]
method_decorators = [jwt_required()]

def get(self, user_id):
schema = UserSchema()
Expand Down Expand Up @@ -142,7 +142,7 @@ class UserList(Resource):
user: UserSchema
"""

method_decorators = [jwt_required]
method_decorators = [jwt_required()]

def get(self):
schema = UserSchema(many=True)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Various helpers for auth. Mainly about tokens blacklisting
"""Various helpers for auth. Mainly about tokens blocklisting
heavily inspired by
https://github.com/vimalloc/flask-jwt-extended/blob/master/examples/database_blacklist/blacklist_helpers.py
Heavily inspired by
https://github.com/vimalloc/flask-jwt-extended/blob/master/examples/blocklist_database.py
"""
from datetime import datetime

from flask_jwt_extended import decode_token
from sqlalchemy.orm.exc import NoResultFound

from {{cookiecutter.app_name}}.extensions import db
from {{cookiecutter.app_name}}.models import TokenBlacklist
from {{cookiecutter.app_name}}.models import TokenBlocklist


def add_token_to_database(encoded_token, identity_claim):
Expand All @@ -25,7 +25,7 @@ def add_token_to_database(encoded_token, identity_claim):
expires = datetime.fromtimestamp(decoded_token["exp"])
revoked = False

db_token = TokenBlacklist(
db_token = TokenBlocklist(
jti=jti,
token_type=token_type,
user_id=user_identity,
Expand All @@ -36,16 +36,16 @@ def add_token_to_database(encoded_token, identity_claim):
db.session.commit()


def is_token_revoked(decoded_token):
def is_token_revoked(jwt_payload):
"""
Checks if the given token is revoked or not. Because we are adding all the
tokens that we create into this database, if the token is not present
in the database we are going to consider it revoked, as we don't know where
it was created.
"""
jti = decoded_token["jti"]
jti = jwt_payload["jti"]
try:
token = TokenBlacklist.query.filter_by(jti=jti).one()
token = TokenBlocklist.query.filter_by(jti=jti).one()
return token.revoked
except NoResultFound:
return True
Expand All @@ -58,7 +58,7 @@ def revoke_token(token_jti, user):
if token is not found we raise an exception
"""
try:
token = TokenBlacklist.query.filter_by(jti=token_jti, user_id=user).one()
token = TokenBlocklist.query.filter_by(jti=token_jti, user_id=user).one()
token.revoked = True
db.session.commit()
except NoResultFound:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
create_access_token,
create_refresh_token,
jwt_required,
jwt_refresh_token_required,
get_jwt_identity,
get_raw_jwt,
get_jwt,
)

from {{cookiecutter.app_name}}.models import User
Expand Down Expand Up @@ -77,7 +76,7 @@ def login():


@blueprint.route("/refresh", methods=["POST"])
@jwt_refresh_token_required
@jwt_required(refresh=True)
def refresh():
"""Get an access token from a refresh token
Expand Down Expand Up @@ -113,7 +112,7 @@ def refresh():


@blueprint.route("/revoke_access", methods=["DELETE"])
@jwt_required
@jwt_required()
def revoke_access_token():
"""Revoke an access token
Expand All @@ -136,14 +135,14 @@ def revoke_access_token():
401:
description: unauthorized
"""
jti = get_raw_jwt()["jti"]
jti = get_jwt()["jti"]
user_identity = get_jwt_identity()
revoke_token(jti, user_identity)
return jsonify({"message": "token revoked"}), 200


@blueprint.route("/revoke_refresh", methods=["DELETE"])
@jwt_refresh_token_required
@jwt_required(refresh=True)
def revoke_refresh_token():
"""Revoke a refresh token, used mainly for logout
Expand All @@ -166,20 +165,21 @@ def revoke_refresh_token():
401:
description: unauthorized
"""
jti = get_raw_jwt()["jti"]
jti = get_jwt()["jti"]
user_identity = get_jwt_identity()
revoke_token(jti, user_identity)
return jsonify({"message": "token revoked"}), 200


@jwt.user_loader_callback_loader
def user_loader_callback(identity):
@jwt.user_lookup_loader
def user_loader_callback(jwt_headers, jwt_payload):
identity = jwt_payload["sub"]
return User.query.get(identity)


@jwt.token_in_blacklist_loader
def check_if_token_revoked(decoded_token):
return is_token_revoked(decoded_token)
@jwt.token_in_blocklist_loader
def check_if_token_revoked(jwt_headers, jwt_payload):
return is_token_revoked(jwt_payload)


@blueprint.before_app_first_request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
SQLALCHEMY_DATABASE_URI = os.getenv("DATABASE_URI")
SQLALCHEMY_TRACK_MODIFICATIONS = False

JWT_BLACKLIST_ENABLED = True
JWT_BLACKLIST_TOKEN_CHECKS = ["access", "refresh"]
{%- if cookiecutter.use_celery == "yes" %}
CELERY = {
"broker_url": os.getenv("CELERY_BROKER_URL"),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from {{cookiecutter.app_name}}.models.user import User
from {{cookiecutter.app_name}}.models.blacklist import TokenBlacklist
from {{cookiecutter.app_name}}.models.blocklist import TokenBlocklist


__all__ = ["User", "TokenBlacklist"]
__all__ = ["User", "TokenBlocklist"]
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
"""Simple blacklist implementation using database
"""Simple blocklist implementation using database
Using database may not be your prefered solution to handle blacklist in your
Using database may not be your prefered solution to handle blocklist in your
final application, but remember that's just a cookiecutter template. Feel free
to dump this code and adapt it for your needs.
For this reason, we don't include advanced tokens management in this
example (view all tokens for a user, revoke from api, etc.)
If we choose to use database to handle blacklist in this example, it's mainly
If we choose to use database to handle blocklist in this example, it's mainly
because it will allow you to run the example without needing to setup anything else
like a redis or a memcached server.
This example is heavily inspired by
https://github.com/vimalloc/flask-jwt-extended/blob/master/examples/database_blacklist/
https://github.com/vimalloc/flask-jwt-extended/blob/master/examples/blocklist_database.py
"""
from {{cookiecutter.app_name}}.extensions import db


class TokenBlacklist(db.Model):
"""Blacklist representation"""
class TokenBlocklist(db.Model):
"""Blocklist representation"""

id = db.Column(db.Integer, primary_key=True)
jti = db.Column(db.String(36), nullable=False, unique=True)
Expand Down

0 comments on commit f1e9a29

Please sign in to comment.