Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/llama_stack/core/routers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import TypeAdapter

from llama_stack.core.access_control.access_control import is_action_allowed
from llama_stack.core.datatypes import ModelWithOwner
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack_api import (
Expand Down Expand Up @@ -93,15 +96,41 @@ async def _get_model_provider(self, model_id: str, expected_model_type: str) ->
provider = await self.routing_table.get_provider_impl(model.identifier)
return provider, model.provider_resource_id

# Handles cases where clients use the provider format directly
return await self._get_provider_by_fallback(model_id, expected_model_type)

async def _get_provider_by_fallback(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
"""
Handle fallback case where model_id is in provider_id/provider_resource_id format.
"""
splits = model_id.split("/", maxsplit=1)
if len(splits) != 2:
raise ModelNotFoundError(model_id)

provider_id, provider_resource_id = splits

# Check if provider exists
if provider_id not in self.routing_table.impls_by_provider_id:
logger.warning(f"Provider {provider_id} not found for model {model_id}")
raise ModelNotFoundError(model_id)

# Create a temporary model object for RBAC check
temp_model = ModelWithOwner(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this reads awkward, maybe we need to think about the RBAC API a bit so this "temp model" creation is not necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, creating that seemed a little excessive but I couldn't see any way around it

identifier=model_id,
provider_id=provider_id,
provider_resource_id=provider_resource_id,
model_type=expected_model_type,
metadata={}, # Empty metadata for temporary object
)

# Perform RBAC check
user = get_authenticated_user()
if not is_action_allowed(self.routing_table.policy, "read", temp_model, user):
logger.debug(
f"Access denied to model '{model_id}' via fallback path for user {user.principal if user else 'anonymous'}"
)
raise ModelNotFoundError(model_id)

return self.routing_table.impls_by_provider_id[provider_id], provider_resource_id

async def rerank(
Expand Down
29 changes: 24 additions & 5 deletions src/llama_stack/core/routing_tables/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import time
from typing import Any

from llama_stack.core.access_control.access_control import is_action_allowed
from llama_stack.core.datatypes import (
ModelWithOwner,
RegistryEntrySource,
)
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData, get_authenticated_user
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack_api import (
Expand Down Expand Up @@ -66,6 +67,7 @@ async def _get_dynamic_models_from_provider_data(self) -> list[Model]:
return []

dynamic_models = []
user = get_authenticated_user()

for provider_id, provider in self.impls_by_provider_id.items():
# Check if this provider supports provider_data
Expand Down Expand Up @@ -93,15 +95,32 @@ async def _get_dynamic_models_from_provider_data(self) -> list[Model]:
if not models:
continue

# Ensure models have fully qualified identifiers with provider_id prefix
# Ensure models have fully qualified identifiers and apply RBAC filtering
for model in models:
# Only add prefix if model identifier doesn't already have it
if not model.identifier.startswith(f"{provider_id}/"):
model.identifier = f"{provider_id}/{model.provider_resource_id}"

dynamic_models.append(model)

logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data")
# Convert to ModelWithOwner for RBAC check
temp_model = ModelWithOwner(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@derekhiggins we ask providers to return Model, but we always need ModelWithOwner. can we move the ResourceWithOwner up a level in the class hierarcy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd have to look into how difficult it would be but wouldn't be a bad idea, I'd like to add a few simple RBAC policies into the integration auth tests next. Then would be more comfortable with refactors like this.

identifier=model.identifier,
provider_id=provider_id,
provider_resource_id=model.provider_resource_id,
model_type=model.model_type,
metadata=model.metadata,
)

# Apply RBAC check - only include models user has read permission for
if is_action_allowed(self.policy, "read", temp_model, user):
dynamic_models.append(model)
else:
logger.debug(
f"Access denied to dynamic model '{model.identifier}' for user {user.principal if user else 'anonymous'}"
)

logger.debug(
f"Fetched {len(dynamic_models)} accessible models from provider {provider_id} using provider_data"
)

except Exception as e:
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")
Expand Down
178 changes: 177 additions & 1 deletion tests/unit/server/test_access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.datatypes import AccessRule, ModelWithOwner, User
from llama_stack.core.routers.inference import InferenceRouter
from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack_api import Api, ModelType
from llama_stack_api import Api, Model, ModelNotFoundError, ModelType


class AsyncMock(MagicMock):
Expand Down Expand Up @@ -557,3 +558,178 @@ def test_condition_reprs(condition):
from llama_stack.core.access_control.conditions import parse_condition

assert condition == str(parse_condition(condition))


@pytest.fixture
def restricted_user():
"""User with limited access."""
return User("restricted-user", {"roles": ["user"]})


@pytest.fixture
def admin_user():
"""User with admin access."""
return User("admin-user", {"roles": ["admin"]})


@pytest.fixture
def rbac_policy():
"""RBAC policy that restricts access to certain models."""
from llama_stack.core.access_control.datatypes import Action, Scope

return [
# Admins get full access
AccessRule(
permit=Scope(actions=list(Action)),
when=["user with admin in roles"],
),
# Regular users only get read access to their own resources
AccessRule(
permit=Scope(actions=[Action.READ]),
when=["user is owner"],
),
]


class TestInferenceRouterRBACBypass:
"""Test RBAC bypass vulnerability in inference router fallback path."""

@pytest.fixture
def mock_routing_table(self):
"""Create a mock routing table for testing."""
routing_table = AsyncMock()
routing_table.impls_by_provider_id = {"test-provider": AsyncMock()}
routing_table.policy = []
return routing_table

@patch("llama_stack.core.routers.inference.get_authenticated_user")
async def test_registry_path_and_fallback_path_consistent(
self, mock_get_user, mock_routing_table, restricted_user, admin_user, rbac_policy
):
"""Test that registry path and fallback path have consistent RBAC enforcement."""
mock_routing_table.policy = rbac_policy

# Create a model owned by admin
admin_model = ModelWithOwner(
identifier="admin-model",
provider_id="test-provider",
provider_resource_id="admin-resource",
model_type=ModelType.llm,
type="model",
metadata={},
owner=admin_user,
)

# Setup router
router = InferenceRouter(
routing_table=mock_routing_table,
store=None,
)

# Test 1: Restricted user tries to access via registry (should fail)
mock_get_user.return_value = restricted_user
mock_routing_table.get_object_by_identifier.return_value = None # RBAC blocks it
with pytest.raises(ModelNotFoundError):
await router._get_model_provider("admin-model", "llm")

# Test 2: Restricted user tries to access via fallback path (should also fail)
mock_routing_table.get_object_by_identifier.return_value = None
with pytest.raises(ModelNotFoundError):
await router._get_model_provider("test-provider/admin-resource", "llm")

# Test 3: Admin user can access via registry
mock_get_user.return_value = admin_user
mock_routing_table.get_object_by_identifier.return_value = admin_model
provider_mock = AsyncMock()
mock_routing_table.get_provider_impl.return_value = provider_mock

provider, resource_id = await router._get_model_provider("admin-model", "llm")
assert provider == provider_mock
assert resource_id == "admin-resource"

# Test 4: Admin user can also access via fallback path
mock_routing_table.get_object_by_identifier.return_value = None
provider, resource_id = await router._get_model_provider("test-provider/admin-resource", "llm")
assert provider == mock_routing_table.impls_by_provider_id["test-provider"]
assert resource_id == "admin-resource"


class TestModelListingRBACBypass:
"""Test RBAC bypass vulnerability in dynamic model listing via provider_data."""

@patch("llama_stack.core.routing_tables.models.instantiate_class_type")
@patch("llama_stack.core.routing_tables.models.PROVIDER_DATA_VAR")
@patch("llama_stack.core.routing_tables.models.get_authenticated_user")
@patch("llama_stack.core.routing_tables.common.get_authenticated_user")
async def test_dynamic_models_respect_rbac(
self,
mock_get_user_common,
mock_get_user_models,
mock_provider_data,
mock_instantiate_class,
cached_disk_dist_registry,
rbac_policy,
admin_user,
restricted_user,
):
"""Test that models fetched via provider_data are filtered by RBAC."""
from llama_stack.core.request_headers import NeedsRequestProviderData

# Create a mock provider that supports provider_data
mock_provider = Mock(spec=NeedsRequestProviderData)
mock_provider.__provider_spec__ = MagicMock()
mock_provider.__provider_spec__.api = Api.inference
mock_provider.__provider_spec__.provider_data_validator = "dict"

# Mock the validator to always succeed
mock_validator = MagicMock(return_value={})
mock_instantiate_class.return_value = mock_validator

# Mock list_models to return dynamic models
# These are fetched via provider_data and don't have owners initially
dynamic_model1 = Model(
identifier="dynamic-model-1",
provider_id="test-provider",
provider_resource_id="dynamic-model-1",
model_type=ModelType.llm,
metadata={},
)
dynamic_model2 = Model(
identifier="dynamic-model-2",
provider_id="test-provider",
provider_resource_id="dynamic-model-2",
model_type=ModelType.llm,
metadata={},
)
mock_provider.list_models = AsyncMock(return_value=[dynamic_model1, dynamic_model2])

# Setup routing table with policy (no models pre-registered in registry)
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test-provider": mock_provider},
dist_registry=cached_disk_dist_registry,
policy=rbac_policy,
)

# Set up provider_data context (user has credentials for this provider)
mock_provider_data.get.return_value = {"api_key": "test-key"}

# Test 1: Admin user can see dynamic models
# Admin rule allows all actions, so they can see models even without ownership
mock_get_user_common.return_value = admin_user
mock_get_user_models.return_value = admin_user

result = await routing_table.list_models()
model_ids = [m.identifier for m in result.data]
assert "test-provider/dynamic-model-1" in model_ids
assert "test-provider/dynamic-model-2" in model_ids

# Test 2: Restricted user CANNOT see dynamic models
# Dynamic models have no owner, and policy requires either admin role OR ownership
# This demonstrates the fix: before, these would be returned without RBAC checks
mock_get_user_common.return_value = restricted_user
mock_get_user_models.return_value = restricted_user

result = await routing_table.list_models()
model_ids = [m.identifier for m in result.data]
# Restricted user should see no models (no ownership, not admin)
assert len(model_ids) == 0
Loading