Skip to content

Commit

Permalink
add task routing for celery and dask distributed runtimes
Browse files Browse the repository at this point in the history
  • Loading branch information
miraculixx committed May 18, 2019
1 parent 5753b2e commit fec34a8
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 18 deletions.
2 changes: 1 addition & 1 deletion docs/source/devguide/mixins.rst
Expand Up @@ -79,7 +79,7 @@ convenience.
class CrossValidationMixin(object):
def cross_validate(modelName, Xname, Yname, *args, **kwargs):
# get the cross validation task
task = self.runtime.task('custom.tasks.cross_validate')
task = self.task('custom.tasks.cross_validate')
return task.delay(modelName, Xname, Yname, *args, **kwargs)
Expand Down
7 changes: 4 additions & 3 deletions omegaml/backends/keras.py
@@ -1,6 +1,4 @@
import os
from keras import Sequential, Model
from keras.engine.saving import load_model
from mongoengine import GridFSProxy

from omegaml.backends import BaseModelBackend
Expand All @@ -11,6 +9,7 @@ class KerasBackend(BaseModelBackend):

@classmethod
def supports(self, obj, name, **kwargs):
from keras import Sequential, Model
return isinstance(obj, (Sequential, Model))

def put_model(self, obj, name, attributes=None):
Expand All @@ -31,6 +30,8 @@ def put_model(self, obj, name, attributes=None):
gridfile=gridfile).save()

def get_model(self, name, version=-1):
from keras.engine.saving import load_model

filename = self.model_store._get_obj_store_key(name, 'h5')
packagefname = os.path.join(self.model_store.tmppath, name)
dirname = os.path.dirname(packagefname)
Expand Down Expand Up @@ -61,7 +62,7 @@ def _fit_tpu(self, modelname, Xname, Yname=None, tpu_specs=None, **kwargs):
import tensorflow as tf
# adopted from https://www.dlology.com/blog/how-to-train-keras-model-x20-times-faster-with-tpu-for-free/
# This address identifies the TPU we'll use when configuring TensorFlow.
TPU_WORKER = 'grpc://' + os.environ.get('COLAB_TPU_ADDR','')
TPU_WORKER = 'grpc://' + os.environ.get('COLAB_TPU_ADDR', '')
tf.logging.set_verbosity(tf.logging.INFO)
model = self.get_model(modelname)
tpu_model = tf.contrib.tpu.keras_to_tpu_model(
Expand Down
1 change: 1 addition & 0 deletions omegaml/defaults.py
Expand Up @@ -44,6 +44,7 @@
#: storage backends
OMEGA_STORE_BACKENDS = {
'sklearn.joblib': 'omegaml.backends.ScikitLearnBackend',
'keras.h5': 'omegaml.backends.keras.KerasBackend',
}
#: storage mixins
OMEGA_STORE_MIXINS = [
Expand Down
9 changes: 6 additions & 3 deletions omegaml/runtimes/daskruntime.py
Expand Up @@ -12,7 +12,7 @@ class DaskTask(object):
A dask remote function wrapper mimicking a Celery task
"""

def __init__(self, fn, client, pure=True):
def __init__(self, fn, client, pure=True, **kwargs):
"""
:param fn: (function) the function to be called
:param client: (dask client) the dask client to use
Expand All @@ -22,12 +22,15 @@ def __init__(self, fn, client, pure=True):
self.client = client
self.fn = fn
self.pure = pure
self.kwargs = kwargs

def delay(self, *args, **kwargs):
"""
submit the function and execute on cluster.
"""
kwargs['pure'] = kwargs.get('pure', self.pure)
if self.kwargs:
kwargs.update(self.kwargs)
return DaskAsyncResult(self.client.submit(self.fn, *args, **kwargs))


Expand Down Expand Up @@ -90,7 +93,7 @@ def job(self, jobname):
"""
return OmegaJobProxy(jobname, runtime=self)

def task(self, name):
def task(self, name, **kwargs):
"""
retrieve the task function from the task module
Expand All @@ -107,7 +110,7 @@ def task(self, name):
func = getattr(mod, funcname)
# we pass pure=False to force dask to reevaluate the task
# http://distributed.readthedocs.io/en/latest/client.html?highlight=pure#pure-functions-by-default
return DaskTask(func, self.client, pure=False)
return DaskTask(func, self.client, pure=False, **kwargs)

def settings(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion omegaml/runtimes/mixins/gridsearch.py
Expand Up @@ -4,7 +4,7 @@ def _common_kwargs(self):
return dict(pure_python=self.pure_python)

def gridsearch(self, Xname, Yname, parameters=None, pure_python=False, **kwargs):
gs_task = self.runtime.task('omegaml.tasks.omega_gridsearch')
gs_task = self.task('omegaml.tasks.omega_gridsearch')
Xname = self._ensure_data_is_stored(Xname, prefix='_fitX')
if Yname is not None:
Yname = self._ensure_data_is_stored(Yname, prefix='_fitY')
Expand Down
16 changes: 8 additions & 8 deletions omegaml/runtimes/mixins/modelmixin.py
Expand Up @@ -31,7 +31,7 @@ def fit(self, Xname, Yname=None, **kwargs):
:param Yname: name of Y dataset or data
:return: the model (self) or the string representation (python clients)
"""
omega_fit = self.runtime.task('omegaml.tasks.omega_fit')
omega_fit = self.task('omegaml.tasks.omega_fit')
Xname = self._ensure_data_is_stored(Xname, prefix='_fitX')
if Yname is not None:
Yname = self._ensure_data_is_stored(Yname, prefix='_fitY')
Expand All @@ -54,7 +54,7 @@ def partial_fit(self, Xname, Yname=None, **kwargs):
:param Yname: name of Y dataset or data
:return: the model (self) or the string representation (python clients)
"""
omega_fit = self.runtime.task('omegaml.tasks.omega_partial_fit')
omega_fit = self.task('omegaml.tasks.omega_partial_fit')
Xname = self._ensure_data_is_stored(Xname, prefix='_fitX')
if Yname is not None:
Yname = self._ensure_data_is_stored(Yname, prefix='_fitY')
Expand All @@ -73,7 +73,7 @@ def transform(self, Xname, rName=None, **kwargs):
:return: the data returned by .transform, or the metadata of the rName
dataset if rName was given
"""
omega_transform = self.runtime.task('omegaml.tasks.omega_transform')
omega_transform = self.task('omegaml.tasks.omega_transform')
Xname = self._ensure_data_is_stored(Xname)
return omega_transform.delay(self.modelname, Xname,
rName=rName,
Expand All @@ -93,7 +93,7 @@ def fit_transform(self, Xname, Yname=None, rName=None, **kwargs):
dataset if rName was given
"""

omega_fit_transform = self.runtime.task(
omega_fit_transform = self.task(
'omegaml.tasks.omega_fit_transform')
Xname = self._ensure_data_is_stored(Xname)
if Yname is not None:
Expand All @@ -114,7 +114,7 @@ def predict(self, Xpath_or_data, rName=None, **kwargs):
:return: the data returned by .predict, or the metadata of the rName
dataset if rName was given
"""
omega_predict = self.runtime.task('omegaml.tasks.omega_predict')
omega_predict = self.task('omegaml.tasks.omega_predict')
Xname = self._ensure_data_is_stored(Xpath_or_data)
return omega_predict.delay(self.modelname, Xname, rName=rName,
**self._common_kwargs, **kwargs)
Expand All @@ -131,7 +131,7 @@ def predict_proba(self, Xpath_or_data, rName=None, **kwargs):
:return: the data returned by .predict_proba, or the metadata of the rName
dataset if rName was given
"""
omega_predict_proba = self.runtime.task(
omega_predict_proba = self.task(
'omegaml.tasks.omega_predict_proba')
Xname = self._ensure_data_is_stored(Xpath_or_data)
return omega_predict_proba.delay(self.modelname, Xname, rName=rName,
Expand All @@ -150,7 +150,7 @@ def score(self, Xname, yName, rName=None, **kwargs):
:return: the data returned by .score, or the metadata of the rName
dataset if rName was given
"""
omega_score = self.runtime.task('omegaml.tasks.omega_score')
omega_score = self.task('omegaml.tasks.omega_score')
Xname = self._ensure_data_is_stored(Xname)
yName = self._ensure_data_is_stored(yName)
return omega_score.delay(self.modelname, Xname, yName, rName=rName,
Expand All @@ -168,7 +168,7 @@ def decision_function(self, Xname, rName=None, **kwargs):
:return: the data returned by .score, or the metadata of the rName
dataset if rName was given
"""
omega_decision_function = self.runtime.task('omegaml.tasks.omega_decision_function')
omega_decision_function = self.task('omegaml.tasks.omega_decision_function')
Xname = self._ensure_data_is_stored(Xname)
return omega_decision_function.delay(self.modelname, Xname, rName=rName,
**self._common_kwargs, **kwargs)
Expand Down
31 changes: 31 additions & 0 deletions omegaml/runtimes/modelproxy.py
Expand Up @@ -52,6 +52,7 @@ def __init__(self, modelname, runtime=None):
False)
self.pure_python = self.pure_python or self._client_is_pure_python()
self.apply_mixins()
self.task_kwargs = {}

def apply_mixins(self):
"""
Expand All @@ -61,6 +62,36 @@ def apply_mixins(self):
for mixin in defaults.OMEGA_RUNTIME_MIXINS:
extend_instance(self, mixin)

def task(self, name):
"""
return the task from the runtime with requirements applied
"""
kwargs = self.task_kwargs
return self.runtime.task(name, **kwargs)

def require(self, **kwargs):
"""
specify requirements for the task execution
Use this to specify resource or routing requirements on the task call
sent to the runtime.
Args:
**kwargs: requirements specification that the runtime understands
Usage:
# celery runtime
om.runtime.model('foo').require(queue='gpu').fit(...)
# dask distributed runtime
om.runtime.model('foo').require(resources={...}).fit(...)
Returns:
self
"""
self.task_kwargs.update(kwargs)
return self

def _client_is_pure_python(self):
try:
import pandas as pd
Expand Down
31 changes: 29 additions & 2 deletions omegaml/runtimes/runtime.py
Expand Up @@ -6,6 +6,33 @@
from omegaml.util import settings


class CeleryTask(object):
"""
A thin wrapper for a Celery.Task object
This is so that we can collect common delay arguments on the
.task() call
"""

def __init__(self, task, **kwargs):
"""
Args:
task (Celery.Task): the celery task object
**kwargs (dict): optional, the kwargs to pass to apply_async
"""
self.task = task
self.kwargs = kwargs

def delay(self, *args, **kwargs):
"""
submit the task with args and kwargs to pass on
This calls task.apply_async and passes on the self.kwargs.
"""
return self.task.apply_async(args=args, kwargs=kwargs, **self.kwargs)


class OmegaRuntime(object):
"""
omegaml compute cluster gateway
Expand Down Expand Up @@ -48,7 +75,7 @@ def job(self, jobname):
"""
return OmegaJobProxy(jobname, runtime=self)

def task(self, name):
def task(self, name, **kwargs):
"""
retrieve the task function from the celery instance
Expand All @@ -57,7 +84,7 @@ def task(self, name):
import, which seems to confuse celery)
"""
# import omegapkg.tasks
return self.celeryapp.tasks.get(name)
return CeleryTask(self.celeryapp.tasks.get(name), **kwargs)

def settings(self):
"""
Expand Down

0 comments on commit fec34a8

Please sign in to comment.