diff --git a/app/auth/token_store.py b/app/auth/token_store.py index c02b852..adb6303 100644 --- a/app/auth/token_store.py +++ b/app/auth/token_store.py @@ -21,8 +21,10 @@ async def add(token_id: str, sibling_token_id: str, ttl: int) -> bool: Returns: bool: True if the token was successfully added to the store, False otherwise. """ + print(f"Adding token '{token_id}' to store, value '{sibling_token_id}'") return await cache.add(key=token_id, value=sibling_token_id, ttl=ttl) + async def exists(token_id: str) -> bool: """ Checks if a token exists in the cache store. @@ -58,4 +60,5 @@ async def remove(token_id: str) -> bool: Returns: bool: True if the token was successfully removed from the store, False otherwise. """ + print(f"Removing token '{token_id}' from token store") return await cache.delete(key=token_id) diff --git a/app/token.py b/app/token.py index 21b5b26..3165b5c 100644 --- a/app/token.py +++ b/app/token.py @@ -45,18 +45,17 @@ async def get_token_payload(token: Annotated[str, Depends(oauth2_scheme)]) -> Di session_id = payload.get("sid") if not await session_store.exists(user_id=user_id, session_id=session_id): - raise HTTPException(status_code=status_code, - detail="Session was revoked or expired") + raise HTTPException(status_code=status_code, detail="Session was revoked or expired") # check if the token is in active token cache token_id = payload.get("jti") if not await token_store.exists(token_id=token_id): # for security measures, revoke token's session as well + print(f"Token '{token_id}' was revoked or expired, removing session '{session_id}' as security measure") await session_store.remove(user_id=user_id, session_id=session_id) - raise HTTPException(status_code=status_code, - detail="Token was revoked or expired") + raise HTTPException(status_code=status_code, detail="Token was revoked or expired") return payload diff --git a/tests/auth/test_session_store.py b/tests/auth/test_session_store.py index d8c7a34..a18ca8f 100644 --- a/tests/auth/test_session_store.py +++ b/tests/auth/test_session_store.py @@ -4,7 +4,8 @@ import random from typing import Optional from datetime import datetime, timedelta -from app.auth.session_store import add, exists, update_last_activity, remove, retrieve_by_userid, retrieve +from app.auth import session_store +# from app.auth.session_store import add, exists, update_last_activity, remove, retrieve_by_userid, retrieve from app.auth.schemas import SessionInfo user_id = "user-" + str(uuid.uuid4())[5:] @@ -58,7 +59,7 @@ async def create_session_in_cache( value = SessionInfo(**session_info_dict) # Add the session to the cache and return the result - return await add(user_id, session_id, value, ttl) + return await session_store.add(user_id, session_id, value, ttl) @pytest.mark.asyncio async def test_add(): @@ -83,7 +84,7 @@ async def test_exists(): The expected result is True since we added the session in the previous test. """ # Call the function with a user_id and session_id - result = await exists(user_id, session_id) + result = await session_store.exists(user_id, session_id) # Assert that the result is True, indicating the session exists assert result is True @@ -103,21 +104,21 @@ async def test_expiration(): await create_session_in_cache(user_id, test_session_id, ttl=3) # check is the session exists - result = await exists(user_id, test_session_id) + result = await session_store.exists(user_id, test_session_id) # Assert that the result is True, indicating the session exists assert result is True # Wait 2s await asyncio.sleep(2) # check is the session exists - result = await exists(user_id, test_session_id) + result = await session_store.exists(user_id, test_session_id) # Assert that the result is True, indicating the session still exists assert result is True # Wait for the session to expire - await asyncio.sleep(5) + await asyncio.sleep(1) # check is the session exists - result = await exists(user_id, test_session_id) + result = await session_store.exists(user_id, test_session_id) # Assert that the result is False, indicating the session has expired assert result is False @@ -131,7 +132,7 @@ async def test_update_last_activity(): The expected result is True since we are updating the session. """ # Check if the session exists, if not create a new session - if not await exists(user_id, session_id): + if not await session_store.exists(user_id, session_id): print("Session not exists, creating session in cache...") await create_session_in_cache(user_id=user_id, session_id=session_id) @@ -139,7 +140,7 @@ async def test_update_last_activity(): token_payload = {"sub": user_id, "sid": session_id} # update existing session last activity - result = await update_last_activity(token_payload) + result = await session_store.update_last_activity(token_payload) # Assert that the result is True, indicating the session was successfully updated assert result is True @@ -148,7 +149,7 @@ async def test_update_last_activity(): token_payload = {"sub": user_id, "sid": "session-" + str(uuid.uuid4())[8:]} # update non-existing session last activity - result = await update_last_activity(token_payload) + result = await session_store.update_last_activity(token_payload) # Assert that the result is False, indicating the session was not updated assert result is False @@ -162,14 +163,14 @@ async def test_remove(): The expected result is 1 since we are removing one session. """ # Call the function with a user_id and session_id - result = await remove(user_id, session_id) + result = await session_store.remove(user_id, session_id) # Assert that the result is 1, indicating one session was successfully removed assert result == 1 # expect raise ValueError because no user_id is provided with pytest.raises(ValueError): - await remove("", session_id) + await session_store.remove("", session_id) @pytest.mark.asyncio async def test_remove_all_user_sessions(): @@ -188,7 +189,7 @@ async def test_remove_all_user_sessions(): await create_session_in_cache(user_id) # Call the remove function to delete all sessions of the user - result = await remove(user_id) + result = await session_store.remove(user_id) # Assert that the result is greater than 2, indicating that more than one session was removed assert result >= 2 @@ -206,13 +207,13 @@ async def test_retrieve_by_userid(): session_id2 = "session-" + str(uuid.uuid4())[8:] # Check if the sessions exist, if not create new sessions - if not await exists(user_id, session_id): + if not await session_store.exists(user_id, session_id): print("Session not exists, creating 2 sessions in cache...") await create_session_in_cache(user_id=user_id, session_id=session_id) await create_session_in_cache(user_id=user_id, session_id=session_id2) # Call the function with the user_id - sessions = await retrieve_by_userid(user_id) + sessions = await session_store.retrieve_by_userid(user_id) # Assert that the sessions dictionary is not empty assert sessions @@ -237,12 +238,12 @@ async def test_retrieve(): It first ensures the session exists, then retrieves the session and checks the result. """ # Check if the session exists, if not create a new session - if not await exists(user_id, session_id): + if not await session_store.exists(user_id, session_id): print("Session not exists, creating session in cache...") await create_session_in_cache(user_id=user_id, session_id=session_id) # Call the function with the user_id and session_id - session = await retrieve(user_id, session_id) + session = await session_store.retrieve(user_id, session_id) # print(session) # Assert that the session is not None @@ -253,3 +254,43 @@ async def test_retrieve(): # Assert that the session contains the correct session_id assert session.session_id == session_id + +@pytest.mark.asyncio +async def test_update(): + """ + Test the update function. + + This test checks if a session can be successfully updated in the cache. + It first ensures the session exists, then updates the session and checks the result. + """ + # Check if the session exists, if not create a new session + if not await session_store.exists(user_id, session_id): + print("Session not exists, creating session in cache...") + await create_session_in_cache(user_id=user_id, session_id=session_id) + + # Generate a new user agent for testing + new_user_agent = "Edge/12.0.0" + + # Retrieve the session from the cache + session: SessionInfo = await session_store.retrieve(user_id, session_id) + + # Update the user agent + session.user_agent = new_user_agent + + # Update the session in the cache + result = await session_store.update(user_id, session_id, session, 5) + + # Assert that the result is True, indicating the session was successfully updated + assert result is True + + # Retrieve the session from the cache + session: SessionInfo = await session_store.retrieve(user_id, session_id) + + # Assert that the user agent has been updated + assert session.user_agent == new_user_agent + + # Wait for the session to expire + await asyncio.sleep(5) + result = await session_store.exists(user_id, session_id) + # Assert that the result is False, indicating the session has expired + assert result is False