Skip to content

Commit

Permalink
Change logic of tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
jonra1993 committed Oct 8, 2022
1 parent c66bd21 commit b9ff90c
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 52 deletions.
4 changes: 2 additions & 2 deletions fastapi-alembic-sqlmodel-async/app/api/deps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import AsyncGenerator, List
from uuid import UUID
from fastapi import Depends, HTTPException, status
from app.utils.token import get_valid_tokens
from app.schemas.user_schema import IUserRead
from app.utils.minio_client import MinioClient
from app.schemas.user_schema import IUserCreate
Expand Down Expand Up @@ -59,8 +60,7 @@ async def current_user(
detail="Could not validate credentials",
)
user_id = payload["sub"]
access_token_key = f"user:{user_id}:{TokenType.ACCESS}"
valid_access_tokens = await redis_client.smembers(access_token_key)
valid_access_tokens = await get_valid_tokens(redis_client, user_id, TokenType.ACCESS)
if valid_access_tokens and token not in valid_access_tokens:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
Expand Down
103 changes: 53 additions & 50 deletions fastapi-alembic-sqlmodel-async/app/api/v1/endpoints/login.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from datetime import timedelta
from typing import Any, Optional
from typing import Any
from fastapi import APIRouter, Body, Depends, HTTPException
from app.utils.token import get_valid_tokens
from app.utils.token import delete_tokens
from app.utils.token import add_token_to_redis
from app.schemas.common_schema import TokenType
from app.core.security import get_password_hash
from app.core.security import verify_password
Expand All @@ -17,31 +20,10 @@
from app.schemas.token_schema import TokenRead, Token, RefreshToken
from app.schemas.common_schema import IMetaGeneral, IPostResponseBase, create_response
from aioredis import Redis
from enum import Enum

router = APIRouter()


async def add_token_to_redis(
redis_client: Redis,
user: User,
token: str,
token_type: TokenType,
expire_time: Optional[int] = None,
):
token_key = f"user:{user.id}:{token_type}"
await redis_client.sadd(token_key, token)
if expire_time:
await redis_client.expire(token_key, expire_time)


async def delete_tokens(redis_client: Redis, user: User, token_type: TokenType):
token_key = f"user:{user.id}:{token_type}"
valid_tokens = await redis_client.smembers(token_key)
if valid_tokens is not None:
await redis_client.delete(token_key)


@router.post("", response_model=IPostResponseBase[Token])
async def login(
email: EmailStr = Body(...),
Expand Down Expand Up @@ -71,20 +53,28 @@ async def login(
refresh_token=refresh_token,
user=user,
)
await add_token_to_redis(
redis_client,
user,
access_token,
TokenType.ACCESS,
settings.ACCESS_TOKEN_EXPIRE_MINUTES,
valid_access_tokens = await get_valid_tokens(
redis_client, user.id, TokenType.ACCESS
)
await add_token_to_redis(
redis_client,
user,
refresh_token,
TokenType.REFRESH,
settings.REFRESH_TOKEN_EXPIRE_MINUTES,
if valid_access_tokens:
await add_token_to_redis(
redis_client,
user,
access_token,
TokenType.ACCESS,
settings.ACCESS_TOKEN_EXPIRE_MINUTES,
)
valid_refresh_tokens = await get_valid_tokens(
redis_client, user.id, TokenType.ACCESS
)
if valid_refresh_tokens:
await add_token_to_redis(
redis_client,
user,
refresh_token,
TokenType.REFRESH,
settings.REFRESH_TOKEN_EXPIRE_MINUTES,
)

return create_response(meta=meta_data, data=data, message="Login correctly")

Expand All @@ -103,9 +93,12 @@ async def change_password(
if not verify_password(current_password, current_user.hashed_password):
raise HTTPException(status_code=400, detail="Invalid Current Password")

hashed_password = get_password_hash(new_password)
if verify_password(new_password, current_user.hashed_password):
raise HTTPException(status_code=400, detail="New Password should be different that the current one")

new_hashed_password = get_password_hash(new_password)
await crud.user.update(
obj_current=current_user, obj_new={"hashed_password": hashed_password}
obj_current=current_user, obj_new={"hashed_password": new_hashed_password}
)

access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
Expand Down Expand Up @@ -162,8 +155,9 @@ async def get_refresh_token(

if payload["type"] == "refresh":
user_id = payload["sub"]
refresh_token_key = f"user:{user_id}:{TokenType.REFRESH}"
valid_refresh_tokens = await redis_client.smembers(refresh_token_key)
valid_refresh_tokens = await get_valid_tokens(
redis_client, user_id, TokenType.REFRESH
)
if valid_refresh_tokens and body.refresh_token not in valid_refresh_tokens:
raise HTTPException(status_code=403, detail="Refresh token invalid")

Expand All @@ -173,13 +167,18 @@ async def get_refresh_token(
access_token = security.create_access_token(
payload["sub"], expires_delta=access_token_expires
)
await add_token_to_redis(
redis_client,
user,
access_token,
TokenType.ACCESS,
settings.ACCESS_TOKEN_EXPIRE_MINUTES,
valid_access_get_valid_tokens = await get_valid_tokens(
redis_client, user.id, TokenType.ACCESS
)
if valid_access_get_valid_tokens:
print("her")
await add_token_to_redis(
redis_client,
user,
access_token,
TokenType.ACCESS,
settings.ACCESS_TOKEN_EXPIRE_MINUTES,
)
return create_response(
data=TokenRead(access_token=access_token, token_type="bearer"),
message="Access token generated correctly",
Expand Down Expand Up @@ -209,13 +208,17 @@ async def login_access_token(
access_token = security.create_access_token(
user.id, expires_delta=access_token_expires
)
await add_token_to_redis(
redis_client,
user,
access_token,
TokenType.ACCESS,
settings.ACCESS_TOKEN_EXPIRE_MINUTES,
valid_access_tokens = await get_valid_tokens(
redis_client, user.id, TokenType.ACCESS
)
if valid_access_tokens:
await add_token_to_redis(
redis_client,
user,
access_token,
TokenType.ACCESS,
settings.ACCESS_TOKEN_EXPIRE_MINUTES,
)
return {
"access_token": access_token,
"token_type": "bearer",
Expand Down
32 changes: 32 additions & 0 deletions fastapi-alembic-sqlmodel-async/app/utils/token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Optional
from uuid import UUID
from aioredis import Redis
from app.models.user_model import User
from app.schemas.common_schema import TokenType


async def add_token_to_redis(
redis_client: Redis,
user: User,
token: str,
token_type: TokenType,
expire_time: Optional[int] = None,
):
token_key = f"user:{user.id}:{token_type}"
valid_tokens = await get_valid_tokens(redis_client,user.id,token_type)
await redis_client.sadd(token_key, token)
if not valid_tokens:
await redis_client.expire(token_key, expire_time)


async def get_valid_tokens(redis_client: Redis, user_id: UUID, token_type: TokenType):
token_key = f"user:{user_id}:{token_type}"
valid_tokens = await redis_client.smembers(token_key)
return valid_tokens


async def delete_tokens(redis_client: Redis, user: User, token_type: TokenType):
token_key = f"user:{user.id}:{token_type}"
valid_tokens = await redis_client.smembers(token_key)
if valid_tokens is not None:
await redis_client.delete(token_key)

0 comments on commit b9ff90c

Please sign in to comment.