Skip to content

Commit

Permalink
Fixing mpijob crd resolve from remote (#290)
Browse files Browse the repository at this point in the history
* Enable resolving mpijob crd version from remote

* resolve circular dependencies

* Fixing API not in k8s cluster bug

* fix

* fix
  • Loading branch information
Hedingber committed Jun 4, 2020
1 parent 13638b4 commit 164184a
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 53 deletions.
3 changes: 3 additions & 0 deletions mlrun/api/api/endpoints/healthz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
from fastapi import APIRouter

from mlrun.config import config
from mlrun.runtimes.utils import resolve_mpijob_crd_version

router = APIRouter()


@router.get("/healthz")
def health():
mpijob_crd_version = resolve_mpijob_crd_version(api_context=True)
return {
"version": config.version,
"namespace": config.namespace,
"docker_registry": environ.get('DEFAULT_DOCKER_REGISTRY', ''),
"remote_host": config.remote_host,
"mpijob_crd_version": mpijob_crd_version,
"ui_url": config.ui_url,
"artifact_path": config.artifact_path,
}
1 change: 1 addition & 0 deletions mlrun/db/httpdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def connect(self, secrets=None):

# get defaults from remote server
config.remote_host = config.remote_host or server_cfg.get('remote_host')
config.mpijob_crd_version = config.mpijob_crd_version or server_cfg.get('mpijob_crd_version')
config.ui_url = config.ui_url or server_cfg.get('ui_url')
config.artifact_path = config.artifact_path or server_cfg.get('artifact_path')
if 'docker_registry' in server_cfg and 'DEFAULT_DOCKER_REGISTRY' not in environ:
Expand Down
36 changes: 5 additions & 31 deletions mlrun/runtimes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
from .function import RemoteRuntime, new_model_server # noqa
from .kubejob import KubejobRuntime, KubeRuntimeHandler # noqa
from .local import HandlerRuntime, LocalRuntime # noqa
from .mpijob import MpiRuntimeV1Alpha1, MpiRuntimeV1, MPIJobCRDVersions, MpiV1RuntimeHandler, \
from .mpijob import MpiRuntimeV1Alpha1, MpiRuntimeV1, MpiV1RuntimeHandler, \
MpiV1Alpha1RuntimeHandler # noqa
from .types import MPIJobCRDVersions
from .nuclio import nuclio_init_hook
from .serving import MLModelServer
from .sparkjob import SparkRuntime, SparkRuntimeHandler # noqa
from mlrun.runtimes.utils import resolve_mpijob_crd_version


class RuntimeKinds(object):
Expand Down Expand Up @@ -65,7 +67,7 @@ def runtime_with_handlers():
def get_runtime_handler(kind: str) -> BaseRuntimeHandler:
global runtime_handler_instances_cache
if kind == RuntimeKinds.mpijob:
mpijob_crd_version = _resolve_mpi_crd_version()
mpijob_crd_version = resolve_mpijob_crd_version()
crd_version_to_runtime_handler_class = {
MPIJobCRDVersions.v1alpha1: MpiV1Alpha1RuntimeHandler,
MPIJobCRDVersions.v1: MpiV1RuntimeHandler
Expand All @@ -88,7 +90,7 @@ def get_runtime_handler(kind: str) -> BaseRuntimeHandler:

def get_runtime_class(kind: str):
if kind == RuntimeKinds.mpijob:
mpijob_crd_version = _resolve_mpi_crd_version()
mpijob_crd_version = resolve_mpijob_crd_version()
crd_version_to_runtime = {
MPIJobCRDVersions.v1alpha1: MpiRuntimeV1Alpha1,
MPIJobCRDVersions.v1: MpiRuntimeV1
Expand All @@ -104,31 +106,3 @@ def get_runtime_class(kind: str):
}

return kind_runtime_map[kind]


# resolve mpijob runtime according to the mpi-operator's supported crd-version
# if specified on mlrun config set it likewise,
# if not specified, try resolving it according to the mpi-operator, otherwise set to default
def _resolve_mpi_crd_version():

# try getting mpijob crd version from config
mpijob_crd_version = config.mpijob_crd_version

if not mpijob_crd_version:
k8s_helper = get_k8s_helper()
namespace = k8s_helper.resolve_namespace()

# set default mpijob crd version
mpijob_crd_version = MPIJobCRDVersions.default()

# try resolving according to mpi-operator that's running
res = k8s_helper.list_pods(namespace=namespace, selector='component=mpi-operator')
if len(res) > 0:
mpi_operator_pod = res[0]
mpijob_crd_version = mpi_operator_pod.metadata.labels.get('crd-version', mpijob_crd_version)

if mpijob_crd_version not in MPIJobCRDVersions.all():
raise ValueError('unsupported mpijob crd version: {}. supported versions: {}'.format(mpijob_crd_version,
MPIJobCRDVersions.all()))

return mpijob_crd_version
1 change: 0 additions & 1 deletion mlrun/runtimes/mpijob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,3 @@

from .v1 import MpiRuntimeV1, MpiV1RuntimeHandler
from .v1alpha1 import MpiRuntimeV1Alpha1, MpiV1Alpha1RuntimeHandler
from .abstract import MPIJobCRDVersions
14 changes: 0 additions & 14 deletions mlrun/runtimes/mpijob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,6 @@
from mlrun.utils import logger, get_in


class MPIJobCRDVersions(object):
v1 = 'v1'
v1alpha1 = 'v1alpha1'

@staticmethod
def all():
return [MPIJobCRDVersions.v1,
MPIJobCRDVersions.v1alpha1]

@staticmethod
def default():
return MPIJobCRDVersions.v1alpha1


class AbstractMPIJobRuntime(KubejobRuntime, abc.ABC):
kind = 'mpijob'
_is_nested = False
Expand Down
3 changes: 2 additions & 1 deletion mlrun/runtimes/mpijob/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from mlrun.execution import MLClientCtx
from mlrun.model import RunObject
from mlrun.runtimes.base import BaseRuntimeHandler
from mlrun.runtimes.mpijob.abstract import AbstractMPIJobRuntime, MPIJobCRDVersions
from mlrun.runtimes.types import MPIJobCRDVersions
from mlrun.runtimes.mpijob.abstract import AbstractMPIJobRuntime
from mlrun.utils import update_in, get_in


Expand Down
3 changes: 2 additions & 1 deletion mlrun/runtimes/mpijob/v1alpha1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from mlrun.execution import MLClientCtx
from mlrun.model import RunObject
from mlrun.runtimes.base import BaseRuntimeHandler
from mlrun.runtimes.mpijob.abstract import AbstractMPIJobRuntime, MPIJobCRDVersions
from mlrun.runtimes.types import MPIJobCRDVersions
from mlrun.runtimes.mpijob.abstract import AbstractMPIJobRuntime
from mlrun.utils import update_in, get_in


Expand Down
12 changes: 12 additions & 0 deletions mlrun/runtimes/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class MPIJobCRDVersions(object):
v1 = 'v1'
v1alpha1 = 'v1alpha1'

@staticmethod
def all():
return [MPIJobCRDVersions.v1,
MPIJobCRDVersions.v1alpha1]

@staticmethod
def default():
return MPIJobCRDVersions.v1alpha1
59 changes: 54 additions & 5 deletions mlrun/runtimes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,20 @@
import json
import os
from copy import deepcopy
from io import StringIO
from sys import stderr

import pandas as pd
from io import StringIO
from ..utils import logger
from ..config import config
from kubernetes import client

from mlrun.db import get_run_db
from mlrun.k8s_utils import get_k8s_helper
from mlrun.runtimes.types import MPIJobCRDVersions
from .generators import selector
from ..utils import get_in
from ..artifacts import TableArtifact
from kubernetes import client
from ..config import config
from ..utils import get_in
from ..utils import logger


class RunError(Exception):
Expand All @@ -47,6 +52,50 @@ def set(self, context):
global_context = _ContextStore()


cached_mpijob_crd_version = None


# resolve mpijob runtime according to the mpi-operator's supported crd-version
# if specified on mlrun config set it likewise,
# if not specified, try resolving it according to the mpi-operator, otherwise set to default
# since this is a heavy operation (sending requests to k8s/API), and it's unlikely that the crd version
# will change in any context - cache it
def resolve_mpijob_crd_version(api_context=False):
global cached_mpijob_crd_version
if not cached_mpijob_crd_version:
# try getting mpijob crd version from config
mpijob_crd_version = config.mpijob_crd_version

if not mpijob_crd_version:
# set default mpijob crd version
mpijob_crd_version = MPIJobCRDVersions.default()

in_k8s_cluster = get_k8s_helper(init=False).is_running_inside_kubernetes_cluster()
if in_k8s_cluster:
k8s_helper = get_k8s_helper()
namespace = k8s_helper.resolve_namespace()

# try resolving according to mpi-operator that's running
res = k8s_helper.list_pods(namespace=namespace, selector='component=mpi-operator')
if len(res) > 0:
mpi_operator_pod = res[0]
mpijob_crd_version = mpi_operator_pod.metadata.labels.get('crd-version', mpijob_crd_version)
elif not in_k8s_cluster and not api_context:
# connect will populate the config from the server config
# TODO: something nicer
get_run_db().connect()
if not config.mpijob_crd_version:
raise Exception('Server does not have configured mpijob crd version')
mpijob_crd_version = config.mpijob_crd_version

if mpijob_crd_version not in MPIJobCRDVersions.all():
raise ValueError('unsupported mpijob crd version: {}. '
'supported versions: {}'.format(mpijob_crd_version, MPIJobCRDVersions.all()))
cached_mpijob_crd_version = mpijob_crd_version

return cached_mpijob_crd_version


def calc_hash(func, tag=''):
# remove tag, hash, date from calculation
tag = tag or func.metadata.tag
Expand Down

0 comments on commit 164184a

Please sign in to comment.