Skip to content

Commit

Permalink
remove passlib and add documents
Browse files Browse the repository at this point in the history
issue : pyca/bcrypt#684
passlib seems to be unmaintained!
so i rewrite the hash and verify password my self with pure bcrypt.
  • Loading branch information
houshmand-2005 committed Feb 26, 2024
1 parent 50d129b commit 15b8334
Showing 1 changed file with 120 additions and 32 deletions.
152 changes: 120 additions & 32 deletions backend/chat/utils/jwt.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,100 @@
from datetime import datetime, timedelta
from typing import Annotated

from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.orm import Session
import bcrypt
from chat import models
from chat.database import get_db
from chat.schema import TokenData, User
from chat.setting import setting
from chat.utils.exception import CredentialsException
from chat.database import SessionLocal, engine
import chat.models as models
from chat.schema import TokenData, User
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy.orm import Session

models.Base.metadata.create_all(bind=engine)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")


def get_db() -> Session:
db = SessionLocal()
try:
yield db
finally:
db.close()
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verifies if the provided plain text password matches the stored hashed password.
Args:
plain_password: The plain text password entered by the user.
hashed_password: The stored hashed password from the database.
Returns:
True if the passwords match, False otherwise.
"""
encoded_hashed_password = hashed_password.encode("utf-8")
return bcrypt.checkpw(
plain_password.encode("utf-8"),
encoded_hashed_password,
)

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def get_password_hash(password: str) -> str:
"""
Generates a bcrypt hash for the provided password.
Args:
password: The plain text password to hash.
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
Returns:
The generated password hash.
"""
hashed_bytes = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt())
return hashed_bytes.decode("utf-8")


def get_password_hash(password):
return pwd_context.hash(password)
def get_user(user_db: Session, username: str) -> models.User:
"""
Retrieve a user from the database based on the username.
Args:
user_db (Session): The database session.
username (str): The username of the user to retrieve.
def get_user(user_db: Session, username: str):
Returns:
Optional[models.User]: The user object if found, None otherwise.
"""
user = user_db.query(models.User).filter(models.User.username == username).first()
return user


def authenticate_user(user_db: Session, username: str, password: str):
def authenticate_user(
user_db: Session, username: str, password: str
) -> models.User | None:
"""
Authenticate a user based on the provided username and password.
Args:
user_db (Session): The database session.
username (str): The username of the user to authenticate.
password (str): The password of the user to authenticate.
Returns:
Optional[models.User]: The authenticated user object if successful, None otherwise.
"""
user = get_user(user_db, username)
if not user:
return False
return None
if not verify_password(password, user.password):
return False
return None
return user


def create_access_token(data: dict, expires_delta: timedelta | None = None):
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
"""
Create an access token with the provided data.
Args:
data (dict): The data to include in the token payload.
expires_delta (timedelta, optional): The expiration time delta for the token. Defaults to None.
Returns:
str: The generated access token.
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
Expand All @@ -71,7 +115,17 @@ async def get_current_user(
user_db: Session = Depends(
get_db,
),
):
) -> User:
"""
Get the current authenticated user from the provided token.
Args:
token (str): The JWT token representing the user.
user_db (Session, optional): The database session. Defaults to Depends(get_db).
Returns:
models.User: The current authenticated user.
"""
token_data = decode_jwt(token)
user = get_user(user_db, username=token_data.username)
if user is None:
Expand All @@ -80,14 +134,37 @@ async def get_current_user(


async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)]
):
current_user: Annotated[User, Depends(get_current_user)],
) -> User:
"""
Get the current active authenticated user.
Args:
current_user (User): The current authenticated user.
Raises:
HTTPException: If the user is inactive.
Returns:
models.User: The current active authenticated user.
"""
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
raise HTTPException(
status_code=400, detail="Inactive user"
) # TODO Add this to Exceptions
return current_user


def get_admin_payload(token: str):
def get_admin_payload(token: str) -> dict | None:
"""
Decode the payload of the provided JWT token for admin user.
Args:
token (str): The JWT token to decode.
Returns:
Optional[dict]: The payload data containing username and id if the token is valid, None otherwise.
"""
try:
payload = jwt.decode(token, setting.SECRET_KEY, setting.ALGORITHM)
username: str = payload.get("username")
Expand All @@ -97,7 +174,18 @@ def get_admin_payload(token: str):
return


def decode_jwt(token: Annotated[str, Depends(oauth2_scheme)]) -> TokenData:
def decode_jwt(
token: Annotated[str, Depends(oauth2_scheme)]
) -> TokenData | CredentialsException:
"""
Decode the provided JWT token and extract the token data.
Args:
token (str): The JWT token to decode.
Returns:
TokenData: The token data containing username and id.
"""
try:
payload = jwt.decode(
token,
Expand Down

0 comments on commit 15b8334

Please sign in to comment.