Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import logging # allow-direct-logging
import os
import warnings

import pytest


def pytest_sessionstart(session) -> None:
if "LLAMA_STACK_LOGGING" not in os.environ:
Expand All @@ -17,4 +20,10 @@ def pytest_sessionstart(session) -> None:
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)


@pytest.fixture(autouse=True)
def suppress_httpx_logs(caplog):
"""Suppress httpx INFO logs for all unit tests"""
caplog.set_level(logging.WARNING, logger="httpx")


pytest_plugins = ["tests.unit.fixtures"]
32 changes: 21 additions & 11 deletions tests/unit/server/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import base64
import json
import logging # allow-direct-logging
from unittest.mock import AsyncMock, Mock, patch

import pytest
Expand All @@ -27,6 +28,13 @@
)


@pytest.fixture
def suppress_auth_errors(caplog):
"""Suppress expected ERROR/WARNING logs for tests that deliberately trigger authentication errors"""
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth")
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth_providers")


class MockResponse:
def __init__(self, status_code, json_data):
self.status_code = status_code
Expand Down Expand Up @@ -237,20 +245,20 @@ def test_valid_http_authentication(http_client, valid_api_key):


@patch("httpx.AsyncClient.post", new=mock_post_failure)
def test_invalid_http_authentication(http_client, invalid_api_key):
def test_invalid_http_authentication(http_client, invalid_api_key, suppress_auth_errors):
response = http_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
assert response.status_code == 401
assert "Authentication failed" in response.json()["error"]["message"]


@patch("httpx.AsyncClient.post", new=mock_post_exception)
def test_http_auth_service_error(http_client, valid_api_key):
def test_http_auth_service_error(http_client, valid_api_key, suppress_auth_errors):
response = http_client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
assert response.status_code == 401
assert "Authentication service error" in response.json()["error"]["message"]


def test_http_auth_request_payload(http_client, valid_api_key, mock_auth_endpoint):
def test_http_auth_request_payload(http_client, valid_api_key, mock_auth_endpoint, suppress_auth_errors):
with patch("httpx.AsyncClient.post") as mock_post:
mock_response = MockResponse(200, {"message": "Authentication successful"})
mock_post.return_value = mock_response
Expand Down Expand Up @@ -420,7 +428,7 @@ def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_u


@patch("httpx.AsyncClient.get", new=mock_jwks_response)
def test_invalid_oauth2_authentication(oauth2_client, invalid_token):
def test_invalid_oauth2_authentication(oauth2_client, invalid_token, suppress_auth_errors):
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
assert response.status_code == 401
assert "Invalid JWT token" in response.json()["error"]["message"]
Expand Down Expand Up @@ -465,7 +473,7 @@ def oauth2_client_with_jwks_token(oauth2_app_with_jwks_token):


@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid):
def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid, suppress_auth_errors):
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 401

Expand Down Expand Up @@ -726,21 +734,21 @@ def test_valid_introspection_authentication(introspection_client, valid_api_key)


@patch("httpx.AsyncClient.post", new=mock_introspection_inactive)
def test_inactive_introspection_authentication(introspection_client, invalid_api_key):
def test_inactive_introspection_authentication(introspection_client, invalid_api_key, suppress_auth_errors):
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
assert response.status_code == 401
assert "Token not active" in response.json()["error"]["message"]


@patch("httpx.AsyncClient.post", new=mock_introspection_invalid)
def test_invalid_introspection_authentication(introspection_client, invalid_api_key):
def test_invalid_introspection_authentication(introspection_client, invalid_api_key, suppress_auth_errors):
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
assert response.status_code == 401
assert "Not JSON" in response.json()["error"]["message"]


@patch("httpx.AsyncClient.post", new=mock_introspection_failed)
def test_failed_introspection_authentication(introspection_client, invalid_api_key):
def test_failed_introspection_authentication(introspection_client, invalid_api_key, suppress_auth_errors):
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
assert response.status_code == 401
assert "Token introspection failed: 500" in response.json()["error"]["message"]
Expand Down Expand Up @@ -957,20 +965,22 @@ def test_valid_kubernetes_auth_authentication(kubernetes_auth_client, valid_toke


@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_failure)
def test_invalid_kubernetes_auth_authentication(kubernetes_auth_client, invalid_token):
def test_invalid_kubernetes_auth_authentication(kubernetes_auth_client, invalid_token, suppress_auth_errors):
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
assert response.status_code == 401
assert "Invalid token" in response.json()["error"]["message"]


@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_http_error)
def test_kubernetes_auth_http_error(kubernetes_auth_client, valid_token):
def test_kubernetes_auth_http_error(kubernetes_auth_client, valid_token, suppress_auth_errors):
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
assert response.status_code == 401
assert "Token validation failed" in response.json()["error"]["message"]


def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mock_kubernetes_api_server):
def test_kubernetes_auth_request_payload(
kubernetes_auth_client, valid_token, mock_kubernetes_api_server, suppress_auth_errors
):
with patch("httpx.AsyncClient.post") as mock_post:
mock_response = MockResponse(
200,
Expand Down
10 changes: 9 additions & 1 deletion tests/unit/server/test_auth_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import logging # allow-direct-logging
from unittest.mock import AsyncMock, patch

import httpx
Expand All @@ -15,6 +16,13 @@
from llama_stack.core.server.auth import AuthenticationMiddleware


@pytest.fixture
def suppress_auth_errors(caplog):
"""Suppress expected ERROR logs for tests that deliberately trigger authentication errors"""
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth")
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth_providers")


class MockResponse:
def __init__(self, status_code, json_data):
self.status_code = status_code
Expand Down Expand Up @@ -119,7 +127,7 @@ def test_authenticated_endpoint_with_valid_github_token(mock_client_class, githu


@patch("llama_stack.core.server.auth_providers.httpx.AsyncClient")
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client):
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client, suppress_auth_errors):
"""Test accessing protected endpoint with invalid GitHub token"""
# Mock the GitHub API to return 401 Unauthorized
mock_client = AsyncMock()
Expand Down
15 changes: 11 additions & 4 deletions tests/unit/server/test_quota.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import logging # allow-direct-logging
from uuid import uuid4

import pytest
Expand All @@ -17,6 +18,12 @@
from llama_stack.providers.utils.kvstore import register_kvstore_backends


@pytest.fixture
def suppress_quota_warnings(caplog):
"""Suppress expected WARNING logs for SQLite backend and quota exceeded"""
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.quota")


class InjectClientIDMiddleware(BaseHTTPMiddleware):
"""
Middleware that injects 'authenticated_client_id' to mimic AuthenticationMiddleware.
Expand Down Expand Up @@ -70,13 +77,13 @@ async def test_endpoint():
return app


def test_authenticated_quota_allows_up_to_limit(auth_app):
def test_authenticated_quota_allows_up_to_limit(auth_app, suppress_quota_warnings):
client = TestClient(auth_app)
assert client.get("/test").status_code == 200
assert client.get("/test").status_code == 200


def test_authenticated_quota_blocks_after_limit(auth_app):
def test_authenticated_quota_blocks_after_limit(auth_app, suppress_quota_warnings):
client = TestClient(auth_app)
client.get("/test")
client.get("/test")
Expand All @@ -85,7 +92,7 @@ def test_authenticated_quota_blocks_after_limit(auth_app):
assert resp.json()["error"]["message"] == "Quota exceeded"


def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
def test_anonymous_quota_allows_up_to_limit(tmp_path, request, suppress_quota_warnings):
inner_app = FastAPI()

@inner_app.get("/test")
Expand All @@ -107,7 +114,7 @@ async def test_endpoint():
assert client.get("/test").status_code == 200


def test_anonymous_quota_blocks_after_limit(tmp_path, request):
def test_anonymous_quota_blocks_after_limit(tmp_path, request, suppress_quota_warnings):
inner_app = FastAPI()

@inner_app.get("/test")
Expand Down
11 changes: 10 additions & 1 deletion tests/unit/server/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,21 @@
# the root directory of this source tree.

import asyncio
import logging # allow-direct-logging
from unittest.mock import AsyncMock, MagicMock

import pytest

from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.core.server.server import create_dynamic_typed_route, create_sse_event, sse_generator


@pytest.fixture
def suppress_sse_errors(caplog):
"""Suppress expected ERROR logs for tests that deliberately trigger SSE errors"""
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.server")


async def test_sse_generator_basic():
# An AsyncIterator wrapped in an Awaitable, just like our web methods
async def async_event_gen():
Expand Down Expand Up @@ -70,7 +79,7 @@ async def async_event_gen():
assert len(seen_events) == 0


async def test_sse_generator_error_before_response_starts():
async def test_sse_generator_error_before_response_starts(suppress_sse_errors):
# Raise an error before the response starts
async def async_event_gen():
raise Exception("Test error")
Expand Down
Loading