Skip to content

Commit

Permalink
Refactor session_store imports and update tests
Browse files Browse the repository at this point in the history
Added test_update
  • Loading branch information
Dennis Lee committed Mar 6, 2024
1 parent 4d68012 commit e3f29f0
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 21 deletions.
3 changes: 3 additions & 0 deletions app/auth/token_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ async def add(token_id: str, sibling_token_id: str, ttl: int) -> bool:
Returns:
bool: True if the token was successfully added to the store, False otherwise.
"""
print(f"Adding token '{token_id}' to store, value '{sibling_token_id}'")
return await cache.add(key=token_id, value=sibling_token_id, ttl=ttl)


async def exists(token_id: str) -> bool:
"""
Checks if a token exists in the cache store.
Expand Down Expand Up @@ -58,4 +60,5 @@ async def remove(token_id: str) -> bool:
Returns:
bool: True if the token was successfully removed from the store, False otherwise.
"""
print(f"Removing token '{token_id}' from token store")
return await cache.delete(key=token_id)
7 changes: 3 additions & 4 deletions app/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,17 @@ async def get_token_payload(token: Annotated[str, Depends(oauth2_scheme)]) -> Di
session_id = payload.get("sid")

if not await session_store.exists(user_id=user_id, session_id=session_id):
raise HTTPException(status_code=status_code,
detail="Session was revoked or expired")
raise HTTPException(status_code=status_code, detail="Session was revoked or expired")

# check if the token is in active token cache
token_id = payload.get("jti")

if not await token_store.exists(token_id=token_id):
# for security measures, revoke token's session as well
print(f"Token '{token_id}' was revoked or expired, removing session '{session_id}' as security measure")
await session_store.remove(user_id=user_id, session_id=session_id)

raise HTTPException(status_code=status_code,
detail="Token was revoked or expired")
raise HTTPException(status_code=status_code, detail="Token was revoked or expired")

return payload

Expand Down
75 changes: 58 additions & 17 deletions tests/auth/test_session_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import random
from typing import Optional
from datetime import datetime, timedelta
from app.auth.session_store import add, exists, update_last_activity, remove, retrieve_by_userid, retrieve
from app.auth import session_store
# from app.auth.session_store import add, exists, update_last_activity, remove, retrieve_by_userid, retrieve
from app.auth.schemas import SessionInfo

user_id = "user-" + str(uuid.uuid4())[5:]
Expand Down Expand Up @@ -58,7 +59,7 @@ async def create_session_in_cache(
value = SessionInfo(**session_info_dict)

# Add the session to the cache and return the result
return await add(user_id, session_id, value, ttl)
return await session_store.add(user_id, session_id, value, ttl)

@pytest.mark.asyncio
async def test_add():
Expand All @@ -83,7 +84,7 @@ async def test_exists():
The expected result is True since we added the session in the previous test.
"""
# Call the function with a user_id and session_id
result = await exists(user_id, session_id)
result = await session_store.exists(user_id, session_id)

# Assert that the result is True, indicating the session exists
assert result is True
Expand All @@ -103,21 +104,21 @@ async def test_expiration():
await create_session_in_cache(user_id, test_session_id, ttl=3)

# check is the session exists
result = await exists(user_id, test_session_id)
result = await session_store.exists(user_id, test_session_id)
# Assert that the result is True, indicating the session exists
assert result is True

# Wait 2s
await asyncio.sleep(2)
# check is the session exists
result = await exists(user_id, test_session_id)
result = await session_store.exists(user_id, test_session_id)
# Assert that the result is True, indicating the session still exists
assert result is True

# Wait for the session to expire
await asyncio.sleep(5)
await asyncio.sleep(1)
# check is the session exists
result = await exists(user_id, test_session_id)
result = await session_store.exists(user_id, test_session_id)
# Assert that the result is False, indicating the session has expired
assert result is False

Expand All @@ -131,15 +132,15 @@ async def test_update_last_activity():
The expected result is True since we are updating the session.
"""
# Check if the session exists, if not create a new session
if not await exists(user_id, session_id):
if not await session_store.exists(user_id, session_id):
print("Session not exists, creating session in cache...")
await create_session_in_cache(user_id=user_id, session_id=session_id)

# Create a valid token payload
token_payload = {"sub": user_id, "sid": session_id}

# update existing session last activity
result = await update_last_activity(token_payload)
result = await session_store.update_last_activity(token_payload)

# Assert that the result is True, indicating the session was successfully updated
assert result is True
Expand All @@ -148,7 +149,7 @@ async def test_update_last_activity():
token_payload = {"sub": user_id, "sid": "session-" + str(uuid.uuid4())[8:]}

# update non-existing session last activity
result = await update_last_activity(token_payload)
result = await session_store.update_last_activity(token_payload)

# Assert that the result is False, indicating the session was not updated
assert result is False
Expand All @@ -162,14 +163,14 @@ async def test_remove():
The expected result is 1 since we are removing one session.
"""
# Call the function with a user_id and session_id
result = await remove(user_id, session_id)
result = await session_store.remove(user_id, session_id)

# Assert that the result is 1, indicating one session was successfully removed
assert result == 1

# expect raise ValueError because no user_id is provided
with pytest.raises(ValueError):
await remove("", session_id)
await session_store.remove("", session_id)

@pytest.mark.asyncio
async def test_remove_all_user_sessions():
Expand All @@ -188,7 +189,7 @@ async def test_remove_all_user_sessions():
await create_session_in_cache(user_id)

# Call the remove function to delete all sessions of the user
result = await remove(user_id)
result = await session_store.remove(user_id)

# Assert that the result is greater than 2, indicating that more than one session was removed
assert result >= 2
Expand All @@ -206,13 +207,13 @@ async def test_retrieve_by_userid():
session_id2 = "session-" + str(uuid.uuid4())[8:]

# Check if the sessions exist, if not create new sessions
if not await exists(user_id, session_id):
if not await session_store.exists(user_id, session_id):
print("Session not exists, creating 2 sessions in cache...")
await create_session_in_cache(user_id=user_id, session_id=session_id)
await create_session_in_cache(user_id=user_id, session_id=session_id2)

# Call the function with the user_id
sessions = await retrieve_by_userid(user_id)
sessions = await session_store.retrieve_by_userid(user_id)

# Assert that the sessions dictionary is not empty
assert sessions
Expand All @@ -237,12 +238,12 @@ async def test_retrieve():
It first ensures the session exists, then retrieves the session and checks the result.
"""
# Check if the session exists, if not create a new session
if not await exists(user_id, session_id):
if not await session_store.exists(user_id, session_id):
print("Session not exists, creating session in cache...")
await create_session_in_cache(user_id=user_id, session_id=session_id)

# Call the function with the user_id and session_id
session = await retrieve(user_id, session_id)
session = await session_store.retrieve(user_id, session_id)
# print(session)

# Assert that the session is not None
Expand All @@ -253,3 +254,43 @@ async def test_retrieve():

# Assert that the session contains the correct session_id
assert session.session_id == session_id

@pytest.mark.asyncio
async def test_update():
"""
Test the update function.
This test checks if a session can be successfully updated in the cache.
It first ensures the session exists, then updates the session and checks the result.
"""
# Check if the session exists, if not create a new session
if not await session_store.exists(user_id, session_id):
print("Session not exists, creating session in cache...")
await create_session_in_cache(user_id=user_id, session_id=session_id)

# Generate a new user agent for testing
new_user_agent = "Edge/12.0.0"

# Retrieve the session from the cache
session: SessionInfo = await session_store.retrieve(user_id, session_id)

# Update the user agent
session.user_agent = new_user_agent

# Update the session in the cache
result = await session_store.update(user_id, session_id, session, 5)

# Assert that the result is True, indicating the session was successfully updated
assert result is True

# Retrieve the session from the cache
session: SessionInfo = await session_store.retrieve(user_id, session_id)

# Assert that the user agent has been updated
assert session.user_agent == new_user_agent

# Wait for the session to expire
await asyncio.sleep(5)
result = await session_store.exists(user_id, session_id)
# Assert that the result is False, indicating the session has expired
assert result is False

0 comments on commit e3f29f0

Please sign in to comment.