Skip to content

Commit

Permalink
[Tests] Send only string headers to align to new requests limitation (#…
Browse files Browse the repository at this point in the history
…2039)

(cherry picked from commit 362d6c9)
  • Loading branch information
Hedingber committed Jun 12, 2022
1 parent 1a3ceef commit 2db36db
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 51 deletions.
9 changes: 9 additions & 0 deletions mlrun/api/utils/clients/iguazio.py
@@ -1,5 +1,6 @@
import copy
import datetime
import enum
import http
import json
import typing
Expand Down Expand Up @@ -416,6 +417,14 @@ def _send_request_to_api(
kwargs.setdefault("headers", {})[
mlrun.api.schemas.HeaderNames.projects_role
] = "mlrun"

# requests no longer supports header values to be enum (https://github.com/psf/requests/pull/6154)
# convert to strings. Do the same for params for niceness
for kwarg in ["headers", "params"]:
dict_ = kwargs.get(kwarg, {})
for key in dict_.keys():
if isinstance(dict_[key], enum.Enum):
dict_[key] = dict_[key].value
response = self._session.request(method, url, verify=False, **kwargs)
if not response.ok:
log_kwargs = copy.deepcopy(kwargs)
Expand Down
9 changes: 9 additions & 0 deletions mlrun/api/utils/clients/nuclio.py
@@ -1,4 +1,5 @@
import copy
import enum
import http
import typing

Expand Down Expand Up @@ -189,6 +190,14 @@ def _send_request_to_api(self, method, path, **kwargs):
url = f"{self._api_url}/api/{path}"
if kwargs.get("timeout") is None:
kwargs["timeout"] = 20

# requests no longer supports header values to be enum (https://github.com/psf/requests/pull/6154)
# convert to strings. Do the same for params for niceness
for kwarg in ["headers", "params"]:
dict_ = kwargs.get(kwarg, {})
for key in dict_.keys():
if isinstance(dict_[key], enum.Enum):
dict_[key] = dict_[key].value
response = self._session.request(method, url, verify=False, **kwargs)
if not response.ok:
log_kwargs = copy.deepcopy(kwargs)
Expand Down
48 changes: 9 additions & 39 deletions mlrun/db/httpdb.py
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import http
import os
import tempfile
Expand Down Expand Up @@ -203,6 +203,14 @@ def api_call(
{mlrun.api.schemas.HeaderNames.client_version: self.client_version}
)

# requests no longer supports header values to be enum (https://github.com/psf/requests/pull/6154)
# convert to strings. Do the same for params for niceness
for dict_ in [headers, params]:
if dict_ is not None:
for key in dict_.keys():
if isinstance(dict_[key], enum.Enum):
dict_[key] = dict_[key].value

if not self.session:
self.session = requests.Session()
self.session.mount("http://", http_adapter)
Expand Down Expand Up @@ -689,8 +697,6 @@ def list_artifacts(
"""

project = project or config.default_project
if category and isinstance(category, schemas.ArtifactCategories):
category = category.value

params = {
"name": name,
Expand Down Expand Up @@ -818,8 +824,6 @@ def list_runtime_resources(
resources are per Function, for which the identifier is the Function's name.
:param group_by: Object to group results by. Allowed values are `job` and `project`.
"""
if isinstance(group_by, mlrun.api.schemas.ListRuntimeResourcesGroupByField):
group_by = group_by.value
params = {
"label_selector": label_selector,
"group-by": group_by,
Expand Down Expand Up @@ -1427,8 +1431,6 @@ def list_pipelines(
raise mlrun.errors.MLRunInvalidArgumentError(
"Filtering by project can not be used together with pagination, or sorting"
)
if isinstance(format_, mlrun.api.schemas.PipelinesFormat):
format_ = format_.value
params = {
"namespace": namespace,
"sort_by": sort_by,
Expand Down Expand Up @@ -1456,8 +1458,6 @@ def get_pipeline(
):
"""Retrieve details of a specific pipeline using its run ID (as provided when the pipeline was executed)."""

if isinstance(format_, mlrun.api.schemas.PipelinesFormat):
format_ = format_.value
try:
params = {}
if namespace:
Expand Down Expand Up @@ -1610,12 +1610,6 @@ def _generate_partition_by_params(
order,
max_partitions=None,
):
if isinstance(partition_by, partition_by_cls):
partition_by = partition_by.value
if isinstance(sort_by, schemas.SortField):
sort_by = sort_by.value
if isinstance(order, schemas.OrderType):
order = order.value

partition_params = {
"partition-by": partition_by,
Expand Down Expand Up @@ -1761,8 +1755,6 @@ def patch_feature_set(
"""
project = project or config.default_project
reference = self._resolve_reference(tag, uid)
if isinstance(patch_mode, schemas.PatchMode):
patch_mode = patch_mode.value
headers = {schemas.HeaderNames.patch_mode: patch_mode}
path = f"projects/{project}/feature-sets/{name}/references/{reference}"
error_message = f"Failed updating feature-set {project}/{name}"
Expand Down Expand Up @@ -1964,8 +1956,6 @@ def patch_feature_vector(
"""
reference = self._resolve_reference(tag, uid)
project = project or config.default_project
if isinstance(patch_mode, schemas.PatchMode):
patch_mode = patch_mode.value
headers = {schemas.HeaderNames.patch_mode: patch_mode}
path = f"projects/{project}/feature-vectors/{name}/references/{reference}"
error_message = f"Failed updating feature-vector {project}/{name}"
Expand Down Expand Up @@ -2013,10 +2003,6 @@ def list_projects(
:param state: Filter by project's state. Can be either ``online`` or ``archived``.
"""

if isinstance(state, mlrun.api.schemas.ProjectState):
state = state.value
if isinstance(format_, mlrun.api.schemas.ProjectsFormat):
format_ = format_.value
params = {
"owner": owner,
"state": state,
Expand Down Expand Up @@ -2067,8 +2053,6 @@ def delete_project(
"""

path = f"projects/{name}"
if isinstance(deletion_strategy, schemas.DeletionStrategy):
deletion_strategy = deletion_strategy.value
headers = {schemas.HeaderNames.deletion_strategy: deletion_strategy}
error_message = f"Failed deleting project {name}"
response = self.api_call("DELETE", path, error_message, headers=headers)
Expand Down Expand Up @@ -2113,8 +2097,6 @@ def patch_project(
"""

path = f"projects/{name}"
if isinstance(patch_mode, schemas.PatchMode):
patch_mode = patch_mode.value
headers = {schemas.HeaderNames.patch_mode: patch_mode}
error_message = f"Failed patching project {name}"
response = self.api_call(
Expand Down Expand Up @@ -2236,8 +2218,6 @@ def create_project_secrets(
secrets=secrets
)
"""
if isinstance(provider, schemas.SecretProviderName):
provider = provider.value
path = f"projects/{project}/secrets"
secrets_input = schemas.SecretsData(secrets=secrets, provider=provider)
body = secrets_input.dict()
Expand Down Expand Up @@ -2272,9 +2252,6 @@ def list_project_secrets(
to this specific project. ``kubernetes`` provider only supports an empty list.
"""

if isinstance(provider, schemas.SecretProviderName):
provider = provider.value

if provider == schemas.SecretProviderName.vault.value and not token:
raise MLRunInvalidArgumentError(
"A vault token must be provided when accessing vault secrets"
Expand Down Expand Up @@ -2314,9 +2291,6 @@ def list_project_secret_keys(
Must be a valid Vault token, with permissions to retrieve secrets of the project in question.
"""

if isinstance(provider, schemas.SecretProviderName):
provider = provider.value

if provider == schemas.SecretProviderName.vault.value and not token:
raise MLRunInvalidArgumentError(
"A vault token must be provided when accessing vault secrets"
Expand Down Expand Up @@ -2354,8 +2328,6 @@ def delete_project_secrets(
:param secrets: A list of secret names to delete. An empty list will delete all secrets assigned
to this specific project.
"""
if isinstance(provider, schemas.SecretProviderName):
provider = provider.value

path = f"projects/{project}/secrets"
params = {"provider": provider, "secret": secrets}
Expand Down Expand Up @@ -2386,8 +2358,6 @@ def create_user_secrets(
:param provider: The name of the secrets-provider to work with. Currently only ``vault`` is supported.
:param secrets: A set of secret values to store within the Vault.
"""
if isinstance(provider, schemas.SecretProviderName):
provider = provider.value
path = "user-secrets"
secrets_creation_request = schemas.UserSecretCreationRequest(
user=user,
Expand Down
24 changes: 12 additions & 12 deletions tests/api/api/test_projects.py
Expand Up @@ -117,7 +117,7 @@ def test_delete_project_with_resources(
response = client.delete(
f"projects/{project_to_remove}",
headers={
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.check
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.check.value
},
)
assert response.status_code == HTTPStatus.PRECONDITION_FAILED.value
Expand All @@ -126,7 +126,7 @@ def test_delete_project_with_resources(
response = client.delete(
f"projects/{project_to_remove}",
headers={
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.restricted
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.restricted.value
},
)
assert response.status_code == HTTPStatus.PRECONDITION_FAILED.value
Expand All @@ -135,7 +135,7 @@ def test_delete_project_with_resources(
response = client.delete(
f"projects/{project_to_remove}",
headers={
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.cascading
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.cascading.value
},
)
assert response.status_code == HTTPStatus.NO_CONTENT.value
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_delete_project_with_resources(
response = client.delete(
f"projects/{project_to_remove}",
headers={
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.check
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.check.value
},
)
assert response.status_code == HTTPStatus.NO_CONTENT.value
Expand All @@ -183,7 +183,7 @@ def test_delete_project_with_resources(
response = client.delete(
f"projects/{project_to_remove}",
headers={
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.restricted
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.restricted.value
},
)
assert response.status_code == HTTPStatus.NO_CONTENT.value
Expand Down Expand Up @@ -445,7 +445,7 @@ def test_delete_project_deletion_strategy_check(
response = client.delete(
f"projects/{project.metadata.name}",
headers={
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.check
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.check.value
},
)
assert response.status_code == HTTPStatus.NO_CONTENT.value
Expand All @@ -467,7 +467,7 @@ def test_delete_project_deletion_strategy_check(
response = client.delete(
f"projects/{project.metadata.name}",
headers={
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.check
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.check.value
},
)
assert response.status_code == HTTPStatus.PRECONDITION_FAILED.value
Expand Down Expand Up @@ -533,7 +533,7 @@ def test_delete_project_not_deleting_versioned_objects_multiple_times(
response = client.delete(
f"projects/{project_name}",
headers={
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.cascading
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.cascading.value
},
)
assert response.status_code == HTTPStatus.NO_CONTENT.value
Expand Down Expand Up @@ -576,7 +576,7 @@ def test_delete_project_deletion_strategy_check_external_resource(
response = client.delete(
f"projects/{project.metadata.name}",
headers={
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.restricted
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.restricted.value
},
)
assert response.status_code == HTTPStatus.PRECONDITION_FAILED.value
Expand All @@ -586,7 +586,7 @@ def test_delete_project_deletion_strategy_check_external_resource(
response = client.delete(
f"projects/{project.metadata.name}",
headers={
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.restricted
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.restricted.value
},
)
assert response
Expand Down Expand Up @@ -771,7 +771,7 @@ def test_projects_crud(
response = client.delete(
f"projects/{name1}",
headers={
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.restricted
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.restricted.value
},
)
assert response.status_code == HTTPStatus.PRECONDITION_FAILED.value
Expand All @@ -780,7 +780,7 @@ def test_projects_crud(
response = client.delete(
f"projects/{name1}",
headers={
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.cascading
mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.cascading.value
},
)
assert response.status_code == HTTPStatus.NO_CONTENT.value
Expand Down
29 changes: 29 additions & 0 deletions tests/rundb/test_unit_httpdb.py
@@ -0,0 +1,29 @@
# test_httpdb.py actually holds integration tests (that should be migrated to tests/integration/sdk_api/httpdb)
# currently we are running it in the integration tests CI step so adding this file for unit tests for the httpdb
import enum
import unittest.mock

import mlrun.db.httpdb


class SomeEnumClass(str, enum.Enum):
value1 = "value1"
value2 = "value2"


def test_api_call_enum_conversion():
db = mlrun.db.httpdb.HTTPRunDB("fake-url")
db.session = unittest.mock.Mock()

# ensure not exploding when no headers/params
db.api_call("GET", "some-path")

db.api_call(
"GET",
"some-path",
headers={"enum-value": SomeEnumClass.value1, "string-value": "value"},
params={"enum-value": SomeEnumClass.value2, "string-value": "value"},
)
for dict_key in ["headers", "params"]:
for value in db.session.request.call_args_list[1][1][dict_key].values():
assert type(value) == str

0 comments on commit 2db36db

Please sign in to comment.