Skip to content

Commit

Permalink
adding an option to pass a callback function while initializing remot…
Browse files Browse the repository at this point in the history
…e_pyro_server in regression framework
  • Loading branch information
gesta81 committed Aug 29, 2014
1 parent 903e1bd commit 975ef58
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 14 deletions.
7 changes: 7 additions & 0 deletions bin/mdf_pyro_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import sys
import logging
import pickle
import marshal
import types

_startup_data = None
if __name__ == "__main__":
Expand All @@ -18,6 +20,11 @@

for modulename in _startup_data.get("modules", []):
__import__(modulename)
init_func_s = _startup_data.get("init_func", None)
if init_func_s is not None:
init_func_code = marshal.loads(init_func_s)
init_func = types.FunctionType(init_func_code, globals(), "_mdf_pyro_server_custom_init_func")
init_func(_startup_data)

# these imports are deliberately after the --fork code as sys.path could be modified
import mdf.remote
Expand Down
51 changes: 37 additions & 14 deletions mdf/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import getpass
import logging
import atexit
import marshal
import pkg_resources
import tempfile
import types
Expand Down Expand Up @@ -83,7 +84,8 @@ def _fh_redirect(fh_in, fh_out, prefix):
break
fh_out.write(prefix + " : " + line.strip("\r\n") + "\n")

def _start_pyro_subprocess(python_exe, side, modulenames=[]):
def _start_pyro_subprocess(python_exe, side, modulenames=[],
init_func=None, startup_data={}):
"""
starts a mdf.remote pyro server and returns
a Pyro Proxy object.
Expand All @@ -93,16 +95,20 @@ def _start_pyro_subprocess(python_exe, side, modulenames=[]):
side is for information and is "LHS" or "RHS"
"""
# dictionary of settings to be passed to the child process
start_data = {
_start_data = {
"log_level" : logging.getLogger().getEffectiveLevel(),
"modules" : modulenames,
}
if init_func is not None:
_start_data["init_func"] = marshal.dumps(init_func.func_code)
_start_data.update(startup_data)


if python_exe is None:
# use the current interpreter and environment
python_exe = sys.executable
env = os.environ
start_data["pythonpath"] = sys.path
_start_data["pythonpath"] = sys.path

# get the script from the metadata and write it to a tempfile
dist = pkg_resources.require("mdf")[0]
Expand Down Expand Up @@ -156,7 +162,7 @@ def _start_pyro_subprocess(python_exe, side, modulenames=[]):
stderr_thread.start()

# send the data to the new process and get the result from the pipe
child_process.stdin.write(pickle.dumps(start_data))
child_process.stdin.write(pickle.dumps(_start_data))
child_process.stdin.close()

# read the URI from the child stdout
Expand All @@ -180,13 +186,15 @@ def _start_pyro_subprocess(python_exe, side, modulenames=[]):

return Pyro4.Proxy(uri)

def _get_context(virtualenv, ctx, side, executable=None, modulenames=[]):
def _get_context(virtualenv, ctx, side, executable=None, modulenames=[],
init_func=None, startup_data={}):
# get the executables from the virtualenvs
if executable is None and virtualenv is not None:
executable = _python_exes[sys.platform] % virtualenv

# use subprocess to start new processes using the virtualenv python.exe
server = _start_pyro_subprocess(executable, side, modulenames)
server = _start_pyro_subprocess(executable, side, modulenames,
init_func=init_func, startup_data=startup_data)

# create the remote context and return
ctx = server.get_remote_context(ctx)
Expand All @@ -196,7 +204,8 @@ def _get_context(virtualenv, ctx, side, executable=None, modulenames=[]):
def get_contexts(lhs_virtualenv, rhs_virtualenv,
lhs_executable=None, rhs_executable=None,
lhs_modulenames=[], rhs_modulenames=[],
ctx=None):
ctx=None,
init_func=None, startup_data={}):
"""
returns a tuple of remote contexts using different
virtualenvs or python executables.
Expand All @@ -212,7 +221,11 @@ def get_contexts(lhs_virtualenv, rhs_virtualenv,
instances and used as the base context. For that to work
it must only contain nodes that are available in both
environments.
"""
init_func is a function which will be called while starting mdf_pyro_server
startup_data - dict of data to be passed to init_func
Note: startup_data is extended with modules, pythonpath etc before passing to init_func
"""
# get the executables from the virtualenvs
if lhs_executable is None and lhs_virtualenv is not None:
lhs_executable = _python_exes[sys.platform] % lhs_virtualenv
Expand All @@ -226,11 +239,13 @@ def get_contexts(lhs_virtualenv, rhs_virtualenv,
# use subprocess to start two new processes using the virtualenv python.exe
lhs_promise = Pyro4.Future(_start_pyro_subprocess)(lhs_executable,
side="LHS",
modulenames=lhs_modulenames)
modulenames=lhs_modulenames,
init_func=init_func, startup_data=startup_data)

rhs_promise = Pyro4.Future(_start_pyro_subprocess)(rhs_executable,
side="RHS",
modulenames=rhs_modulenames)
modulenames=rhs_modulenames,
init_func=init_func, startup_data=startup_data)

lhs_server = lhs_promise.value
rhs_server = rhs_promise.value
Expand All @@ -245,7 +260,8 @@ def get_contexts(lhs_virtualenv, rhs_virtualenv,
return lhs_ctx, rhs_ctx

def run(date_range, differs, lhs, rhs, filter=None, ctx=None,
lhs_modulenames=[], rhs_modulenames=[]):
lhs_modulenames=[], rhs_modulenames=[],
init_func=None, startup_data={}):
"""
evaluates the 'differ' objects in two contexts, lhs and rhs.
Expand All @@ -266,6 +282,10 @@ def run(date_range, differs, lhs, rhs, filter=None, ctx=None,
lhs_modulenames and rhs_modulenames are list of modules that
need to be imported on the remote sides.
init_func is a function which will be called while starting mdf_pyro_server
startup_data - dict of data to be passed to init_func
Note: startup_data is extended with modules, pythonpath etc before passing to init_func
"""
if ctx is None:
ctx = MDFContext()
Expand All @@ -278,15 +298,18 @@ def run(date_range, differs, lhs, rhs, filter=None, ctx=None,
# get both remote contexts from the virtual env name
lhs, rhs = get_contexts(lhs, rhs, ctx=ctx,
lhs_modulenames=lhs_modulenames,
rhs_modulenames=rhs_modulenames)
rhs_modulenames=rhs_modulenames,
init_func=init_func, startup_data=startup_data)
shutdown_lhs = shutdown_rhs = True
elif isinstance(lhs, (basestring, NoneType)):
# only need the lhs, rhs must already be a remote ctx
lhs = _get_context(lhs, ctx, side="LHS", modulenames=lhs_modulenames)
lhs = _get_context(lhs, ctx, side="LHS", modulenames=lhs_modulenames,
init_func=init_func, startup_data=startup_data)
shutdown_lhs = True
elif isinstance(rhs, (basestring, NoneType)):
# only need the rhs, lhs must already be a remote ctx
rhs = _get_context(rhs, ctx, side="RHS", modulenames=rhs_modulenames)
rhs = _get_context(rhs, ctx, side="RHS", modulenames=rhs_modulenames,
init_func=init_func, startup_data=startup_data)
shutdown_rhs = True

lhs_ctx, rhs_ctx = lhs, rhs
Expand Down
20 changes: 20 additions & 0 deletions mdf/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
def pid_test():
return os.getpid()

# used in test_regression_remnote_server_init
startup_data = {"cfg":{"paramA":"A"}}
def remote_server_init_func(startup_data):
"""
startup_data is a dict constructed by _start_pyro_subprocess
which will be passed to this callback function on the remote process.
startup_data will contain additional startup_data passed to mdf.regression.[get_contexts|run]
"""
_cfg = startup_data["cfg"]
assert _cfg["paramA"], "A"

class RemoteTest(unittest.TestCase):

def test_regression_contexts(self):
Expand All @@ -27,6 +38,15 @@ def test_regression_contexts(self):

self.assertNotEqual(lhs_pid, rhs_pid)

def test_regression_remnote_server_init_func(self):
"""
simple test that creates two subprocesses and checks the
pids are different
"""
lhs, rhs = mdf.regression.get_contexts(None, None,
init_func=remote_server_init_func,
startup_data=startup_data)

def test_df_differ(self):
"""
tests the DataFrameDiffer
Expand Down

0 comments on commit 975ef58

Please sign in to comment.