From 897aa1d5bc6c74cb772affc84625c30326aed034 Mon Sep 17 00:00:00 2001 From: Hedingber Date: Thu, 1 Jul 2021 01:29:37 +0300 Subject: [PATCH] [SDK] Send credentials in request also if inside the cluster (#1047) --- mlrun/api/api/deps.py | 34 +++++++--------- mlrun/api/api/endpoints/artifacts.py | 8 ++-- mlrun/api/api/endpoints/feature_store.py | 30 +++++++-------- mlrun/api/api/endpoints/functions.py | 44 +++++++++------------ mlrun/api/api/endpoints/model_endpoints.py | 10 +++-- mlrun/api/api/endpoints/projects.py | 16 ++++---- mlrun/api/api/endpoints/runs.py | 12 +++--- mlrun/api/api/endpoints/runtimes.py | 19 +++++---- mlrun/api/api/endpoints/schedules.py | 15 ++++---- mlrun/api/api/endpoints/submit.py | 8 ++-- mlrun/api/api/utils.py | 33 +++++++++++----- mlrun/api/schemas/__init__.py | 1 + mlrun/api/schemas/auth.py | 17 ++++++++ mlrun/api/utils/scheduler.py | 43 +++++++++++---------- mlrun/api/utils/singletons/scheduler.py | 6 ++- mlrun/db/httpdb.py | 10 ++++- mlrun/platforms/iguazio.py | 14 +++++-- mlrun/runtimes/pod.py | 7 ++++ tests/api/api/test_utils.py | 13 +++++-- tests/api/utils/test_scheduler.py | 45 +++++++++++++++++----- tests/platforms/test_iguazio.py | 10 +++++ 21 files changed, 247 insertions(+), 148 deletions(-) create mode 100644 mlrun/api/schemas/auth.py diff --git a/mlrun/api/api/deps.py b/mlrun/api/api/deps.py index c10b8877a55..b7bdcaf2497 100644 --- a/mlrun/api/api/deps.py +++ b/mlrun/api/api/deps.py @@ -1,10 +1,11 @@ +import typing from base64 import b64decode from http import HTTPStatus -from typing import Generator from fastapi import Request from sqlalchemy.orm import Session +import mlrun.api.schemas import mlrun.api.utils.authorizers.authorizer import mlrun.api.utils.authorizers.nop import mlrun.api.utils.authorizers.opa @@ -14,7 +15,7 @@ from mlrun.config import config -def get_db_session() -> Generator[Session, None, None]: +def get_db_session() -> typing.Generator[Session, None, None]: try: db_session = create_session() yield db_session @@ -27,16 +28,7 @@ class AuthVerifier: _bearer_prefix = "Bearer " def __init__(self, request: Request): - # Basic auth - self.username = None - self.password = None - # Bearer auth - self.token = None - # Iguazio auth - self.session = None - self.data_session = None - self.uid = None - self.gids = None + self.auth_info = mlrun.api.schemas.AuthInfo() self._authenticate_request(request) self._authorize_request(request) @@ -68,8 +60,8 @@ def _authenticate_request(self, request: Request): HTTPStatus.UNAUTHORIZED.value, reason="Username or password did not match", ) - self.username = username - self.password = password + self.auth_info.username = username + self.auth_info.password = password elif self._bearer_auth_required(): if not header.startswith(self._bearer_prefix): log_and_raise( @@ -80,20 +72,20 @@ def _authenticate_request(self, request: Request): log_and_raise( HTTPStatus.UNAUTHORIZED.value, reason="Token did not match" ) - self.token = token + self.auth_info.token = token elif self._iguazio_auth_required(): iguazio_client = mlrun.api.utils.clients.iguazio.Client() ( - self.username, - self.session, - self.uid, - self.gids, + self.auth_info.username, + self.auth_info.session, + self.auth_info.user_id, + self.auth_info.user_group_ids, planes, ) = iguazio_client.verify_request_session(request) if "x-data-session-override" in request.headers: - self.data_session = request.headers["x-data-session-override"] + self.auth_info.data_session = request.headers["x-data-session-override"] elif "data" in planes: - self.data_session = self.session + self.auth_info.data_session = self.auth_info.session @staticmethod def _basic_auth_required(): diff --git a/mlrun/api/api/endpoints/artifacts.py b/mlrun/api/api/endpoints/artifacts.py index 5ef449eddf6..590d57de570 100644 --- a/mlrun/api/api/endpoints/artifacts.py +++ b/mlrun/api/api/endpoints/artifacts.py @@ -1,7 +1,7 @@ from http import HTTPStatus -from typing import List, Optional +from typing import List -from fastapi import APIRouter, Cookie, Depends, Query, Request +from fastapi import APIRouter, Depends, Query, Request from fastapi.concurrency import run_in_threadpool from sqlalchemy.orm import Session @@ -24,7 +24,7 @@ async def store_artifact( key: str, tag: str = "", iter: int = 0, - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): data = None @@ -43,7 +43,7 @@ async def store_artifact( iter=iter, tag=tag, project=project, - leader_session=iguazio_session, + leader_session=auth_verifier.auth_info.session, ) return {} diff --git a/mlrun/api/api/endpoints/feature_store.py b/mlrun/api/api/endpoints/feature_store.py index 7d4be5f5146..7d65bc2260c 100644 --- a/mlrun/api/api/endpoints/feature_store.py +++ b/mlrun/api/api/endpoints/feature_store.py @@ -1,7 +1,7 @@ from http import HTTPStatus from typing import List, Optional -from fastapi import APIRouter, Cookie, Depends, Header, Query, Request, Response +from fastapi import APIRouter, Depends, Header, Query, Request, Response from sqlalchemy.orm import Session import mlrun.feature_store @@ -23,11 +23,11 @@ def create_feature_set( project: str, feature_set: schemas.FeatureSet, versioned: bool = True, - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): feature_set_uid = get_db().create_feature_set( - db_session, project, feature_set, versioned, iguazio_session + db_session, project, feature_set, versioned, auth_verifier.auth_info.session ) return get_db().get_feature_set( @@ -49,7 +49,7 @@ def store_feature_set( reference: str, feature_set: schemas.FeatureSet, versioned: bool = True, - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): tag, uid = parse_reference(reference) @@ -61,7 +61,7 @@ def store_feature_set( tag, uid, versioned, - leader_session=iguazio_session, + leader_session=auth_verifier.auth_info.session, ) return get_db().get_feature_set(db_session, project, name, tag=tag, uid=uid) @@ -76,7 +76,7 @@ def patch_feature_set( patch_mode: schemas.PatchMode = Header( schemas.PatchMode.replace, alias=schemas.HeaderNames.patch_mode ), - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): tag, uid = parse_reference(reference) @@ -88,7 +88,7 @@ def patch_feature_set( tag, uid, patch_mode, - iguazio_session, + auth_verifier.auth_info.session, ) return Response(status_code=HTTPStatus.OK.value) @@ -199,7 +199,7 @@ def ingest_feature_set( schemas.FeatureSetIngestInput ] = schemas.FeatureSetIngestInput(), username: str = Header(None, alias="x-remote-user"), - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): tag, uid = parse_reference(reference) @@ -207,7 +207,7 @@ def ingest_feature_set( feature_set = mlrun.feature_store.FeatureSet.from_dict(feature_set_record.dict()) # Need to override the default rundb since we're in the server. - feature_set._override_run_db(db_session, iguazio_session) + feature_set._override_run_db(db_session, auth_verifier.auth_info.session) data_source = data_targets = None if ingest_parameters.source: @@ -283,11 +283,11 @@ def create_feature_vector( project: str, feature_vector: schemas.FeatureVector, versioned: bool = True, - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): feature_vector_uid = get_db().create_feature_vector( - db_session, project, feature_vector, versioned, iguazio_session + db_session, project, feature_vector, versioned, auth_verifier.auth_info.session ) return get_db().get_feature_vector( @@ -356,7 +356,7 @@ def store_feature_vector( reference: str, feature_vector: schemas.FeatureVector, versioned: bool = True, - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): tag, uid = parse_reference(reference) @@ -368,7 +368,7 @@ def store_feature_vector( tag, uid, versioned, - leader_session=iguazio_session, + leader_session=auth_verifier.auth_info.session, ) return get_db().get_feature_vector(db_session, project, name, uid=uid, tag=tag) @@ -383,7 +383,7 @@ def patch_feature_vector( patch_mode: schemas.PatchMode = Header( schemas.PatchMode.replace, alias=schemas.HeaderNames.patch_mode ), - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): tag, uid = parse_reference(reference) @@ -395,7 +395,7 @@ def patch_feature_vector( tag, uid, patch_mode, - iguazio_session, + auth_verifier.auth_info.session, ) return Response(status_code=HTTPStatus.OK.value) diff --git a/mlrun/api/api/endpoints/functions.py b/mlrun/api/api/endpoints/functions.py index d3605827e6e..8f2da7dba87 100644 --- a/mlrun/api/api/endpoints/functions.py +++ b/mlrun/api/api/endpoints/functions.py @@ -1,17 +1,9 @@ import traceback from distutils.util import strtobool from http import HTTPStatus -from typing import List, Optional - -from fastapi import ( - APIRouter, - BackgroundTasks, - Cookie, - Depends, - Query, - Request, - Response, -) +from typing import List + +from fastapi import APIRouter, BackgroundTasks, Depends, Query, Request, Response from fastapi.concurrency import run_in_threadpool from sqlalchemy.orm import Session @@ -40,7 +32,7 @@ async def store_function( name: str, tag: str = "", versioned: bool = False, - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): data = None @@ -59,7 +51,7 @@ async def store_function( project, tag=tag, versioned=versioned, - leader_session=iguazio_session, + leader_session=auth_verifier.auth_info.session, ) return { "hash_key": hash_key, @@ -111,7 +103,7 @@ def list_functions( @router.post("/build/function/") async def build_function( request: Request, - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): data = None @@ -128,11 +120,11 @@ async def build_function( fn, ready = await run_in_threadpool( _build_function, db_session, + auth_verifier.auth_info, function, with_mlrun, skip_deployed, mlrun_version_specifier, - iguazio_session, ) return { "data": fn.to_dict(), @@ -146,7 +138,7 @@ async def build_function( async def start_function( request: Request, background_tasks: BackgroundTasks, - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): data = None @@ -162,12 +154,12 @@ async def start_function( background_task = await run_in_threadpool( mlrun.api.utils.background_tasks.Handler().create_background_task, db_session, - iguazio_session, + auth_verifier.auth_info.session, function.metadata.project, background_tasks, _start_function, function, - iguazio_session, + auth_verifier.auth_info, ) return background_task @@ -200,7 +192,7 @@ def build_status( logs: bool = True, last_log_timestamp: float = 0.0, verbose: bool = False, - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): fn = get_db().get_function(db_session, name, project, tag) @@ -238,7 +230,7 @@ def build_status( project, tag, versioned=versioned, - leader_session=iguazio_session, + leader_session=auth_verifier.auth_info.session, ) return Response( content=text, @@ -299,7 +291,7 @@ def build_status( project, tag, versioned=versioned, - leader_session=iguazio_session, + leader_session=auth_verifier.auth_info.session, ) return Response( @@ -316,21 +308,22 @@ def build_status( def _build_function( db_session, + auth_info: mlrun.api.schemas.AuthInfo, function, with_mlrun, skip_deployed, mlrun_version_specifier, - leader_session, ): fn = None ready = None try: fn = new_function(runtime=function) - run_db = get_run_db_instance(db_session, leader_session) + run_db = get_run_db_instance(db_session, auth_info.session) fn.set_db_connection(run_db) fn.save(versioned=False) if fn.kind in RuntimeKinds.nuclio_runtimes(): + mlrun.api.api.utils.ensure_function_has_auth_set(fn, auth_info) deploy_nuclio_function(fn) # deploy only start the process, the get status API is used to check readiness ready = False @@ -365,7 +358,7 @@ def _parse_start_function_body(db_session, data): return new_function(runtime=runtime) -def _start_function(function, leader_session: Optional[str] = None): +def _start_function(function, auth_info: mlrun.api.schemas.AuthInfo): db_session = mlrun.api.db.session.create_session() try: resource = runtime_resources_map.get(function.kind) @@ -375,8 +368,9 @@ def _start_function(function, leader_session: Optional[str] = None): reason="runtime error: 'start' not supported by this runtime", ) try: - run_db = get_run_db_instance(db_session, leader_session) + run_db = get_run_db_instance(db_session, auth_info.session) function.set_db_connection(run_db) + mlrun.api.api.utils.ensure_function_has_auth_set(function, auth_info) # resp = resource["start"](fn) # TODO: handle resp? resource["start"](function) function.save(versioned=False) diff --git a/mlrun/api/api/endpoints/model_endpoints.py b/mlrun/api/api/endpoints/model_endpoints.py index 4ddfc725878..535bb62b525 100644 --- a/mlrun/api/api/endpoints/model_endpoints.py +++ b/mlrun/api/api/endpoints/model_endpoints.py @@ -1,10 +1,10 @@ from http import HTTPStatus from typing import List, Optional -from fastapi import APIRouter, Cookie, Depends, Query, Request, Response +from fastapi import APIRouter, Depends, Query, Request, Response from sqlalchemy.orm import Session -import mlrun.api.api +import mlrun.api.api.deps from mlrun.api.crud.model_endpoints import ModelEndpoints, get_access_key from mlrun.api.schemas import ModelEndpoint, ModelEndpointList from mlrun.errors import MLRunConflictError @@ -21,7 +21,9 @@ def create_or_patch( project: str, endpoint_id: str, model_endpoint: ModelEndpoint, - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: mlrun.api.api.deps.AuthVerifier = Depends( + mlrun.api.api.deps.AuthVerifier + ), db_session: Session = Depends(mlrun.api.api.deps.get_db_session), ) -> Response: """ @@ -41,7 +43,7 @@ def create_or_patch( db_session=db_session, access_key=access_key, model_endpoint=model_endpoint, - leader_session=iguazio_session, + leader_session=auth_verifier.auth_info.session, ) return Response(status_code=HTTPStatus.NO_CONTENT.value) diff --git a/mlrun/api/api/endpoints/projects.py b/mlrun/api/api/endpoints/projects.py index 333c5268e99..86f9b10b257 100644 --- a/mlrun/api/api/endpoints/projects.py +++ b/mlrun/api/api/endpoints/projects.py @@ -28,14 +28,14 @@ def create_project( # TODO: we're in a http request context here, therefore it doesn't make sense that by default it will hold the # request until the process will be completed - after UI supports waiting - change default to False wait_for_completion: bool = fastapi.Query(True, alias="wait-for-completion"), - iguazio_session: typing.Optional[str] = fastapi.Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = fastapi.Depends(deps.AuthVerifier), db_session: Session = fastapi.Depends(deps.get_db_session), ): project, is_running_in_background = get_project_member().create_project( db_session, project, projects_role, - iguazio_session, + auth_verifier.auth_info.session, wait_for_completion=wait_for_completion, ) if is_running_in_background: @@ -61,7 +61,7 @@ def store_project( # TODO: we're in a http request context here, therefore it doesn't make sense that by default it will hold the # request until the process will be completed - after UI supports waiting - change default to False wait_for_completion: bool = fastapi.Query(True, alias="wait-for-completion"), - iguazio_session: typing.Optional[str] = fastapi.Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = fastapi.Depends(deps.AuthVerifier), db_session: Session = fastapi.Depends(deps.get_db_session), ): project, is_running_in_background = get_project_member().store_project( @@ -69,7 +69,7 @@ def store_project( name, project, projects_role, - iguazio_session, + auth_verifier.auth_info.session, wait_for_completion=wait_for_completion, ) if is_running_in_background: @@ -96,7 +96,7 @@ def patch_project( # TODO: we're in a http request context here, therefore it doesn't make sense that by default it will hold the # request until the process will be completed - after UI supports waiting - change default to False wait_for_completion: bool = fastapi.Query(True, alias="wait-for-completion"), - iguazio_session: typing.Optional[str] = fastapi.Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = fastapi.Depends(deps.AuthVerifier), db_session: Session = fastapi.Depends(deps.get_db_session), ): project, is_running_in_background = get_project_member().patch_project( @@ -105,7 +105,7 @@ def patch_project( project, patch_mode, projects_role, - iguazio_session, + auth_verifier.auth_info.session, wait_for_completion=wait_for_completion, ) if is_running_in_background: @@ -134,7 +134,7 @@ def delete_project( # TODO: we're in a http request context here, therefore it doesn't make sense that by default it will hold the # request until the process will be completed - after UI supports waiting - change default to False wait_for_completion: bool = fastapi.Query(True, alias="wait-for-completion"), - iguazio_session: typing.Optional[str] = fastapi.Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = fastapi.Depends(deps.AuthVerifier), db_session: Session = fastapi.Depends(deps.get_db_session), ): is_running_in_background = get_project_member().delete_project( @@ -142,7 +142,7 @@ def delete_project( name, deletion_strategy, projects_role, - iguazio_session, + auth_verifier.auth_info.session, wait_for_completion=wait_for_completion, ) if is_running_in_background: diff --git a/mlrun/api/api/endpoints/runs.py b/mlrun/api/api/endpoints/runs.py index 4298776eae9..0668388d910 100644 --- a/mlrun/api/api/endpoints/runs.py +++ b/mlrun/api/api/endpoints/runs.py @@ -1,7 +1,7 @@ from http import HTTPStatus -from typing import List, Optional +from typing import List -from fastapi import APIRouter, Cookie, Depends, Query, Request +from fastapi import APIRouter, Depends, Query, Request from fastapi.concurrency import run_in_threadpool from sqlalchemy.orm import Session @@ -22,7 +22,7 @@ async def store_run( project: str, uid: str, iter: int = 0, - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): data = None @@ -39,7 +39,7 @@ async def store_run( uid, project, iter=iter, - leader_session=iguazio_session, + leader_session=auth_verifier.auth_info.session, ) return {} @@ -51,7 +51,7 @@ async def update_run( project: str, uid: str, iter: int = 0, - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): data = None @@ -67,7 +67,7 @@ async def update_run( uid, iter, data, - iguazio_session, + auth_verifier.auth_info.session, ) return {} diff --git a/mlrun/api/api/endpoints/runtimes.py b/mlrun/api/api/endpoints/runtimes.py index 232d3393c8e..16c2cf6c38b 100644 --- a/mlrun/api/api/endpoints/runtimes.py +++ b/mlrun/api/api/endpoints/runtimes.py @@ -1,7 +1,7 @@ import typing from http import HTTPStatus -from fastapi import APIRouter, Cookie, Depends, Query, Response +from fastapi import APIRouter, Depends, Query, Response from sqlalchemy.orm import Session import mlrun.api.crud @@ -39,11 +39,11 @@ def delete_runtimes( label_selector: str = None, force: bool = False, grace_period: int = config.runtime_resources_deletion_grace_period, - iguazio_session: typing.Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): mlrun.api.crud.Runtimes().delete_runtimes( - db_session, label_selector, force, grace_period, iguazio_session + db_session, label_selector, force, grace_period, auth_verifier.auth_info.session ) return Response(status_code=HTTPStatus.NO_CONTENT.value) @@ -54,11 +54,16 @@ def delete_runtime( label_selector: str = None, force: bool = False, grace_period: int = config.runtime_resources_deletion_grace_period, - iguazio_session: typing.Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): mlrun.api.crud.Runtimes().delete_runtime( - db_session, kind, label_selector, force, grace_period, iguazio_session, + db_session, + kind, + label_selector, + force, + grace_period, + auth_verifier.auth_info.session, ) return Response(status_code=HTTPStatus.NO_CONTENT.value) @@ -71,7 +76,7 @@ def delete_runtime_object( label_selector: str = None, force: bool = False, grace_period: int = config.runtime_resources_deletion_grace_period, - iguazio_session: typing.Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): mlrun.api.crud.Runtimes().delete_runtime_object( @@ -81,6 +86,6 @@ def delete_runtime_object( label_selector, force, grace_period, - iguazio_session, + auth_verifier.auth_info.session, ) return Response(status_code=HTTPStatus.NO_CONTENT.value) diff --git a/mlrun/api/api/endpoints/schedules.py b/mlrun/api/api/endpoints/schedules.py index feee4b1dbef..feb295c3811 100644 --- a/mlrun/api/api/endpoints/schedules.py +++ b/mlrun/api/api/endpoints/schedules.py @@ -1,7 +1,6 @@ -import typing from http import HTTPStatus -from fastapi import APIRouter, Cookie, Depends, Response +from fastapi import APIRouter, Depends, Response from sqlalchemy.orm import Session from mlrun.api import schemas @@ -15,11 +14,12 @@ def create_schedule( project: str, schedule: schemas.ScheduleInput, - iguazio_session: typing.Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): get_scheduler().create_schedule( db_session, + auth_verifier.auth_info, project, schedule.name, schedule.kind, @@ -27,7 +27,6 @@ def create_schedule( schedule.cron_trigger, labels=schedule.labels, concurrency_limit=schedule.concurrency_limit, - leader_session=iguazio_session, ) return Response(status_code=HTTPStatus.CREATED.value) @@ -37,17 +36,17 @@ def update_schedule( project: str, name: str, schedule: schemas.ScheduleUpdate, - iguazio_session: typing.Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): get_scheduler().update_schedule( db_session, + auth_verifier.auth_info, project, name, schedule.scheduled_object, schedule.cron_trigger, labels=schedule.labels, - leader_session=iguazio_session, ) return Response(status_code=HTTPStatus.OK.value) @@ -84,11 +83,11 @@ def get_schedule( async def invoke_schedule( project: str, name: str, - iguazio_session: typing.Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): return await get_scheduler().invoke_schedule( - db_session, project, name, iguazio_session + db_session, auth_verifier.auth_info, project, name ) diff --git a/mlrun/api/api/endpoints/submit.py b/mlrun/api/api/endpoints/submit.py index 81a71a32d0b..13dfffba783 100644 --- a/mlrun/api/api/endpoints/submit.py +++ b/mlrun/api/api/endpoints/submit.py @@ -1,7 +1,7 @@ from http import HTTPStatus from typing import Optional -from fastapi import APIRouter, Cookie, Depends, Header, Request +from fastapi import APIRouter, Depends, Header, Request from sqlalchemy.orm import Session import mlrun.api.api.utils @@ -19,7 +19,7 @@ async def submit_job( request: Request, username: Optional[str] = Header(None, alias="x-remote-user"), - iguazio_session: Optional[str] = Cookie(None, alias="session"), + auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier), db_session: Session = Depends(deps.get_db_session), ): data = None @@ -40,5 +40,7 @@ async def submit_job( labels.setdefault("owner", username) logger.info("Submit run", data=data) - response = await mlrun.api.api.utils.submit_run(db_session, data, iguazio_session) + response = await mlrun.api.api.utils.submit_run( + db_session, auth_verifier.auth_info, data + ) return response diff --git a/mlrun/api/api/utils.py b/mlrun/api/api/utils.py index 81886b0fe8e..a8b8b773b5d 100644 --- a/mlrun/api/api/utils.py +++ b/mlrun/api/api/utils.py @@ -10,6 +10,7 @@ from fastapi.concurrency import run_in_threadpool from sqlalchemy.orm import Session +import mlrun.api.api.deps import mlrun.errors from mlrun.api import schemas from mlrun.api.db.sqldb.db import SQLDB @@ -75,7 +76,9 @@ def get_run_db_instance( return run_db -def _parse_submit_run_body(db_session: Session, data): +def _parse_submit_run_body( + db_session: Session, auth_info: mlrun.api.schemas.AuthInfo, data +): task = data.get("task") function_dict = data.get("function") function_url = data.get("functionUrl") @@ -115,20 +118,32 @@ def _parse_submit_run_body(db_session: Session, data): # assign values from it to the main function object function = enrich_function_from_dict(function, function_dict) + # if auth given in request ensure the function pod will have these auth env vars set, otherwise the job won't + # be able to communicate with the api + ensure_function_has_auth_set(function, auth_info) + return function, task -async def submit_run( - db_session: Session, data, leader_session: typing.Optional[str] = None -): +async def submit_run(db_session: Session, auth_info: mlrun.api.schemas.AuthInfo, data): _, _, _, response = await run_in_threadpool( - _submit_run, db_session, data, leader_session + _submit_run, db_session, auth_info, data ) return response +def ensure_function_has_auth_set(function, auth_info: mlrun.api.schemas.AuthInfo): + if auth_info and auth_info.session: + auth_env_vars = { + "V3IO_ACCESS_KEY": auth_info.session, + } + for key, value in auth_env_vars.items(): + if not function.is_env_exists(key): + function.set_env(key, value) + + def _submit_run( - db_session: Session, data, leader_session: typing.Optional[str] = None + db_session: Session, auth_info: mlrun.api.schemas.AuthInfo, data ) -> typing.Tuple[str, str, str, typing.Dict]: """ :return: Tuple with: @@ -140,8 +155,8 @@ def _submit_run( run_uid = None project = None try: - fn, task = _parse_submit_run_body(db_session, data) - run_db = get_run_db_instance(db_session, leader_session) + fn, task = _parse_submit_run_body(db_session, auth_info, data) + run_db = get_run_db_instance(db_session, auth_info.session) fn.set_db_connection(run_db, True) logger.info("Submitting run", function=fn.to_dict(), task=task) # fn.spec.rundb = "http://mlrun-api:8080" @@ -153,13 +168,13 @@ def _submit_run( schedule_labels = task["metadata"].get("labels") get_scheduler().create_schedule( db_session, + auth_info, task["metadata"]["project"], task["metadata"]["name"], schemas.ScheduleKinds.job, data, cron_trigger, schedule_labels, - leader_session=leader_session, ) project = task["metadata"]["project"] diff --git a/mlrun/api/schemas/__init__.py b/mlrun/api/schemas/__init__.py index d2fcce3bf12..d5279296c40 100644 --- a/mlrun/api/schemas/__init__.py +++ b/mlrun/api/schemas/__init__.py @@ -1,6 +1,7 @@ # flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx from .artifact import ArtifactCategories +from .auth import AuthInfo from .background_task import ( BackgroundTask, BackgroundTaskMetadata, diff --git a/mlrun/api/schemas/auth.py b/mlrun/api/schemas/auth.py new file mode 100644 index 00000000000..c143937074c --- /dev/null +++ b/mlrun/api/schemas/auth.py @@ -0,0 +1,17 @@ +import typing + +import pydantic + + +class AuthInfo(pydantic.BaseModel): + # Basic + Iguazio auth + username: typing.Optional[str] = None + # Basic auth + password: typing.Optional[str] = None + # Bearer auth + token: typing.Optional[str] = None + # Iguazio auth + session: typing.Optional[str] = None + data_session: typing.Optional[str] = None + user_id: typing.Optional[str] = None + user_group_ids: typing.List[str] = [] diff --git a/mlrun/api/utils/scheduler.py b/mlrun/api/utils/scheduler.py index f553cd5cc6c..3e52b2b50e8 100644 --- a/mlrun/api/utils/scheduler.py +++ b/mlrun/api/utils/scheduler.py @@ -10,6 +10,7 @@ from apscheduler.triggers.cron import CronTrigger as APSchedulerCronTrigger from sqlalchemy.orm import Session +import mlrun.api.api.deps from mlrun.api import schemas from mlrun.api.db.session import close_session, create_session from mlrun.api.utils.singletons.db import get_db @@ -31,7 +32,7 @@ def __init__(self): self._min_allowed_interval = config.httpdb.scheduling.min_allowed_interval async def start( - self, db_session: Session, leader_session: Optional[str] = None, + self, db_session: Session, auth_info: mlrun.api.schemas.AuthInfo, ): logger.info("Starting scheduler") self._scheduler.start() @@ -41,7 +42,7 @@ async def start( # don't fail the start on re-scheduling failure try: - self._reload_schedules(db_session, leader_session) + self._reload_schedules(db_session, auth_info) except Exception as exc: logger.warning("Failed reloading schedules", exc=exc) @@ -55,6 +56,7 @@ async def stop(self): def create_schedule( self, db_session: Session, + auth_info: mlrun.api.schemas.AuthInfo, project: str, name: str, kind: schemas.ScheduleKinds, @@ -62,7 +64,6 @@ def create_schedule( cron_trigger: Union[str, schemas.ScheduleCronTrigger], labels: Dict = None, concurrency_limit: int = config.httpdb.scheduling.default_concurrency_limit, - leader_session: Optional[str] = None, ): if isinstance(cron_trigger, str): cron_trigger = schemas.ScheduleCronTrigger.from_crontab(cron_trigger) @@ -80,7 +81,7 @@ def create_schedule( concurrency_limit=concurrency_limit, ) get_project_member().ensure_project( - db_session, project, leader_session=leader_session + db_session, project, leader_session=auth_info.session ) get_db().create_schedule( db_session, @@ -99,19 +100,19 @@ def create_schedule( scheduled_object, cron_trigger, concurrency_limit, - leader_session, + auth_info, ) def update_schedule( self, db_session: Session, + auth_info: mlrun.api.schemas.AuthInfo, project: str, name: str, scheduled_object: Union[Dict, Callable] = None, cron_trigger: Union[str, schemas.ScheduleCronTrigger] = None, labels: Dict = None, concurrency_limit: int = None, - leader_session: Optional[str] = None, ): if isinstance(cron_trigger, str): cron_trigger = schemas.ScheduleCronTrigger.from_crontab(cron_trigger) @@ -136,7 +137,7 @@ def update_schedule( cron_trigger, labels, concurrency_limit, - leader_session, + auth_info.session, ) db_schedule = get_db().get_schedule(db_session, project, name) updated_schedule = self._transform_and_enrich_db_schedule( @@ -150,7 +151,7 @@ def update_schedule( updated_schedule.scheduled_object, updated_schedule.cron_trigger, updated_schedule.concurrency_limit, - leader_session, + auth_info, ) def list_schedules( @@ -199,9 +200,9 @@ def delete_schedule(self, db_session: Session, project: str, name: str): async def invoke_schedule( self, db_session: Session, + auth_info: mlrun.api.schemas.AuthInfo, project: str, name: str, - leader_session: Optional[str] = None, ): logger.debug("Invoking schedule", project=project, name=name) db_schedule = await fastapi.concurrency.run_in_threadpool( @@ -213,7 +214,7 @@ async def invoke_schedule( project, name, db_schedule.concurrency_limit, - leader_session, + auth_info, ) return await function(*args, **kwargs) @@ -276,12 +277,12 @@ def _create_schedule_in_scheduler( scheduled_object: Any, cron_trigger: schemas.ScheduleCronTrigger, concurrency_limit: int, - leader_session: Optional[str] = None, + auth_info: mlrun.api.schemas.AuthInfo, ): job_id = self._resolve_job_id(project, name) logger.debug("Adding schedule to scheduler", job_id=job_id) function, args, kwargs = self._resolve_job_function( - kind, scheduled_object, project, name, concurrency_limit, leader_session + kind, scheduled_object, project, name, concurrency_limit, auth_info ) # we use max_instances as well as our logic in the run wrapper for concurrent jobs @@ -306,12 +307,12 @@ def _update_schedule_in_scheduler( scheduled_object: Any, cron_trigger: schemas.ScheduleCronTrigger, concurrency_limit: int, - leader_session: Optional[str] = None, + auth_info: mlrun.api.schemas.AuthInfo, ): job_id = self._resolve_job_id(project, name) logger.debug("Updating schedule in scheduler", job_id=job_id) function, args, kwargs = self._resolve_job_function( - kind, scheduled_object, project, name, concurrency_limit, leader_session + kind, scheduled_object, project, name, concurrency_limit, auth_info ) trigger = self.transform_schemas_cron_trigger_to_apscheduler_cron_trigger( cron_trigger @@ -328,7 +329,7 @@ def _update_schedule_in_scheduler( ) def _reload_schedules( - self, db_session: Session, leader_session: Optional[str] = None, + self, db_session: Session, auth_info: mlrun.api.schemas.AuthInfo, ): logger.info("Reloading schedules") db_schedules = get_db().list_schedules(db_session) @@ -342,7 +343,7 @@ def _reload_schedules( db_schedule.scheduled_object, db_schedule.cron_trigger, db_schedule.concurrency_limit, - leader_session, + auth_info, ) except Exception as exc: logger.warn( @@ -392,7 +393,7 @@ def _resolve_job_function( project_name: str, schedule_name: str, schedule_concurrency_limit: int, - leader_session: Optional[str] = None, + auth_info: mlrun.api.schemas.AuthInfo, ) -> Tuple[Callable, Optional[Union[List, Tuple]], Optional[Dict]]: """ :return: a tuple (function, args, kwargs) to be used with the APScheduler.add_job @@ -407,7 +408,7 @@ def _resolve_job_function( project_name, schedule_name, schedule_concurrency_limit, - leader_session, + auth_info, ], {}, ) @@ -431,7 +432,7 @@ async def submit_run_wrapper( project_name, schedule_name, schedule_concurrency_limit, - leader_session: Optional[str] = None, + auth_info: mlrun.api.schemas.AuthInfo, ): # import here to avoid circular imports from mlrun.api.api.utils import submit_run @@ -468,7 +469,7 @@ async def submit_run_wrapper( ) return - response = await submit_run(db_session, scheduled_object, leader_session) + response = await submit_run(db_session, auth_info, scheduled_object) run_metadata = response["data"]["metadata"] run_uri = RunObject.create_uri( @@ -479,7 +480,7 @@ async def submit_run_wrapper( run_metadata["project"], schedule_name, last_run_uri=run_uri, - leader_session=leader_session, + leader_session=auth_info.session, ) close_session(db_session) diff --git a/mlrun/api/utils/singletons/scheduler.py b/mlrun/api/utils/singletons/scheduler.py index cecd88d128b..7f0fcf276e8 100644 --- a/mlrun/api/utils/singletons/scheduler.py +++ b/mlrun/api/utils/singletons/scheduler.py @@ -1,3 +1,4 @@ +import mlrun.api.schemas import mlrun.config from mlrun.api.db.sqldb.session import create_session from mlrun.api.utils.scheduler import Scheduler @@ -13,7 +14,10 @@ async def initialize_scheduler(): try: db_session = create_session() await scheduler.start( - db_session, mlrun.config.config.httpdb.projects.iguazio_access_key + db_session, + mlrun.api.schemas.AuthInfo( + session=mlrun.config.config.httpdb.projects.iguazio_access_key + ), ) finally: db_session.close() diff --git a/mlrun/db/httpdb.py b/mlrun/db/httpdb.py index d496e81940f..50ed93e229d 100644 --- a/mlrun/db/httpdb.py +++ b/mlrun/db/httpdb.py @@ -151,7 +151,15 @@ def api_call( if self.user: kw["auth"] = (self.user, self.password) elif self.token: - kw["headers"] = {"Authorization": "Bearer " + self.token} + # Iguazio auth doesn't support passing token through bearer, so use cookie instead + if mlrun.platforms.iguazio.is_iguazio_session(self.token): + session_cookie = f'j:{{"sid": "{self.token}"}}' + cookies = { + "session": session_cookie, + } + kw["cookies"] = cookies + else: + kw["headers"] = {"Authorization": "Bearer " + self.token} if not self.session: self.session = requests.Session() diff --git a/mlrun/platforms/iguazio.py b/mlrun/platforms/iguazio.py index a885348d864..3648b412def 100644 --- a/mlrun/platforms/iguazio.py +++ b/mlrun/platforms/iguazio.py @@ -474,7 +474,7 @@ def is_iguazio_endpoint(endpoint_url: str) -> bool: return ".default-tenant." in endpoint_url -def is_iguazio_control_session(value: str) -> bool: +def is_iguazio_session(value: str) -> bool: # TODO: find a better heuristic return len(value) > 20 and "-" in value @@ -500,13 +500,21 @@ def add_or_refresh_credentials( api_url: str, username: str = "", password: str = "", token: str = "" ) -> (str, str, str): - # this may be called in "open source scenario" so in this case (not iguazio endpoint) simply do nothing - if not is_iguazio_endpoint(api_url) or is_iguazio_control_session(password): + if is_iguazio_session(password): return username, password, token username = username or os.environ.get("V3IO_USERNAME") password = password or os.environ.get("V3IO_PASSWORD") token = token or os.environ.get("V3IO_ACCESS_KEY") + + # When it's not iguazio endpoint it's one of two options: + # Enterprise, but we're in the cluster (and not from remote), e.g. url will be something like http://mlrun-api:8080 + # In which we enforce to have access key which is needed for the API auth + # Open source in which auth is not enabled so no creds needed + # We don't really have an easy/nice way to differentiate between the two so we're just sending creds anyways + # (ideally if we could identify we're in enterprise we would have verify here that token and username have value) + if not is_iguazio_endpoint(api_url): + return "", "", token iguazio_dashboard_url = "https://dashboard" + api_url[api_url.find(".") :] # in 2.8 mlrun api is protected with control session, from 2.10 it's protected with access key diff --git a/mlrun/runtimes/pod.py b/mlrun/runtimes/pod.py index b2b94ba1163..d6a6593aac1 100644 --- a/mlrun/runtimes/pod.py +++ b/mlrun/runtimes/pod.py @@ -220,6 +220,13 @@ def set_env(self, name, value): """set pod environment var from value""" return self._set_env(name, value=str(value)) + def is_env_exists(self, name): + """Check whether there is an environment variable define for the given key""" + for env_var in self.spec.env: + if get_item_name(env_var) == name: + return True + return False + def _set_env(self, name, value=None, value_from=None): new_var = client.V1EnvVar(name=name, value=value, value_from=value_from) i = 0 diff --git a/tests/api/api/test_utils.py b/tests/api/api/test_utils.py index cf0e806891e..e05d6d5b588 100644 --- a/tests/api/api/test_utils.py +++ b/tests/api/api/test_utils.py @@ -5,6 +5,7 @@ from sqlalchemy.orm import Session import mlrun +import mlrun.api.schemas from mlrun.api.api.utils import _parse_submit_run_body @@ -140,7 +141,9 @@ def test_parse_submit_job_body_override_values(db: Session, client: TestClient): } }, } - parsed_function_object, task = _parse_submit_run_body(db, submit_job_body) + parsed_function_object, task = _parse_submit_run_body( + db, mlrun.api.schemas.AuthInfo(), submit_job_body + ) assert parsed_function_object.metadata.name == function_name assert parsed_function_object.metadata.project == project assert parsed_function_object.metadata.tag == function_tag @@ -200,7 +203,9 @@ def test_parse_submit_job_body_keep_resources(db: Session, client: TestClient): }, "function": {"spec": {"resources": {"limits": {}, "requests": {}}}}, } - parsed_function_object, task = _parse_submit_run_body(db, submit_job_body) + parsed_function_object, task = _parse_submit_run_body( + db, mlrun.api.schemas.AuthInfo(), submit_job_body + ) assert parsed_function_object.metadata.name == function_name assert parsed_function_object.metadata.project == project assert parsed_function_object.metadata.tag == function_tag @@ -235,7 +240,9 @@ def test_parse_submit_job_imported_function_project_assignment( }, "function": {"spec": {"resources": {"limits": {}, "requests": {}}}}, } - parsed_function_object, task = _parse_submit_run_body(db, submit_job_body) + parsed_function_object, task = _parse_submit_run_body( + db, mlrun.api.schemas.AuthInfo(), submit_job_body + ) assert parsed_function_object.metadata.project == task_project diff --git a/tests/api/utils/test_scheduler.py b/tests/api/utils/test_scheduler.py index 91ea850377f..2d18110d7ae 100644 --- a/tests/api/utils/test_scheduler.py +++ b/tests/api/utils/test_scheduler.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session import mlrun +import mlrun.api.api.deps import mlrun.api.utils.singletons.project_member import mlrun.errors from mlrun.api import schemas @@ -25,7 +26,7 @@ async def scheduler(db: Session) -> Generator: logger.info("Creating scheduler") config.httpdb.scheduling.min_allowed_interval = "0" scheduler = Scheduler() - await scheduler.start(db) + await scheduler.start(db, mlrun.api.schemas.AuthInfo()) mlrun.api.utils.singletons.project_member.initialize_project_member() yield scheduler logger.info("Stopping scheduler") @@ -66,6 +67,7 @@ async def test_not_skipping_delayed_schedules(db: Session, scheduler: Scheduler) project = config.default_project scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), project, schedule_name, schemas.ScheduleKinds.local_function, @@ -95,6 +97,7 @@ async def test_create_schedule(db: Session, scheduler: Scheduler): project = config.default_project scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), project, schedule_name, schemas.ScheduleKinds.local_function, @@ -115,6 +118,7 @@ async def test_invoke_schedule(db: Session, scheduler: Scheduler): assert len(runs) == 0 scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), project, schedule_name, schemas.ScheduleKinds.job, @@ -123,10 +127,14 @@ async def test_invoke_schedule(db: Session, scheduler: Scheduler): ) runs = get_db().list_runs(db, project=project) assert len(runs) == 0 - response_1 = await scheduler.invoke_schedule(db, project, schedule_name) + response_1 = await scheduler.invoke_schedule( + db, mlrun.api.schemas.AuthInfo(), project, schedule_name + ) runs = get_db().list_runs(db, project=project) assert len(runs) == 1 - response_2 = await scheduler.invoke_schedule(db, project, schedule_name) + response_2 = await scheduler.invoke_schedule( + db, mlrun.api.schemas.AuthInfo(), project, schedule_name + ) runs = get_db().list_runs(db, project=project) assert len(runs) == 2 for run in runs: @@ -159,6 +167,7 @@ async def test_create_schedule_mlrun_function(db: Session, scheduler: Scheduler) assert len(runs) == 0 scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), project, schedule_name, schemas.ScheduleKinds.job, @@ -194,6 +203,7 @@ async def test_create_schedule_success_cron_trigger_validation( cron_trigger = schemas.ScheduleCronTrigger(**case) scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), "project", f"schedule-name-{index}", schemas.ScheduleKinds.local_function, @@ -226,6 +236,7 @@ async def test_create_schedule_failure_too_frequent_cron_trigger( with pytest.raises(ValueError) as excinfo: scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), "project", "schedule-name", schemas.ScheduleKinds.local_function, @@ -244,6 +255,7 @@ async def test_create_schedule_failure_already_exists( project = config.default_project scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), project, schedule_name, schemas.ScheduleKinds.local_function, @@ -254,6 +266,7 @@ async def test_create_schedule_failure_already_exists( with pytest.raises(mlrun.errors.MLRunConflictError) as excinfo: scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), project, schedule_name, schemas.ScheduleKinds.local_function, @@ -297,6 +310,7 @@ async def test_get_schedule_datetime_fields_timezone(db: Session, scheduler: Sch project = config.default_project scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), project, schedule_name, schemas.ScheduleKinds.local_function, @@ -324,6 +338,7 @@ async def test_get_schedule(db: Session, scheduler: Scheduler): project = config.default_project scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), project, schedule_name, schemas.ScheduleKinds.local_function, @@ -353,6 +368,7 @@ async def test_get_schedule(db: Session, scheduler: Scheduler): schedule_name_2 = "schedule-name-2" scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), project, schedule_name_2, schemas.ScheduleKinds.local_function, @@ -437,6 +453,7 @@ async def test_list_schedules_name_filter(db: Session, scheduler: Scheduler): should_find = case["should_find"] scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), project, name, schemas.ScheduleKinds.local_function, @@ -460,6 +477,7 @@ async def test_delete_schedule(db: Session, scheduler: Scheduler): project = config.default_project scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), project, schedule_name, schemas.ScheduleKinds.local_function, @@ -492,6 +510,7 @@ async def test_rescheduling(db: Session, scheduler: Scheduler): project = config.default_project scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), project, schedule_name, schemas.ScheduleKinds.local_function, @@ -507,7 +526,7 @@ async def test_rescheduling(db: Session, scheduler: Scheduler): assert call_counter == 1 # start the scheduler and and assert another run - await scheduler.start(db) + await scheduler.start(db, mlrun.api.schemas.AuthInfo()) await asyncio.sleep(1) assert call_counter == 2 @@ -530,6 +549,7 @@ async def test_update_schedule(db: Session, scheduler: Scheduler): assert len(runs) == 0 scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), project, schedule_name, schemas.ScheduleKinds.job, @@ -552,7 +572,7 @@ async def test_update_schedule(db: Session, scheduler: Scheduler): # update labels scheduler.update_schedule( - db, project, schedule_name, labels=labels_2, + db, mlrun.api.schemas.AuthInfo(), project, schedule_name, labels=labels_2, ) schedule = scheduler.get_schedule(db, project, schedule_name) @@ -568,7 +588,7 @@ async def test_update_schedule(db: Session, scheduler: Scheduler): # update nothing scheduler.update_schedule( - db, project, schedule_name, + db, mlrun.api.schemas.AuthInfo(), project, schedule_name, ) schedule = scheduler.get_schedule(db, project, schedule_name) @@ -584,7 +604,7 @@ async def test_update_schedule(db: Session, scheduler: Scheduler): # update labels to empty dict scheduler.update_schedule( - db, project, schedule_name, labels={}, + db, mlrun.api.schemas.AuthInfo(), project, schedule_name, labels={}, ) schedule = scheduler.get_schedule(db, project, schedule_name) @@ -607,7 +627,11 @@ async def test_update_schedule(db: Session, scheduler: Scheduler): second="*/1", start_date=now_plus_1_second, end_date=now_plus_2_second, ) scheduler.update_schedule( - db, project, schedule_name, cron_trigger=cron_trigger, + db, + mlrun.api.schemas.AuthInfo(), + project, + schedule_name, + cron_trigger=cron_trigger, ) schedule = scheduler.get_schedule(db, project, schedule_name) @@ -642,7 +666,9 @@ async def test_update_schedule_failure_not_found(db: Session, scheduler: Schedul schedule_name = "schedule-name" project = config.default_project with pytest.raises(mlrun.errors.MLRunNotFoundError) as excinfo: - scheduler.update_schedule(db, project, schedule_name) + scheduler.update_schedule( + db, mlrun.api.schemas.AuthInfo(), project, schedule_name + ) assert "Schedule not found" in str(excinfo.value) @@ -689,6 +715,7 @@ async def test_schedule_job_concurrency_limit( scheduler.create_schedule( db, + mlrun.api.schemas.AuthInfo(), project, schedule_name, schedule_kind, diff --git a/tests/platforms/test_iguazio.py b/tests/platforms/test_iguazio.py index b313120f8c8..7c69ce3de6e 100644 --- a/tests/platforms/test_iguazio.py +++ b/tests/platforms/test_iguazio.py @@ -69,6 +69,16 @@ def mock_get(*args, **kwargs): assert access_key == result_access_key +def test_add_or_refresh_credentials_kubernetes_svc_url_success(monkeypatch): + access_key = "access_key" + api_url = "http://mlrun-api:8080" + env = os.environ + env["V3IO_ACCESS_KEY"] = access_key + + _, _, result_access_key = add_or_refresh_credentials(api_url) + assert access_key == result_access_key + + def test_mount_v3io_legacy(): username = "username" access_key = "access-key"