diff --git a/docs/source/devguide/mixins.rst b/docs/source/devguide/mixins.rst index 1a548b26..695f177a 100644 --- a/docs/source/devguide/mixins.rst +++ b/docs/source/devguide/mixins.rst @@ -79,7 +79,7 @@ convenience. class CrossValidationMixin(object): def cross_validate(modelName, Xname, Yname, *args, **kwargs): # get the cross validation task - task = self.runtime.task('custom.tasks.cross_validate') + task = self.task('custom.tasks.cross_validate') return task.delay(modelName, Xname, Yname, *args, **kwargs) diff --git a/omegaml/backends/keras.py b/omegaml/backends/keras.py index 9127e62a..f235da01 100644 --- a/omegaml/backends/keras.py +++ b/omegaml/backends/keras.py @@ -1,6 +1,4 @@ import os -from keras import Sequential, Model -from keras.engine.saving import load_model from mongoengine import GridFSProxy from omegaml.backends import BaseModelBackend @@ -11,6 +9,7 @@ class KerasBackend(BaseModelBackend): @classmethod def supports(self, obj, name, **kwargs): + from keras import Sequential, Model return isinstance(obj, (Sequential, Model)) def put_model(self, obj, name, attributes=None): @@ -31,6 +30,8 @@ def put_model(self, obj, name, attributes=None): gridfile=gridfile).save() def get_model(self, name, version=-1): + from keras.engine.saving import load_model + filename = self.model_store._get_obj_store_key(name, 'h5') packagefname = os.path.join(self.model_store.tmppath, name) dirname = os.path.dirname(packagefname) @@ -61,7 +62,7 @@ def _fit_tpu(self, modelname, Xname, Yname=None, tpu_specs=None, **kwargs): import tensorflow as tf # adopted from https://www.dlology.com/blog/how-to-train-keras-model-x20-times-faster-with-tpu-for-free/ # This address identifies the TPU we'll use when configuring TensorFlow. - TPU_WORKER = 'grpc://' + os.environ.get('COLAB_TPU_ADDR','') + TPU_WORKER = 'grpc://' + os.environ.get('COLAB_TPU_ADDR', '') tf.logging.set_verbosity(tf.logging.INFO) model = self.get_model(modelname) tpu_model = tf.contrib.tpu.keras_to_tpu_model( diff --git a/omegaml/defaults.py b/omegaml/defaults.py index 8ca649b1..36b91aa3 100644 --- a/omegaml/defaults.py +++ b/omegaml/defaults.py @@ -44,6 +44,7 @@ #: storage backends OMEGA_STORE_BACKENDS = { 'sklearn.joblib': 'omegaml.backends.ScikitLearnBackend', + 'keras.h5': 'omegaml.backends.keras.KerasBackend', } #: storage mixins OMEGA_STORE_MIXINS = [ diff --git a/omegaml/runtimes/daskruntime.py b/omegaml/runtimes/daskruntime.py index d8959321..85a74345 100644 --- a/omegaml/runtimes/daskruntime.py +++ b/omegaml/runtimes/daskruntime.py @@ -12,7 +12,7 @@ class DaskTask(object): A dask remote function wrapper mimicking a Celery task """ - def __init__(self, fn, client, pure=True): + def __init__(self, fn, client, pure=True, **kwargs): """ :param fn: (function) the function to be called :param client: (dask client) the dask client to use @@ -22,12 +22,15 @@ def __init__(self, fn, client, pure=True): self.client = client self.fn = fn self.pure = pure + self.kwargs = kwargs def delay(self, *args, **kwargs): """ submit the function and execute on cluster. """ kwargs['pure'] = kwargs.get('pure', self.pure) + if self.kwargs: + kwargs.update(self.kwargs) return DaskAsyncResult(self.client.submit(self.fn, *args, **kwargs)) @@ -90,7 +93,7 @@ def job(self, jobname): """ return OmegaJobProxy(jobname, runtime=self) - def task(self, name): + def task(self, name, **kwargs): """ retrieve the task function from the task module @@ -107,7 +110,7 @@ def task(self, name): func = getattr(mod, funcname) # we pass pure=False to force dask to reevaluate the task # http://distributed.readthedocs.io/en/latest/client.html?highlight=pure#pure-functions-by-default - return DaskTask(func, self.client, pure=False) + return DaskTask(func, self.client, pure=False, **kwargs) def settings(self): """ diff --git a/omegaml/runtimes/mixins/gridsearch.py b/omegaml/runtimes/mixins/gridsearch.py index 5c7f3492..0d7092a0 100644 --- a/omegaml/runtimes/mixins/gridsearch.py +++ b/omegaml/runtimes/mixins/gridsearch.py @@ -4,7 +4,7 @@ def _common_kwargs(self): return dict(pure_python=self.pure_python) def gridsearch(self, Xname, Yname, parameters=None, pure_python=False, **kwargs): - gs_task = self.runtime.task('omegaml.tasks.omega_gridsearch') + gs_task = self.task('omegaml.tasks.omega_gridsearch') Xname = self._ensure_data_is_stored(Xname, prefix='_fitX') if Yname is not None: Yname = self._ensure_data_is_stored(Yname, prefix='_fitY') diff --git a/omegaml/runtimes/mixins/modelmixin.py b/omegaml/runtimes/mixins/modelmixin.py index dec969c7..24d7588b 100644 --- a/omegaml/runtimes/mixins/modelmixin.py +++ b/omegaml/runtimes/mixins/modelmixin.py @@ -31,7 +31,7 @@ def fit(self, Xname, Yname=None, **kwargs): :param Yname: name of Y dataset or data :return: the model (self) or the string representation (python clients) """ - omega_fit = self.runtime.task('omegaml.tasks.omega_fit') + omega_fit = self.task('omegaml.tasks.omega_fit') Xname = self._ensure_data_is_stored(Xname, prefix='_fitX') if Yname is not None: Yname = self._ensure_data_is_stored(Yname, prefix='_fitY') @@ -54,7 +54,7 @@ def partial_fit(self, Xname, Yname=None, **kwargs): :param Yname: name of Y dataset or data :return: the model (self) or the string representation (python clients) """ - omega_fit = self.runtime.task('omegaml.tasks.omega_partial_fit') + omega_fit = self.task('omegaml.tasks.omega_partial_fit') Xname = self._ensure_data_is_stored(Xname, prefix='_fitX') if Yname is not None: Yname = self._ensure_data_is_stored(Yname, prefix='_fitY') @@ -73,7 +73,7 @@ def transform(self, Xname, rName=None, **kwargs): :return: the data returned by .transform, or the metadata of the rName dataset if rName was given """ - omega_transform = self.runtime.task('omegaml.tasks.omega_transform') + omega_transform = self.task('omegaml.tasks.omega_transform') Xname = self._ensure_data_is_stored(Xname) return omega_transform.delay(self.modelname, Xname, rName=rName, @@ -93,7 +93,7 @@ def fit_transform(self, Xname, Yname=None, rName=None, **kwargs): dataset if rName was given """ - omega_fit_transform = self.runtime.task( + omega_fit_transform = self.task( 'omegaml.tasks.omega_fit_transform') Xname = self._ensure_data_is_stored(Xname) if Yname is not None: @@ -114,7 +114,7 @@ def predict(self, Xpath_or_data, rName=None, **kwargs): :return: the data returned by .predict, or the metadata of the rName dataset if rName was given """ - omega_predict = self.runtime.task('omegaml.tasks.omega_predict') + omega_predict = self.task('omegaml.tasks.omega_predict') Xname = self._ensure_data_is_stored(Xpath_or_data) return omega_predict.delay(self.modelname, Xname, rName=rName, **self._common_kwargs, **kwargs) @@ -131,7 +131,7 @@ def predict_proba(self, Xpath_or_data, rName=None, **kwargs): :return: the data returned by .predict_proba, or the metadata of the rName dataset if rName was given """ - omega_predict_proba = self.runtime.task( + omega_predict_proba = self.task( 'omegaml.tasks.omega_predict_proba') Xname = self._ensure_data_is_stored(Xpath_or_data) return omega_predict_proba.delay(self.modelname, Xname, rName=rName, @@ -150,7 +150,7 @@ def score(self, Xname, yName, rName=None, **kwargs): :return: the data returned by .score, or the metadata of the rName dataset if rName was given """ - omega_score = self.runtime.task('omegaml.tasks.omega_score') + omega_score = self.task('omegaml.tasks.omega_score') Xname = self._ensure_data_is_stored(Xname) yName = self._ensure_data_is_stored(yName) return omega_score.delay(self.modelname, Xname, yName, rName=rName, @@ -168,7 +168,7 @@ def decision_function(self, Xname, rName=None, **kwargs): :return: the data returned by .score, or the metadata of the rName dataset if rName was given """ - omega_decision_function = self.runtime.task('omegaml.tasks.omega_decision_function') + omega_decision_function = self.task('omegaml.tasks.omega_decision_function') Xname = self._ensure_data_is_stored(Xname) return omega_decision_function.delay(self.modelname, Xname, rName=rName, **self._common_kwargs, **kwargs) diff --git a/omegaml/runtimes/modelproxy.py b/omegaml/runtimes/modelproxy.py index ef71590d..11a5faa0 100644 --- a/omegaml/runtimes/modelproxy.py +++ b/omegaml/runtimes/modelproxy.py @@ -52,6 +52,7 @@ def __init__(self, modelname, runtime=None): False) self.pure_python = self.pure_python or self._client_is_pure_python() self.apply_mixins() + self.task_kwargs = {} def apply_mixins(self): """ @@ -61,6 +62,36 @@ def apply_mixins(self): for mixin in defaults.OMEGA_RUNTIME_MIXINS: extend_instance(self, mixin) + def task(self, name): + """ + return the task from the runtime with requirements applied + """ + kwargs = self.task_kwargs + return self.runtime.task(name, **kwargs) + + def require(self, **kwargs): + """ + specify requirements for the task execution + + Use this to specify resource or routing requirements on the task call + sent to the runtime. + + Args: + **kwargs: requirements specification that the runtime understands + + Usage: + # celery runtime + om.runtime.model('foo').require(queue='gpu').fit(...) + + # dask distributed runtime + om.runtime.model('foo').require(resources={...}).fit(...) + + Returns: + self + """ + self.task_kwargs.update(kwargs) + return self + def _client_is_pure_python(self): try: import pandas as pd diff --git a/omegaml/runtimes/runtime.py b/omegaml/runtimes/runtime.py index 8c52d33e..6822fb65 100644 --- a/omegaml/runtimes/runtime.py +++ b/omegaml/runtimes/runtime.py @@ -6,6 +6,33 @@ from omegaml.util import settings +class CeleryTask(object): + """ + A thin wrapper for a Celery.Task object + + This is so that we can collect common delay arguments on the + .task() call + """ + + def __init__(self, task, **kwargs): + """ + + Args: + task (Celery.Task): the celery task object + **kwargs (dict): optional, the kwargs to pass to apply_async + """ + self.task = task + self.kwargs = kwargs + + def delay(self, *args, **kwargs): + """ + submit the task with args and kwargs to pass on + + This calls task.apply_async and passes on the self.kwargs. + """ + return self.task.apply_async(args=args, kwargs=kwargs, **self.kwargs) + + class OmegaRuntime(object): """ omegaml compute cluster gateway @@ -48,7 +75,7 @@ def job(self, jobname): """ return OmegaJobProxy(jobname, runtime=self) - def task(self, name): + def task(self, name, **kwargs): """ retrieve the task function from the celery instance @@ -57,7 +84,7 @@ def task(self, name): import, which seems to confuse celery) """ # import omegapkg.tasks - return self.celeryapp.tasks.get(name) + return CeleryTask(self.celeryapp.tasks.get(name), **kwargs) def settings(self): """