Skip to content

Commit

Permalink
[SDK] Send credentials in request also if inside the cluster (#1047)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hedingber committed Jun 30, 2021
1 parent 6612bb4 commit 897aa1d
Show file tree
Hide file tree
Showing 21 changed files with 247 additions and 148 deletions.
34 changes: 13 additions & 21 deletions mlrun/api/api/deps.py
@@ -1,10 +1,11 @@
import typing
from base64 import b64decode
from http import HTTPStatus
from typing import Generator

from fastapi import Request
from sqlalchemy.orm import Session

import mlrun.api.schemas
import mlrun.api.utils.authorizers.authorizer
import mlrun.api.utils.authorizers.nop
import mlrun.api.utils.authorizers.opa
Expand All @@ -14,7 +15,7 @@
from mlrun.config import config


def get_db_session() -> Generator[Session, None, None]:
def get_db_session() -> typing.Generator[Session, None, None]:
try:
db_session = create_session()
yield db_session
Expand All @@ -27,16 +28,7 @@ class AuthVerifier:
_bearer_prefix = "Bearer "

def __init__(self, request: Request):
# Basic auth
self.username = None
self.password = None
# Bearer auth
self.token = None
# Iguazio auth
self.session = None
self.data_session = None
self.uid = None
self.gids = None
self.auth_info = mlrun.api.schemas.AuthInfo()

self._authenticate_request(request)
self._authorize_request(request)
Expand Down Expand Up @@ -68,8 +60,8 @@ def _authenticate_request(self, request: Request):
HTTPStatus.UNAUTHORIZED.value,
reason="Username or password did not match",
)
self.username = username
self.password = password
self.auth_info.username = username
self.auth_info.password = password
elif self._bearer_auth_required():
if not header.startswith(self._bearer_prefix):
log_and_raise(
Expand All @@ -80,20 +72,20 @@ def _authenticate_request(self, request: Request):
log_and_raise(
HTTPStatus.UNAUTHORIZED.value, reason="Token did not match"
)
self.token = token
self.auth_info.token = token
elif self._iguazio_auth_required():
iguazio_client = mlrun.api.utils.clients.iguazio.Client()
(
self.username,
self.session,
self.uid,
self.gids,
self.auth_info.username,
self.auth_info.session,
self.auth_info.user_id,
self.auth_info.user_group_ids,
planes,
) = iguazio_client.verify_request_session(request)
if "x-data-session-override" in request.headers:
self.data_session = request.headers["x-data-session-override"]
self.auth_info.data_session = request.headers["x-data-session-override"]
elif "data" in planes:
self.data_session = self.session
self.auth_info.data_session = self.auth_info.session

@staticmethod
def _basic_auth_required():
Expand Down
8 changes: 4 additions & 4 deletions mlrun/api/api/endpoints/artifacts.py
@@ -1,7 +1,7 @@
from http import HTTPStatus
from typing import List, Optional
from typing import List

from fastapi import APIRouter, Cookie, Depends, Query, Request
from fastapi import APIRouter, Depends, Query, Request
from fastapi.concurrency import run_in_threadpool
from sqlalchemy.orm import Session

Expand All @@ -24,7 +24,7 @@ async def store_artifact(
key: str,
tag: str = "",
iter: int = 0,
iguazio_session: Optional[str] = Cookie(None, alias="session"),
auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier),
db_session: Session = Depends(deps.get_db_session),
):
data = None
Expand All @@ -43,7 +43,7 @@ async def store_artifact(
iter=iter,
tag=tag,
project=project,
leader_session=iguazio_session,
leader_session=auth_verifier.auth_info.session,
)
return {}

Expand Down
30 changes: 15 additions & 15 deletions mlrun/api/api/endpoints/feature_store.py
@@ -1,7 +1,7 @@
from http import HTTPStatus
from typing import List, Optional

from fastapi import APIRouter, Cookie, Depends, Header, Query, Request, Response
from fastapi import APIRouter, Depends, Header, Query, Request, Response
from sqlalchemy.orm import Session

import mlrun.feature_store
Expand All @@ -23,11 +23,11 @@ def create_feature_set(
project: str,
feature_set: schemas.FeatureSet,
versioned: bool = True,
iguazio_session: Optional[str] = Cookie(None, alias="session"),
auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier),
db_session: Session = Depends(deps.get_db_session),
):
feature_set_uid = get_db().create_feature_set(
db_session, project, feature_set, versioned, iguazio_session
db_session, project, feature_set, versioned, auth_verifier.auth_info.session
)

return get_db().get_feature_set(
Expand All @@ -49,7 +49,7 @@ def store_feature_set(
reference: str,
feature_set: schemas.FeatureSet,
versioned: bool = True,
iguazio_session: Optional[str] = Cookie(None, alias="session"),
auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier),
db_session: Session = Depends(deps.get_db_session),
):
tag, uid = parse_reference(reference)
Expand All @@ -61,7 +61,7 @@ def store_feature_set(
tag,
uid,
versioned,
leader_session=iguazio_session,
leader_session=auth_verifier.auth_info.session,
)

return get_db().get_feature_set(db_session, project, name, tag=tag, uid=uid)
Expand All @@ -76,7 +76,7 @@ def patch_feature_set(
patch_mode: schemas.PatchMode = Header(
schemas.PatchMode.replace, alias=schemas.HeaderNames.patch_mode
),
iguazio_session: Optional[str] = Cookie(None, alias="session"),
auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier),
db_session: Session = Depends(deps.get_db_session),
):
tag, uid = parse_reference(reference)
Expand All @@ -88,7 +88,7 @@ def patch_feature_set(
tag,
uid,
patch_mode,
iguazio_session,
auth_verifier.auth_info.session,
)
return Response(status_code=HTTPStatus.OK.value)

Expand Down Expand Up @@ -199,15 +199,15 @@ def ingest_feature_set(
schemas.FeatureSetIngestInput
] = schemas.FeatureSetIngestInput(),
username: str = Header(None, alias="x-remote-user"),
iguazio_session: Optional[str] = Cookie(None, alias="session"),
auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier),
db_session: Session = Depends(deps.get_db_session),
):
tag, uid = parse_reference(reference)
feature_set_record = get_db().get_feature_set(db_session, project, name, tag, uid)

feature_set = mlrun.feature_store.FeatureSet.from_dict(feature_set_record.dict())
# Need to override the default rundb since we're in the server.
feature_set._override_run_db(db_session, iguazio_session)
feature_set._override_run_db(db_session, auth_verifier.auth_info.session)

data_source = data_targets = None
if ingest_parameters.source:
Expand Down Expand Up @@ -283,11 +283,11 @@ def create_feature_vector(
project: str,
feature_vector: schemas.FeatureVector,
versioned: bool = True,
iguazio_session: Optional[str] = Cookie(None, alias="session"),
auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier),
db_session: Session = Depends(deps.get_db_session),
):
feature_vector_uid = get_db().create_feature_vector(
db_session, project, feature_vector, versioned, iguazio_session
db_session, project, feature_vector, versioned, auth_verifier.auth_info.session
)

return get_db().get_feature_vector(
Expand Down Expand Up @@ -356,7 +356,7 @@ def store_feature_vector(
reference: str,
feature_vector: schemas.FeatureVector,
versioned: bool = True,
iguazio_session: Optional[str] = Cookie(None, alias="session"),
auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier),
db_session: Session = Depends(deps.get_db_session),
):
tag, uid = parse_reference(reference)
Expand All @@ -368,7 +368,7 @@ def store_feature_vector(
tag,
uid,
versioned,
leader_session=iguazio_session,
leader_session=auth_verifier.auth_info.session,
)

return get_db().get_feature_vector(db_session, project, name, uid=uid, tag=tag)
Expand All @@ -383,7 +383,7 @@ def patch_feature_vector(
patch_mode: schemas.PatchMode = Header(
schemas.PatchMode.replace, alias=schemas.HeaderNames.patch_mode
),
iguazio_session: Optional[str] = Cookie(None, alias="session"),
auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier),
db_session: Session = Depends(deps.get_db_session),
):
tag, uid = parse_reference(reference)
Expand All @@ -395,7 +395,7 @@ def patch_feature_vector(
tag,
uid,
patch_mode,
iguazio_session,
auth_verifier.auth_info.session,
)
return Response(status_code=HTTPStatus.OK.value)

Expand Down
44 changes: 19 additions & 25 deletions mlrun/api/api/endpoints/functions.py
@@ -1,17 +1,9 @@
import traceback
from distutils.util import strtobool
from http import HTTPStatus
from typing import List, Optional

from fastapi import (
APIRouter,
BackgroundTasks,
Cookie,
Depends,
Query,
Request,
Response,
)
from typing import List

from fastapi import APIRouter, BackgroundTasks, Depends, Query, Request, Response
from fastapi.concurrency import run_in_threadpool
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -40,7 +32,7 @@ async def store_function(
name: str,
tag: str = "",
versioned: bool = False,
iguazio_session: Optional[str] = Cookie(None, alias="session"),
auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier),
db_session: Session = Depends(deps.get_db_session),
):
data = None
Expand All @@ -59,7 +51,7 @@ async def store_function(
project,
tag=tag,
versioned=versioned,
leader_session=iguazio_session,
leader_session=auth_verifier.auth_info.session,
)
return {
"hash_key": hash_key,
Expand Down Expand Up @@ -111,7 +103,7 @@ def list_functions(
@router.post("/build/function/")
async def build_function(
request: Request,
iguazio_session: Optional[str] = Cookie(None, alias="session"),
auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier),
db_session: Session = Depends(deps.get_db_session),
):
data = None
Expand All @@ -128,11 +120,11 @@ async def build_function(
fn, ready = await run_in_threadpool(
_build_function,
db_session,
auth_verifier.auth_info,
function,
with_mlrun,
skip_deployed,
mlrun_version_specifier,
iguazio_session,
)
return {
"data": fn.to_dict(),
Expand All @@ -146,7 +138,7 @@ async def build_function(
async def start_function(
request: Request,
background_tasks: BackgroundTasks,
iguazio_session: Optional[str] = Cookie(None, alias="session"),
auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier),
db_session: Session = Depends(deps.get_db_session),
):
data = None
Expand All @@ -162,12 +154,12 @@ async def start_function(
background_task = await run_in_threadpool(
mlrun.api.utils.background_tasks.Handler().create_background_task,
db_session,
iguazio_session,
auth_verifier.auth_info.session,
function.metadata.project,
background_tasks,
_start_function,
function,
iguazio_session,
auth_verifier.auth_info,
)

return background_task
Expand Down Expand Up @@ -200,7 +192,7 @@ def build_status(
logs: bool = True,
last_log_timestamp: float = 0.0,
verbose: bool = False,
iguazio_session: Optional[str] = Cookie(None, alias="session"),
auth_verifier: deps.AuthVerifier = Depends(deps.AuthVerifier),
db_session: Session = Depends(deps.get_db_session),
):
fn = get_db().get_function(db_session, name, project, tag)
Expand Down Expand Up @@ -238,7 +230,7 @@ def build_status(
project,
tag,
versioned=versioned,
leader_session=iguazio_session,
leader_session=auth_verifier.auth_info.session,
)
return Response(
content=text,
Expand Down Expand Up @@ -299,7 +291,7 @@ def build_status(
project,
tag,
versioned=versioned,
leader_session=iguazio_session,
leader_session=auth_verifier.auth_info.session,
)

return Response(
Expand All @@ -316,21 +308,22 @@ def build_status(

def _build_function(
db_session,
auth_info: mlrun.api.schemas.AuthInfo,
function,
with_mlrun,
skip_deployed,
mlrun_version_specifier,
leader_session,
):
fn = None
ready = None
try:
fn = new_function(runtime=function)

run_db = get_run_db_instance(db_session, leader_session)
run_db = get_run_db_instance(db_session, auth_info.session)
fn.set_db_connection(run_db)
fn.save(versioned=False)
if fn.kind in RuntimeKinds.nuclio_runtimes():
mlrun.api.api.utils.ensure_function_has_auth_set(fn, auth_info)
deploy_nuclio_function(fn)
# deploy only start the process, the get status API is used to check readiness
ready = False
Expand Down Expand Up @@ -365,7 +358,7 @@ def _parse_start_function_body(db_session, data):
return new_function(runtime=runtime)


def _start_function(function, leader_session: Optional[str] = None):
def _start_function(function, auth_info: mlrun.api.schemas.AuthInfo):
db_session = mlrun.api.db.session.create_session()
try:
resource = runtime_resources_map.get(function.kind)
Expand All @@ -375,8 +368,9 @@ def _start_function(function, leader_session: Optional[str] = None):
reason="runtime error: 'start' not supported by this runtime",
)
try:
run_db = get_run_db_instance(db_session, leader_session)
run_db = get_run_db_instance(db_session, auth_info.session)
function.set_db_connection(run_db)
mlrun.api.api.utils.ensure_function_has_auth_set(function, auth_info)
# resp = resource["start"](fn) # TODO: handle resp?
resource["start"](function)
function.save(versioned=False)
Expand Down

0 comments on commit 897aa1d

Please sign in to comment.