From e6b063b8e3b51cc2d7ccb8a639e5f966448c08a9 Mon Sep 17 00:00:00 2001 From: "Bradley M. Froehle" Date: Wed, 22 Aug 2012 13:23:20 -0700 Subject: [PATCH] Parallel: Support get/set of nested objects in view (e.g. dv['a.b']) --- IPython/parallel/tests/test_view.py | 18 ++++++++++++++++++ IPython/parallel/util.py | 21 ++++++++++++--------- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/IPython/parallel/tests/test_view.py b/IPython/parallel/tests/test_view.py index 3458f730534..3858d2c90a4 100644 --- a/IPython/parallel/tests/test_view.py +++ b/IPython/parallel/tests/test_view.py @@ -675,4 +675,22 @@ def test_map_ref(self): drank = amr.get(5) self.assertEqual(drank, [ r*2 for r in ranks ]) + def test_nested_getitem_setitem(self): + """get and set with view['a.b']""" + view = self.client[-1] + view.execute('\n'.join([ + 'class A(object): pass', + 'a = A()', + 'a.b = 128', + ]), block=True) + ra = pmod.Reference('a') + + r = view.apply_sync(lambda x: x.b, ra) + self.assertEqual(r, 128) + self.assertEqual(view['a.b'], 128) + + view['a.b'] = 0 + r = view.apply_sync(lambda x: x.b, ra) + self.assertEqual(r, 0) + self.assertEqual(view['a.b'], 0) diff --git a/IPython/parallel/util.py b/IPython/parallel/util.py index 0e02fed79f1..352dc819a81 100644 --- a/IPython/parallel/util.py +++ b/IPython/parallel/util.py @@ -229,21 +229,24 @@ def interactive(f): @interactive def _push(**ns): """helper method for implementing `client.push` via `client.apply`""" - globals().update(ns) + user_ns = globals() + tmp = '_IP_PUSH_TMP_' + while tmp in user_ns: + tmp = tmp + '_' + try: + for name, value in ns.iteritems(): + user_ns[tmp] = value + exec "%s = %s" % (name, tmp) in user_ns + finally: + user_ns.pop(tmp, None) @interactive def _pull(keys): """helper method for implementing `client.pull` via `client.apply`""" - user_ns = globals() if isinstance(keys, (list,tuple, set)): - for key in keys: - if key not in user_ns: - raise NameError("name '%s' is not defined"%key) - return map(user_ns.get, keys) + return map(lambda key: eval(key, globals()), keys) else: - if keys not in user_ns: - raise NameError("name '%s' is not defined"%keys) - return user_ns.get(keys) + return eval(keys, globals()) @interactive def _execute(code):