diff --git a/run.yaml b/run.yaml index c7535406..d77492af 100644 --- a/run.yaml +++ b/run.yaml @@ -131,7 +131,10 @@ server: tls_cafile: null tls_certfile: null tls_keyfile: null -shields: [] +shields: + - shield_id: llama-guard-shield + provider_id: llama-guard + provider_shield_id: "gpt-3.5-turbo" # Model to use for safety checks vector_dbs: - vector_db_id: my_knowledge_base embedding_model: sentence-transformers/all-mpnet-base-v2 diff --git a/tests/e2e/configs/run-azure.yaml b/tests/e2e/configs/run-azure.yaml index a50301ad..fd8a8c79 100644 --- a/tests/e2e/configs/run-azure.yaml +++ b/tests/e2e/configs/run-azure.yaml @@ -120,7 +120,10 @@ server: tls_cafile: null tls_certfile: null tls_keyfile: null -shields: [] +shields: + - shield_id: llama-guard-shield + provider_id: llama-guard + provider_shield_id: "gpt-4o-mini" models: - model_id: gpt-4o-mini model_type: llm diff --git a/tests/e2e/configs/run-ci.yaml b/tests/e2e/configs/run-ci.yaml index bf1e9cc1..4b53da36 100644 --- a/tests/e2e/configs/run-ci.yaml +++ b/tests/e2e/configs/run-ci.yaml @@ -131,7 +131,10 @@ server: tls_cafile: null tls_certfile: null tls_keyfile: null -shields: [] +shields: + - shield_id: llama-guard-shield + provider_id: llama-guard + provider_shield_id: "gpt-4-turbo" vector_dbs: - vector_db_id: my_knowledge_base embedding_model: sentence-transformers/all-mpnet-base-v2 diff --git a/tests/e2e/configs/run-rhaiis.yaml b/tests/e2e/configs/run-rhaiis.yaml index 77398a13..67240e2e 100644 --- a/tests/e2e/configs/run-rhaiis.yaml +++ b/tests/e2e/configs/run-rhaiis.yaml @@ -123,7 +123,10 @@ server: tls_cafile: null tls_certfile: null tls_keyfile: null -shields: [] +shields: + - shield_id: llama-guard-shield + provider_id: llama-guard + provider_shield_id: "meta-llama/Llama-3.1-8B-Instruct" models: - metadata: embedding_dimension: 768 # Depends on chosen model diff --git a/tests/e2e/features/info.feature b/tests/e2e/features/info.feature index 2f67e53a..d6efebf0 100644 --- a/tests/e2e/features/info.feature +++ b/tests/e2e/features/info.feature @@ -37,7 +37,7 @@ Feature: Info tests And The body of the response has proper model structure - Scenario: Check if models endpoint is working + Scenario: Check if models endpoint reports error when llama-stack in unreachable Given The system is in default state And The llama-stack connection is disrupted When I access REST API endpoint "models" using HTTP GET method @@ -47,6 +47,22 @@ Feature: Info tests {"detail": {"response": "Unable to connect to Llama Stack", "cause": "Connection error."}} """ + Scenario: Check if shields endpoint is working + Given The system is in default state + When I access REST API endpoint "shields" using HTTP GET method + Then The status code of the response is 200 + And The body of the response has proper shield structure + + + Scenario: Check if shields endpoint reports error when llama-stack in unreachable + Given The system is in default state + And The llama-stack connection is disrupted + When I access REST API endpoint "shields" using HTTP GET method + Then The status code of the response is 500 + And The body of the response is the following + """ + {"detail": {"response": "Unable to connect to Llama Stack", "cause": "Connection error."}} + """ Scenario: Check if metrics endpoint is working Given The system is in default state diff --git a/tests/e2e/features/steps/info.py b/tests/e2e/features/steps/info.py index 2dbd1e6c..0c12f1ff 100644 --- a/tests/e2e/features/steps/info.py +++ b/tests/e2e/features/steps/info.py @@ -63,3 +63,37 @@ def check_model_structure(context: Context) -> None: assert ( llm_model["identifier"] == f"{expected_provider}/{expected_model}" ), f"identifier should be '{expected_provider}/{expected_model}'" + + +@then("The body of the response has proper shield structure") +def check_shield_structure(context: Context) -> None: + """Check that the first shield has the correct structure and required fields.""" + response_json = context.response.json() + assert response_json is not None, "Response is not valid JSON" + + assert "shields" in response_json, "Response missing 'shields' field" + shields = response_json["shields"] + assert len(shields) > 0, "Response has empty list of shields" + + # Find first shield + found_shield = None + for shield in shields: + if shield.get("type") == "shield": + found_shield = shield + break + + assert found_shield is not None, "No shield found in response" + + expected_model = context.default_model + + # Validate structure and values + assert found_shield["type"] == "shield", "type should be 'shield'" + assert ( + found_shield["provider_id"] == "llama-guard" + ), "provider_id should be 'llama-guard'" + assert ( + found_shield["provider_resource_id"] == expected_model + ), f"provider_resource_id should be '{expected_model}'" + assert ( + found_shield["identifier"] == "llama-guard-shield" + ), "identifier should be 'llama-guard-shield'"