diff --git a/fastapi-alembic-sqlmodel-async/app/api/deps.py b/fastapi-alembic-sqlmodel-async/app/api/deps.py index bd67bd4..4a9a778 100644 --- a/fastapi-alembic-sqlmodel-async/app/api/deps.py +++ b/fastapi-alembic-sqlmodel-async/app/api/deps.py @@ -1,6 +1,7 @@ from typing import AsyncGenerator, List from uuid import UUID from fastapi import Depends, HTTPException, status +from app.utils.token import get_valid_tokens from app.schemas.user_schema import IUserRead from app.utils.minio_client import MinioClient from app.schemas.user_schema import IUserCreate @@ -59,8 +60,7 @@ async def current_user( detail="Could not validate credentials", ) user_id = payload["sub"] - access_token_key = f"user:{user_id}:{TokenType.ACCESS}" - valid_access_tokens = await redis_client.smembers(access_token_key) + valid_access_tokens = await get_valid_tokens(redis_client, user_id, TokenType.ACCESS) if valid_access_tokens and token not in valid_access_tokens: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/fastapi-alembic-sqlmodel-async/app/api/v1/endpoints/login.py b/fastapi-alembic-sqlmodel-async/app/api/v1/endpoints/login.py index 3f70113..b0eb5a2 100644 --- a/fastapi-alembic-sqlmodel-async/app/api/v1/endpoints/login.py +++ b/fastapi-alembic-sqlmodel-async/app/api/v1/endpoints/login.py @@ -1,6 +1,9 @@ from datetime import timedelta -from typing import Any, Optional +from typing import Any from fastapi import APIRouter, Body, Depends, HTTPException +from app.utils.token import get_valid_tokens +from app.utils.token import delete_tokens +from app.utils.token import add_token_to_redis from app.schemas.common_schema import TokenType from app.core.security import get_password_hash from app.core.security import verify_password @@ -17,31 +20,10 @@ from app.schemas.token_schema import TokenRead, Token, RefreshToken from app.schemas.common_schema import IMetaGeneral, IPostResponseBase, create_response from aioredis import Redis -from enum import Enum router = APIRouter() -async def add_token_to_redis( - redis_client: Redis, - user: User, - token: str, - token_type: TokenType, - expire_time: Optional[int] = None, -): - token_key = f"user:{user.id}:{token_type}" - await redis_client.sadd(token_key, token) - if expire_time: - await redis_client.expire(token_key, expire_time) - - -async def delete_tokens(redis_client: Redis, user: User, token_type: TokenType): - token_key = f"user:{user.id}:{token_type}" - valid_tokens = await redis_client.smembers(token_key) - if valid_tokens is not None: - await redis_client.delete(token_key) - - @router.post("", response_model=IPostResponseBase[Token]) async def login( email: EmailStr = Body(...), @@ -71,20 +53,28 @@ async def login( refresh_token=refresh_token, user=user, ) - await add_token_to_redis( - redis_client, - user, - access_token, - TokenType.ACCESS, - settings.ACCESS_TOKEN_EXPIRE_MINUTES, + valid_access_tokens = await get_valid_tokens( + redis_client, user.id, TokenType.ACCESS ) - await add_token_to_redis( - redis_client, - user, - refresh_token, - TokenType.REFRESH, - settings.REFRESH_TOKEN_EXPIRE_MINUTES, + if valid_access_tokens: + await add_token_to_redis( + redis_client, + user, + access_token, + TokenType.ACCESS, + settings.ACCESS_TOKEN_EXPIRE_MINUTES, + ) + valid_refresh_tokens = await get_valid_tokens( + redis_client, user.id, TokenType.ACCESS ) + if valid_refresh_tokens: + await add_token_to_redis( + redis_client, + user, + refresh_token, + TokenType.REFRESH, + settings.REFRESH_TOKEN_EXPIRE_MINUTES, + ) return create_response(meta=meta_data, data=data, message="Login correctly") @@ -103,9 +93,12 @@ async def change_password( if not verify_password(current_password, current_user.hashed_password): raise HTTPException(status_code=400, detail="Invalid Current Password") - hashed_password = get_password_hash(new_password) + if verify_password(new_password, current_user.hashed_password): + raise HTTPException(status_code=400, detail="New Password should be different that the current one") + + new_hashed_password = get_password_hash(new_password) await crud.user.update( - obj_current=current_user, obj_new={"hashed_password": hashed_password} + obj_current=current_user, obj_new={"hashed_password": new_hashed_password} ) access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) @@ -162,8 +155,9 @@ async def get_refresh_token( if payload["type"] == "refresh": user_id = payload["sub"] - refresh_token_key = f"user:{user_id}:{TokenType.REFRESH}" - valid_refresh_tokens = await redis_client.smembers(refresh_token_key) + valid_refresh_tokens = await get_valid_tokens( + redis_client, user_id, TokenType.REFRESH + ) if valid_refresh_tokens and body.refresh_token not in valid_refresh_tokens: raise HTTPException(status_code=403, detail="Refresh token invalid") @@ -173,13 +167,18 @@ async def get_refresh_token( access_token = security.create_access_token( payload["sub"], expires_delta=access_token_expires ) - await add_token_to_redis( - redis_client, - user, - access_token, - TokenType.ACCESS, - settings.ACCESS_TOKEN_EXPIRE_MINUTES, + valid_access_get_valid_tokens = await get_valid_tokens( + redis_client, user.id, TokenType.ACCESS ) + if valid_access_get_valid_tokens: + print("her") + await add_token_to_redis( + redis_client, + user, + access_token, + TokenType.ACCESS, + settings.ACCESS_TOKEN_EXPIRE_MINUTES, + ) return create_response( data=TokenRead(access_token=access_token, token_type="bearer"), message="Access token generated correctly", @@ -209,13 +208,17 @@ async def login_access_token( access_token = security.create_access_token( user.id, expires_delta=access_token_expires ) - await add_token_to_redis( - redis_client, - user, - access_token, - TokenType.ACCESS, - settings.ACCESS_TOKEN_EXPIRE_MINUTES, + valid_access_tokens = await get_valid_tokens( + redis_client, user.id, TokenType.ACCESS ) + if valid_access_tokens: + await add_token_to_redis( + redis_client, + user, + access_token, + TokenType.ACCESS, + settings.ACCESS_TOKEN_EXPIRE_MINUTES, + ) return { "access_token": access_token, "token_type": "bearer", diff --git a/fastapi-alembic-sqlmodel-async/app/utils/token.py b/fastapi-alembic-sqlmodel-async/app/utils/token.py new file mode 100644 index 0000000..40077d0 --- /dev/null +++ b/fastapi-alembic-sqlmodel-async/app/utils/token.py @@ -0,0 +1,32 @@ +from typing import Optional +from uuid import UUID +from aioredis import Redis +from app.models.user_model import User +from app.schemas.common_schema import TokenType + + +async def add_token_to_redis( + redis_client: Redis, + user: User, + token: str, + token_type: TokenType, + expire_time: Optional[int] = None, +): + token_key = f"user:{user.id}:{token_type}" + valid_tokens = await get_valid_tokens(redis_client,user.id,token_type) + await redis_client.sadd(token_key, token) + if not valid_tokens: + await redis_client.expire(token_key, expire_time) + + +async def get_valid_tokens(redis_client: Redis, user_id: UUID, token_type: TokenType): + token_key = f"user:{user_id}:{token_type}" + valid_tokens = await redis_client.smembers(token_key) + return valid_tokens + + +async def delete_tokens(redis_client: Redis, user: User, token_type: TokenType): + token_key = f"user:{user.id}:{token_type}" + valid_tokens = await redis_client.smembers(token_key) + if valid_tokens is not None: + await redis_client.delete(token_key)