Skip to content

Commit

Permalink
Reduce code duplications
Browse files Browse the repository at this point in the history
  • Loading branch information
Dennis Lee committed Mar 6, 2024
1 parent 92bd832 commit 1575038
Showing 1 changed file with 27 additions and 51 deletions.
78 changes: 27 additions & 51 deletions app/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,87 +19,63 @@
async def get_valid_access_token(
token: Annotated[str, Depends(oauth2_scheme)]) -> Dict[str, Any]:
"""
Decode and validate a access token, and return the payload if valid.
Retrieves and validates the access token.
Args:
token (str): The access token to decode and validate.
token (str): The access token to be validated.
Returns:
dict: The payload of the access token.
Raises:
HTTPException: If the token is invalid, expired, or revoked.
"""
status_code = status.HTTP_401_UNAUTHORIZED

try:
# Set the options for the JWT token validation
options = {
"require_exp": True, # expiration time
"require_sub": True, # user id
"require_jti": True, # token id
"require_sid": True, # session id
"require_email": True # email
}

# Decode, validate JWT token and get the payload
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM], options=options)
return await _decode_and_validate_token(token, "access")

# check if the token's session is in active session cache
user_id = payload.get("sub")
session_id = payload.get("sid")

if not await session_store.exists(user_id=user_id, session_id=session_id):
raise HTTPException(status_code=status_code, detail="Session was revoked or expired")

# check if the token is in active token cache
token_id = payload.get("jti")

if not await token_store.exists(token_id=token_id):
# for security measures, revoke token's session as well
print(f"Token '{token_id}' was revoked or expired, removing session" +
f"'{session_id}' as security measure")
await session_store.remove(user_id=user_id, session_id=session_id)
async def get_valid_refresh_token(
token: Annotated[str, Depends(oauth2_scheme)]) -> Dict[str, Any]:
"""
Retrieves and validates the refresh token.
raise HTTPException(status_code=status_code,
detail="Access token was revoked or expired")
Args:
token (str): The refresh token to be validated.
return payload
Returns:
dict: The payload of the refresh token.
except (JWTError, ValidationError) as exc:
# If there is an error while decoding the token, token is invalid
raise HTTPException(status_code=status_code, detail=str(exc)) from exc
"""
return await _decode_and_validate_token(token, "refresh")

async def get_valid_refresh_token(
token: Annotated[str, Depends(oauth2_scheme)]) -> Dict[str, Any]:
async def _decode_and_validate_token(token: str, token_type: str) -> Dict[str, Any]:
"""
Decode and validate a refresh token, and return the payload if valid.
Decode and validate a token, and return the payload if valid.
Args:
token (str): The refresh token to decode and validate.
token (str): The token to decode and validate.
token_type (str): The type of the token (access or refresh).
Returns:
dict: The payload of the refresh token.
dict: The payload of the token.
Raises:
HTTPException: If the token is invalid, expired, or revoked.
"""

status_code = status.HTTP_401_UNAUTHORIZED

try:
# Set the options for the JWT token validation
options = {
"require_exp": True, # expiration time
"require_sub": True, # user id
"require_jti": True, # token id
"require_sid": True # session id
"require_exp": True, # expiration time
"require_sub": True, # user id
"require_jti": True, # token id
"require_sid": True, # session id
}

if token_type == "access":
options["require_email"] = True # email

# Decode, validate JWT token and get the payload
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM], options=options)

if 'email' in payload:
if token_type == "refresh" and 'email' in payload:
raise HTTPException(status_code=status_code,
detail="Invalid token type, refresh token was expected")

Expand All @@ -115,7 +91,7 @@ async def get_valid_refresh_token(

if not await token_store.exists(token_id=token_id):
# for security measures, revoke token's session as well
print(f"Token '{token_id}' was revoked or expired, removing session" +
print(f"{token_type} token '{token_id}' was revoked or expired, removing session" +
f"'{session_id}' as security measure")
await session_store.remove(user_id=user_id, session_id=session_id)

Expand Down

0 comments on commit 1575038

Please sign in to comment.