Skip to content

Commit

Permalink
add refresh token logic (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianjing-li committed Jun 19, 2024
1 parent a2ec56a commit 9d7564c
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 7 deletions.
9 changes: 7 additions & 2 deletions src/backend/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def get_strategies() -> list[ListAuthStrategy]:
if hasattr(strategy_instance, "get_authorization_endpoint")
else None
),
"refresh_token_params": (
strategy_instance.get_refresh_token_params()
if hasattr(strategy_instance, "get_refresh_token_params")
else None
),
}
)

Expand Down Expand Up @@ -94,7 +99,7 @@ async def login(request: Request, login: Login, session: DBSessionDep):
detail=f"Error performing {strategy_name} authentication with payload: {payload}.",
)

token = JWTService().create_and_encode_jwt(user)
token = JWTService().create_and_encode_jwt(user, strategy_name)

return {"token": token}

Expand Down Expand Up @@ -188,6 +193,6 @@ async def authorize(
# Get or create user, then set session user
user = get_or_create_user(session, userinfo)

token = JWTService().create_and_encode_jwt(user)
token = JWTService().create_and_encode_jwt(user, strategy_name)

return {"token": token}
3 changes: 2 additions & 1 deletion src/backend/services/auth/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self):

self.secret_key = secret_key

def create_and_encode_jwt(self, user: dict) -> str:
def create_and_encode_jwt(self, user: dict, strategy_name: str) -> str:
"""
Creates a payload based on user info and creates a JWT token.
Expand All @@ -41,6 +41,7 @@ def create_and_encode_jwt(self, user: dict) -> str:
"iat": now,
"exp": now + datetime.timedelta(hours=self.EXPIRY_HOURS),
"jti": str(uuid.uuid4()),
"strategy": strategy_name,
"context": user,
}

Expand Down
8 changes: 8 additions & 0 deletions src/backend/services/auth/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ def get_authorization_endpoint(self, **kwargs: Any):
"""
...

@abstractmethod
def get_refresh_token_params(self, **kwargs: Any):
"""
Retrieves the OAuth app's refresh token query parameters,
returned in dict format.
"""
...

@abstractmethod
async def get_endpoints(self, **kwargs: Any):
"""
Expand Down
7 changes: 7 additions & 0 deletions src/backend/services/auth/strategies/google_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def get_client_id(self):
def get_authorization_endpoint(self):
return self.AUTHORIZATION_ENDPOINT

def get_refresh_token_params(self):
return {"access_type": "offline", "prompt": "consent"}

async def get_endpoints(self):
response = requests.get(self.WELL_KNOWN_ENDPOINT)
endpoints = response.json()
Expand Down Expand Up @@ -68,6 +71,10 @@ async def authorize(self, request: Request) -> dict | None:
authorization_response=str(request.url),
redirect_uri=self.REDIRECT_URI,
)

import pdb

pdb.set_trace()
user_info = self.client.get(self.USERINFO_ENDPOINT)

return user_info.json()
3 changes: 3 additions & 0 deletions src/backend/services/auth/strategies/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def get_client_id(self):
def get_authorization_endpoint(self):
return self.AUTHORIZATION_ENDPOINT

def get_refresh_token_params(self):
return None

async def get_endpoints(self):
response = requests.get(self.WELL_KNOWN_ENDPOINT)
endpoints = response.json()
Expand Down
6 changes: 3 additions & 3 deletions src/backend/tests/routers/auth/test_authorization_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_validate_authorization_valid_token(
session_client: TestClient,
):
user = {"user_id": "test"}
token = JWTService().create_and_encode_jwt(user)
token = JWTService().create_and_encode_jwt(user, "")

# Use /logout endpoint to test request validator
response = session_client.get(
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_validate_authorization_invalid_token():
def test_validate_authorization_expired_token():
user = {"user_id": "test"}
with freezegun.freeze_time("2024-01-01 00:00:00"):
token = JWTService().create_and_encode_jwt(user)
token = JWTService().create_and_encode_jwt(user, "")

request_mock = MagicMock(headers={"Authorization": f"Bearer {token}"})

Expand All @@ -85,7 +85,7 @@ def test_validate_authorization_blacklisted_token(
session_client: TestClient, session: Session
):
user = {"user_id": "test"}
token = JWTService().create_and_encode_jwt(user)
token = JWTService().create_and_encode_jwt(user, "")
decoded = JWTService().decode_jwt(token)

# Create a Blacklist entry
Expand Down
2 changes: 1 addition & 1 deletion src/backend/tests/routers/auth/test_basic_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_login_no_payload(session_client: TestClient):

def test_logout_success(session_client: TestClient, session: Session):
user = {"user_id": "test"}
token = JWTService().create_and_encode_jwt(user)
token = JWTService().create_and_encode_jwt(user, "")
decoded = JWTService().decode_jwt(token)

response = session_client.get(
Expand Down

0 comments on commit 9d7564c

Please sign in to comment.