Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/app/endpoints/authorized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Handler for REST API call to authorized endpoint."""

import asyncio
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove unused asyncio import

The asyncio import is only used for the problematic asyncio.run() call. After fixing the async/sync pattern, this import should be removed.

-import asyncio
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import asyncio
🤖 Prompt for AI Agents
In src/app/endpoints/authorized.py at line 3, remove the unused import of
asyncio since it is only needed for the asyncio.run() call, which will be
removed when fixing the async/sync pattern. Delete the line importing asyncio to
clean up unused imports.

import logging
from typing import Any

from fastapi import APIRouter, Request

from auth import get_auth_dependency
from models.responses import AuthorizedResponse, UnauthorizedResponse, ForbiddenResponse

logger = logging.getLogger(__name__)
router = APIRouter(tags=["authorized"])
auth_dependency = get_auth_dependency()


authorized_responses: dict[int | str, dict[str, Any]] = {
200: {
"description": "The user is logged-in and authorized to access OLS",
"model": AuthorizedResponse,
},
400: {
"description": "Missing or invalid credentials provided by client",
"model": UnauthorizedResponse,
},
403: {
"description": "User is not authorized",
"model": ForbiddenResponse,
},
}


@router.post("/authorized", responses=authorized_responses)
def authorized_endpoint_handler(_request: Request) -> AuthorizedResponse:
"""Handle request to the /authorized endpoint."""
# Ignore the user token, we should not return it in the response
user_id, user_name, _ = asyncio.run(auth_dependency(_request))
return AuthorizedResponse(user_id=user_id, username=user_name)
Comment on lines +33 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Fix async/sync pattern and use proper dependency injection

The current implementation has several issues:

  1. Async/sync mismatch: Using asyncio.run() in a sync handler can cause problems in FastAPI's async context
  2. Bypassing dependency injection: The auth dependency should be injected using FastAPI's Depends() mechanism
  3. Error handling: No explicit error handling for authentication failures

Apply this diff to fix the implementation:

-@router.post("/authorized", responses=authorized_responses)
-def authorized_endpoint_handler(_request: Request) -> AuthorizedResponse:
-    """Handle request to the /authorized endpoint."""
-    # Ignore the user token, we should not return it in the response
-    user_id, user_name, _ = asyncio.run(auth_dependency(_request))
-    return AuthorizedResponse(user_id=user_id, username=user_name)
+@router.post("/authorized", responses=authorized_responses)
+async def authorized_endpoint_handler(
+    auth_result: tuple[str, str, str] = Depends(auth_dependency)
+) -> AuthorizedResponse:
+    """Handle request to the /authorized endpoint."""
+    # Ignore the user token, we should not return it in the response
+    user_id, user_name, _ = auth_result
+    return AuthorizedResponse(user_id=user_id, username=user_name)

This approach:

  • Makes the handler async to properly work with FastAPI
  • Uses FastAPI's dependency injection system
  • Allows FastAPI to handle authentication errors automatically
  • Eliminates the problematic asyncio.run() call

Let me verify how the auth dependency is typically used in other endpoints:


🏁 Script executed:

#!/bin/bash
# Search for other usages of auth_dependency in the codebase
rg -A 5 -B 5 "auth_dependency" --type py

Length of output: 19722


🏁 Script executed:

#!/bin/bash
rg -A 20 -B 5 "def get_auth_dependency" src/auth/__init__.py

Length of output: 901


🏁 Script executed:

#!/bin/bash
rg -A 20 -B 0 "def test_authorized_endpoint" -n tests/unit/app/endpoints/test_authorized.py

Length of output: 796


🏁 Script executed:

#!/bin/bash
rg -A 5 -B 2 "class AuthInterface" --type py

Length of output: 531


Use FastAPI’s dependency injection instead of asyncio.run and update tests

The /authorized handler should let FastAPI call the async auth dependency rather than spinning up a new event loop. Please:

• In src/app/endpoints/authorized.py:

  • Remove the Request parameter and asyncio.run(...) call
  • Inject the dependency with Depends(auth_dependency)
  • Mark the handler async def (so FastAPI can await the auth)
  • Import Depends from fastapi

• In your unit tests (tests/unit/app/endpoints/test_authorized.py):

  • Stop calling the handler directly with a raw Request
  • Use FastAPI’s TestClient (or AsyncClient) and app.dependency_overrides to stub out auth_dependency
  • Assert on the full HTTP response instead of the handler return value

Example diff for the endpoint:

--- a/src/app/endpoints/authorized.py
+++ b/src/app/endpoints/authorized.py
@@
-from fastapi import APIRouter, Request
+from fastapi import APIRouter, Depends

@@
-@router.post("/authorized", responses=authorized_responses)
-def authorized_endpoint_handler(_request: Request) -> AuthorizedResponse:
-    """Handle request to the /authorized endpoint."""
-    # Ignore the user token, we should not return it in the response
-    user_id, user_name, _ = asyncio.run(auth_dependency(_request))
-    return AuthorizedResponse(user_id=user_id, username=user_name)
+@router.post("/authorized", responses=authorized_responses)
+async def authorized_endpoint_handler(
+    auth_result: tuple[str, str, str] = Depends(auth_dependency)
+) -> AuthorizedResponse:
+    """Handle request to the /authorized endpoint."""
+    user_id, user_name, _ = auth_result
+    return AuthorizedResponse(user_id=user_id, username=user_name)

And a sketch for updating the test:

from fastapi.testclient import TestClient
from app.main import app
from app.endpoints.authorized import auth_dependency

def test_authorized_endpoint():
    app.dependency_overrides[auth_dependency] = lambda req: ("test-id", "test-user", None)
    client = TestClient(app)
    response = client.post("/authorized")
    assert response.status_code == 200
    assert response.json() == {"user_id": "test-id", "username": "test-user"}
🤖 Prompt for AI Agents
In src/app/endpoints/authorized.py lines 33-38, replace the synchronous handler
using asyncio.run with an async def function that injects auth_dependency via
FastAPI's Depends. Remove the Request parameter and import Depends from fastapi.
Then, in tests/unit/app/endpoints/test_authorized.py, stop calling the handler
directly with a raw Request; instead, use FastAPI's TestClient and override
auth_dependency with app.dependency_overrides to provide a stub. Finally, make
assertions on the full HTTP response from the client.post call rather than the
handler's return value.

15 changes: 14 additions & 1 deletion src/app/endpoints/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@

from auth import get_auth_dependency
from configuration import configuration
from models.responses import FeedbackResponse, StatusResponse
from models.responses import (
FeedbackResponse,
StatusResponse,
UnauthorizedResponse,
ForbiddenResponse,
)
from models.requests import FeedbackRequest
from utils.suid import get_suid
from utils.common import retrieve_user_id
Expand All @@ -22,6 +27,14 @@
# Response for the feedback endpoint
feedback_response: dict[int | str, dict[str, Any]] = {
200: {"response": "Feedback received and stored"},
400: {
"description": "Missing or invalid credentials provided by client",
"model": UnauthorizedResponse,
},
403: {
"description": "User is not authorized",
"model": ForbiddenResponse,
},
}


Expand Down
10 changes: 9 additions & 1 deletion src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from client import LlamaStackClientHolder
from configuration import configuration
from models.responses import QueryResponse
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
from models.requests import QueryRequest, Attachment
import constants
from auth import get_auth_dependency
Expand All @@ -44,6 +44,14 @@
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "LLM ansert",
},
400: {
"description": "Missing or invalid credentials provided by client",
"model": UnauthorizedResponse,
},
403: {
"description": "User is not authorized",
"model": ForbiddenResponse,
},
503: {
"detail": {
"response": "Unable to connect to Llama Stack",
Expand Down
2 changes: 2 additions & 0 deletions src/app/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
config,
feedback,
streaming_query,
authorized,
)


Expand All @@ -28,3 +29,4 @@ def include_routers(app: FastAPI) -> None:
app.include_router(config.router, prefix="/v1")
app.include_router(feedback.router, prefix="/v1")
app.include_router(streaming_query.router, prefix="/v1")
app.include_router(authorized.router, prefix="/v1")
56 changes: 56 additions & 0 deletions src/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,59 @@ class StatusResponse(BaseModel):
]
}
}


class AuthorizedResponse(BaseModel):
"""Model representing a response to an authorization request.

Attributes:
user_id: The ID of the logged in user.
username: The name of the logged in user.
"""

user_id: str
username: str

# provides examples for /docs endpoint
model_config = {
"json_schema_extra": {
"examples": [
{
"user_id": "123e4567-e89b-12d3-a456-426614174000",
"username": "user1",
}
]
}
}


class UnauthorizedResponse(BaseModel):
"""Model representing response for missing or invalid credentials."""

detail: str

# provides examples for /docs endpoint
model_config = {
"json_schema_extra": {
"examples": [
{
"detail": "Unauthorized: No auth header found",
},
]
}
}


class ForbiddenResponse(UnauthorizedResponse):
"""Model representing response for forbidden access."""

# provides examples for /docs endpoint
model_config = {
"json_schema_extra": {
"examples": [
{
"detail": "Forbidden: User is not authorized to access this resource",
},
]
}
}
54 changes: 54 additions & 0 deletions tests/unit/app/endpoints/test_authorized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from unittest.mock import AsyncMock

import pytest
from fastapi import Request, HTTPException

from app.endpoints.authorized import authorized_endpoint_handler


def test_authorized_endpoint(mocker):
"""Test the authorized endpoint handler."""
# Mock the auth dependency to return a user ID and username
auth_dependency_mock = AsyncMock()
auth_dependency_mock.return_value = ("test-id", "test-user", None)
mocker.patch(
"app.endpoints.authorized.auth_dependency", side_effect=auth_dependency_mock
)

request = Request(
scope={
"type": "http",
"query_string": b"",
}
)

response = authorized_endpoint_handler(request)

assert response.model_dump() == {
"user_id": "test-id",
"username": "test-user",
}


def test_authorized_unauthorized(mocker):
"""Test the authorized endpoint handler with a custom user ID."""
auth_dependency_mock = AsyncMock()
auth_dependency_mock.side_effect = HTTPException(
status_code=403, detail="User is not authorized"
)
mocker.patch(
"app.endpoints.authorized.auth_dependency", side_effect=auth_dependency_mock
)

request = Request(
scope={
"type": "http",
"query_string": b"",
}
)

with pytest.raises(HTTPException) as exc_info:
authorized_endpoint_handler(request)

assert exc_info.value.status_code == 403
assert exc_info.value.detail == "User is not authorized"
4 changes: 3 additions & 1 deletion tests/unit/app/test_routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
config,
feedback,
streaming_query,
authorized,
) # noqa:E402


Expand All @@ -34,7 +35,7 @@ def test_include_routers() -> None:
include_routers(app)

# are all routers added?
assert len(app.routers) == 8
assert len(app.routers) == 9
assert root.router in app.routers
assert info.router in app.routers
assert models.router in app.routers
Expand All @@ -43,3 +44,4 @@ def test_include_routers() -> None:
assert config.router in app.routers
assert feedback.router in app.routers
assert streaming_query.router in app.routers
assert authorized.router in app.routers
31 changes: 30 additions & 1 deletion tests/unit/models/test_responses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from models.responses import QueryResponse, StatusResponse
from models.responses import (
QueryResponse,
StatusResponse,
AuthorizedResponse,
UnauthorizedResponse,
)


class TestQueryResponse:
Expand Down Expand Up @@ -28,3 +33,27 @@ def test_constructor(self) -> None:
sr = StatusResponse(functionality="feedback", status={"enabled": True})
assert sr.functionality == "feedback"
assert sr.status == {"enabled": True}


class TestAuthorizedResponse:
"""Test cases for the AuthorizedResponse model."""

def test_constructor(self) -> None:
"""Test the AuthorizedResponse constructor."""
ar = AuthorizedResponse(
user_id="123e4567-e89b-12d3-a456-426614174000",
username="testuser",
)
assert ar.user_id == "123e4567-e89b-12d3-a456-426614174000"
assert ar.username == "testuser"


class TestUnauthorizedResponse:
"""Test cases for the UnauthorizedResponse model."""

def test_constructor(self) -> None:
"""Test the UnauthorizedResponse constructor."""
ur = UnauthorizedResponse(
detail="Missing or invalid credentials provided by client"
)
assert ur.detail == "Missing or invalid credentials provided by client"