Skip to content

Commit

Permalink
feat: adding typehints (#683)
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas de Jong <thomas.dejong@elements.nl>
  • Loading branch information
abczzz13 and thomasdejongelements committed Mar 10, 2023
1 parent 960ab2b commit 8258b5f
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 126 deletions.
30 changes: 20 additions & 10 deletions rest_framework_simplejwt/authentication.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
from typing import Optional, Set, Tuple, TypeVar

from django.contrib.auth import get_user_model
from django.contrib.auth.models import AbstractBaseUser
from django.utils.translation import gettext_lazy as _
from rest_framework import HTTP_HEADER_ENCODING, authentication
from rest_framework.request import Request

from .exceptions import AuthenticationFailed, InvalidToken, TokenError
from .models import TokenUser
from .settings import api_settings
from .tokens import Token

AUTH_HEADER_TYPES = api_settings.AUTH_HEADER_TYPES

if not isinstance(api_settings.AUTH_HEADER_TYPES, (list, tuple)):
AUTH_HEADER_TYPES = (AUTH_HEADER_TYPES,)

AUTH_HEADER_TYPE_BYTES = {h.encode(HTTP_HEADER_ENCODING) for h in AUTH_HEADER_TYPES}
AUTH_HEADER_TYPE_BYTES: Set[bytes] = {
h.encode(HTTP_HEADER_ENCODING) for h in AUTH_HEADER_TYPES
}

AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)


class JWTAuthentication(authentication.BaseAuthentication):
Expand All @@ -22,11 +32,11 @@ class JWTAuthentication(authentication.BaseAuthentication):
www_authenticate_realm = "api"
media_type = "application/json"

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.user_model = get_user_model()

def authenticate(self, request):
def authenticate(self, request: Request) -> Optional[Tuple[AuthUser, Token]]:
header = self.get_header(request)
if header is None:
return None
Expand All @@ -39,13 +49,13 @@ def authenticate(self, request):

return self.get_user(validated_token), validated_token

def authenticate_header(self, request):
def authenticate_header(self, request: Request) -> str:
return '{} realm="{}"'.format(
AUTH_HEADER_TYPES[0],
self.www_authenticate_realm,
)

def get_header(self, request):
def get_header(self, request: Request) -> bytes:
"""
Extracts the header containing the JSON web token from the given
request.
Expand All @@ -58,7 +68,7 @@ def get_header(self, request):

return header

def get_raw_token(self, header):
def get_raw_token(self, header: bytes) -> Optional[bytes]:
"""
Extracts an unvalidated JSON web token from the given "Authorization"
header value.
Expand All @@ -81,7 +91,7 @@ def get_raw_token(self, header):

return parts[1]

def get_validated_token(self, raw_token):
def get_validated_token(self, raw_token: bytes) -> Token:
"""
Validates an encoded JSON web token and returns a validated token
wrapper object.
Expand All @@ -106,7 +116,7 @@ def get_validated_token(self, raw_token):
}
)

def get_user(self, validated_token):
def get_user(self, validated_token: Token) -> AuthUser:
"""
Attempts to find and return a user using the given validated token.
"""
Expand All @@ -132,7 +142,7 @@ class JWTStatelessUserAuthentication(JWTAuthentication):
token provided in a request header without performing a database lookup to obtain a user instance.
"""

def get_user(self, validated_token):
def get_user(self, validated_token: Token) -> AuthUser:
"""
Returns a stateless user object which is backed by the given validated
token.
Expand All @@ -148,7 +158,7 @@ def get_user(self, validated_token):
JWTTokenUserAuthentication = JWTStatelessUserAuthentication


def default_user_authentication_rule(user):
def default_user_authentication_rule(user: AuthUser) -> bool:
# Prior to Django 1.10, inactive users could be authenticated with the
# default `ModelBackend`. As of Django 1.10, the `ModelBackend`
# prevents inactive users from authenticating. App designers can still
Expand Down
28 changes: 15 additions & 13 deletions rest_framework_simplejwt/backends.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import json
from collections.abc import Iterable
from datetime import timedelta
from typing import Optional, Type, Union
from typing import Any, Dict, Optional, Type, Union

import jwt
from django.utils.translation import gettext_lazy as _
from jwt import InvalidAlgorithmError, InvalidTokenError, algorithms

from .exceptions import TokenBackendError
from .tokens import Token
from .utils import format_lazy

try:
Expand All @@ -32,15 +34,15 @@
class TokenBackend:
def __init__(
self,
algorithm,
signing_key=None,
verifying_key="",
audience=None,
issuer=None,
jwk_url: str = None,
leeway: Union[float, int, timedelta] = None,
algorithm: str,
signing_key: Optional[str] = None,
verifying_key: str = "",
audience: Union[str, Iterable, None] = None,
issuer: Optional[str] = None,
jwk_url: Optional[str] = None,
leeway: Union[float, int, timedelta, None] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
):
) -> None:
self._validate_algorithm(algorithm)

self.algorithm = algorithm
Expand All @@ -57,7 +59,7 @@ def __init__(
self.leeway = leeway
self.json_encoder = json_encoder

def _validate_algorithm(self, algorithm):
def _validate_algorithm(self, algorithm: str) -> None:
"""
Ensure that the nominated algorithm is recognized, and that cryptography is installed for those
algorithms that require it
Expand Down Expand Up @@ -91,7 +93,7 @@ def get_leeway(self) -> timedelta:
)
)

def get_verifying_key(self, token):
def get_verifying_key(self, token: Token) -> Optional[str]:
if self.algorithm.startswith("HS"):
return self.signing_key

Expand All @@ -103,7 +105,7 @@ def get_verifying_key(self, token):

return self.verifying_key

def encode(self, payload):
def encode(self, payload: Dict[str, Any]) -> str:
"""
Returns an encoded token for the given payload dictionary.
"""
Expand All @@ -125,7 +127,7 @@ def encode(self, payload):
# For PyJWT >= 2.0.0a1
return token

def decode(self, token, verify=True):
def decode(self, token: Token, verify: bool = True) -> Dict[str, Any]:
"""
Performs a validation of the given token and returns its payload
dictionary.
Expand Down
13 changes: 11 additions & 2 deletions rest_framework_simplejwt/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict, Optional, Union

from django.utils.translation import gettext_lazy as _
from rest_framework import exceptions, status

Expand All @@ -11,7 +13,14 @@ class TokenBackendError(Exception):


class DetailDictMixin:
def __init__(self, detail=None, code=None):
default_detail: str
default_code: str

def __init__(
self,
detail: Union[Dict[str, Any], str, None] = None,
code: Optional[str] = None,
) -> None:
"""
Builds a detail dictionary for the error to give more information to API
users.
Expand All @@ -26,7 +35,7 @@ def __init__(self, detail=None, code=None):
if code is not None:
detail_dict["code"] = code

super().__init__(detail_dict)
super().__init__(detail_dict) # type: ignore


class AuthenticationFailed(DetailDictMixin, exceptions.AuthenticationFailed):
Expand Down
57 changes: 32 additions & 25 deletions rest_framework_simplejwt/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import TYPE_CHECKING, Any, List, Optional, Union

from django.contrib.auth import models as auth_models
from django.db.models.manager import EmptyManager
from django.utils.functional import cached_property

from .settings import api_settings

if TYPE_CHECKING:
from .tokens import Token


class TokenUser:
"""
Expand All @@ -22,87 +27,89 @@ class instead of a `User` model instance. Instances of this class act as
_groups = EmptyManager(auth_models.Group)
_user_permissions = EmptyManager(auth_models.Permission)

def __init__(self, token):
def __init__(self, token: "Token") -> None:
self.token = token

def __str__(self):
def __str__(self) -> str:
return f"TokenUser {self.id}"

@cached_property
def id(self):
def id(self) -> Union[int, str]:
return self.token[api_settings.USER_ID_CLAIM]

@cached_property
def pk(self):
def pk(self) -> Union[int, str]:
return self.id

@cached_property
def username(self):
def username(self) -> str:
return self.token.get("username", "")

@cached_property
def is_staff(self):
def is_staff(self) -> bool:
return self.token.get("is_staff", False)

@cached_property
def is_superuser(self):
def is_superuser(self) -> bool:
return self.token.get("is_superuser", False)

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, TokenUser):
return NotImplemented
return self.id == other.id

def __ne__(self, other):
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

def __hash__(self):
def __hash__(self) -> int:
return hash(self.id)

def save(self):
def save(self) -> None:
raise NotImplementedError("Token users have no DB representation")

def delete(self):
def delete(self) -> None:
raise NotImplementedError("Token users have no DB representation")

def set_password(self, raw_password):
def set_password(self, raw_password: str) -> None:
raise NotImplementedError("Token users have no DB representation")

def check_password(self, raw_password):
def check_password(self, raw_password: str) -> None:
raise NotImplementedError("Token users have no DB representation")

@property
def groups(self):
def groups(self) -> auth_models.Group:
return self._groups

@property
def user_permissions(self):
def user_permissions(self) -> auth_models.Permission:
return self._user_permissions

def get_group_permissions(self, obj=None):
def get_group_permissions(self, obj: Optional[object] = None) -> set:
return set()

def get_all_permissions(self, obj=None):
def get_all_permissions(self, obj: Optional[object] = None) -> set:
return set()

def has_perm(self, perm, obj=None):
def has_perm(self, perm: str, obj: Optional[object] = None) -> bool:
return False

def has_perms(self, perm_list, obj=None):
def has_perms(self, perm_list: List[str], obj: Optional[object] = None) -> bool:
return False

def has_module_perms(self, module):
def has_module_perms(self, module: str) -> bool:
return False

@property
def is_anonymous(self):
def is_anonymous(self) -> bool:
return False

@property
def is_authenticated(self):
def is_authenticated(self) -> bool:
return True

def get_username(self):
def get_username(self) -> str:
return self.username

def __getattr__(self, attr):
def __getattr__(self, attr: str) -> Optional[Any]:
"""This acts as a backup attribute getter for custom claims defined in Token serializers."""
return self.token.get(attr, None)

0 comments on commit 8258b5f

Please sign in to comment.