diff --git a/tests/unit/models/test_config.py b/tests/unit/models/test_config.py index e3843cbc..e5d19099 100644 --- a/tests/unit/models/test_config.py +++ b/tests/unit/models/test_config.py @@ -10,6 +10,7 @@ from constants import ( AUTH_MOD_NOOP, AUTH_MOD_K8S, + AUTH_MOD_JWK_TOKEN, DATA_COLLECTOR_COLLECTION_INTERVAL, DATA_COLLECTOR_CONNECTION_TIMEOUT, ) @@ -17,6 +18,7 @@ from models.config import ( AuthenticationConfiguration, Configuration, + JwkConfiguration, LlamaStackConfiguration, ServiceConfiguration, UserDataCollection, @@ -682,6 +684,52 @@ def test_authentication_configuration() -> None: assert auth_config.k8s_cluster_api is None +def test_authentication_configuration_jwk_token() -> None: + """Test the AuthenticationConfiguration with JWK token.""" + + auth_config = AuthenticationConfiguration( + module=AUTH_MOD_JWK_TOKEN, + skip_tls_verification=False, + k8s_ca_cert_path=None, + k8s_cluster_api=None, + jwk_config=JwkConfiguration(url="http://foo.bar.baz"), + ) + assert auth_config is not None + assert auth_config.module == AUTH_MOD_JWK_TOKEN + assert auth_config.skip_tls_verification is False + assert auth_config.k8s_ca_cert_path is None + assert auth_config.k8s_cluster_api is None + + +def test_authentication_configuration_jwk_token_but_insufficient_config() -> None: + """Test the AuthenticationConfiguration with JWK token.""" + + with pytest.raises(ValidationError, match="JwkConfiguration"): + AuthenticationConfiguration( + module=AUTH_MOD_JWK_TOKEN, + skip_tls_verification=False, + k8s_ca_cert_path=None, + k8s_cluster_api=None, + jwk_config=JwkConfiguration(), + ) + + +def test_authentication_configuration_jwk_token_but_not_config() -> None: + """Test the AuthenticationConfiguration with JWK token.""" + + with pytest.raises( + ValidationError, + match="Value error, JWK configuration must be specified when using JWK token", + ): + AuthenticationConfiguration( + module=AUTH_MOD_JWK_TOKEN, + skip_tls_verification=False, + k8s_ca_cert_path=None, + k8s_cluster_api=None, + # no JwkConfiguration + ) + + def test_authentication_configuration_supported() -> None: """Test the AuthenticationConfiguration constructor.""" auth_config = AuthenticationConfiguration( diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 69af9c00..980ac494 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -46,6 +46,9 @@ def test_get_llama_stack_library_client() -> None: ls_client = client.get_client() assert ls_client is not None + assert not ls_client.is_closed() + ls_client.close() + assert ls_client.is_closed() def test_get_llama_stack_remote_client() -> None: @@ -62,6 +65,9 @@ def test_get_llama_stack_remote_client() -> None: ls_client = client.get_client() assert ls_client is not None + assert not ls_client.is_closed() + ls_client.close() + assert ls_client.is_closed() def test_get_llama_stack_wrong_configuration() -> None: @@ -81,6 +87,7 @@ def test_get_llama_stack_wrong_configuration() -> None: client.load(cfg) +@pytest.mark.asyncio async def test_get_async_llama_stack_library_client() -> None: """Test the initialization of asynchronous Llama Stack client in library mode.""" cfg = LlamaStackConfiguration( @@ -93,8 +100,11 @@ async def test_get_async_llama_stack_library_client() -> None: await client.load(cfg) assert client is not None - ls_client = client.get_client() - assert ls_client is not None + async with client.get_client() as ls_client: + assert ls_client is not None + assert not ls_client.is_closed() + await ls_client.close() + assert ls_client.is_closed() async def test_get_async_llama_stack_remote_client() -> None: