Skip to content

Commit

Permalink
expose group2ctx to module (apache#8539)
Browse files Browse the repository at this point in the history
* expose group2ctx to module

* Update test_module.py

* address comments

* update
  • Loading branch information
ZiyueHuang authored and eric-haibin-lin committed Dec 3, 2017
1 parent 6720736 commit eef13cf
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 7 deletions.
9 changes: 6 additions & 3 deletions python/mxnet/module/bucketing_module.py
Expand Up @@ -52,10 +52,12 @@ class BucketingModule(BaseModule):
state_names : list of str
States are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by set_states()
group2ctxs : list of dict of str to context
Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
"""
def __init__(self, sym_gen, default_bucket_key=None, logger=logging,
context=ctx.cpu(), work_load_list=None,
fixed_param_names=None, state_names=None):
fixed_param_names=None, state_names=None, group2ctxs=None):
super(BucketingModule, self).__init__(logger=logger)

assert default_bucket_key is not None
Expand All @@ -77,6 +79,7 @@ def __init__(self, sym_gen, default_bucket_key=None, logger=logging,
self._state_names = state_names
self._context = context
self._work_load_list = work_load_list
self._group2ctxs = group2ctxs

self._buckets = {}
self._curr_module = None
Expand Down Expand Up @@ -319,7 +322,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
module = Module(symbol, data_names, label_names, logger=self.logger,
context=self._context, work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
state_names=self._state_names)
state_names=self._state_names, group2ctxs=self._group2ctxs)
module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
force_rebind=False, shared_module=None, grad_req=grad_req)
self._curr_module = module
Expand Down Expand Up @@ -349,7 +352,7 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
logger=self.logger, context=self._context,
work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
state_names=self._state_names)
state_names=self._state_names, group2ctxs=self._group2ctxs)
module.bind(data_shapes, label_shapes, self._curr_module.for_training,
self._curr_module.inputs_need_grad,
force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
Expand Down
12 changes: 10 additions & 2 deletions python/mxnet/module/executor_group.py
Expand Up @@ -139,17 +139,23 @@ class DataParallelExecutorGroup(object):
Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
(default to 'write').
Can be specified globally (str) or for each argument (list, dict).
group2ctxs : list of dict of str to context
Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
"""
def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names,
for_training, inputs_need_grad, shared_group=None, logger=logging,
fixed_param_names=None, grad_req='write', state_names=None):
fixed_param_names=None, grad_req='write', state_names=None, group2ctxs=None):
self.param_names = param_names
self.arg_names = symbol.list_arguments()
self.aux_names = symbol.list_auxiliary_states()

self.symbol = symbol
self.contexts = contexts
self.workload = workload
if group2ctxs is None:
group2ctxs = [None] * len(self.contexts)
assert len(group2ctxs) == len(self.contexts)
self.group2ctxs = group2ctxs

self.for_training = for_training
self.inputs_need_grad = inputs_need_grad
Expand Down Expand Up @@ -597,9 +603,11 @@ def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
if label_shapes is not None:
input_types.update({x.name: x.dtype for x in label_shapes})

group2ctx = self.group2ctxs[i]

executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req,
type_dict=input_types, shared_arg_names=self.param_names,
shared_exec=shared_exec,
shared_exec=shared_exec, group2ctx=group2ctx,
shared_buffer=shared_data_arrays, **input_shapes)
self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1])
return executor
Expand Down
8 changes: 6 additions & 2 deletions python/mxnet/module/module.py
Expand Up @@ -59,10 +59,12 @@ class Module(BaseModule):
state_names : list of str
states are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by `set_states()`.
group2ctxs : list of dict of str to context
Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
"""
def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
logger=logging, context=ctx.cpu(), work_load_list=None,
fixed_param_names=None, state_names=None):
fixed_param_names=None, state_names=None, group2ctxs=None):
super(Module, self).__init__(logger=logger)

if isinstance(context, ctx.Context):
Expand All @@ -73,6 +75,8 @@ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
assert len(work_load_list) == len(self._context)
self._work_load_list = work_load_list

self._group2ctxs = group2ctxs

self._symbol = symbol

data_names = list(data_names) if data_names is not None else []
Expand Down Expand Up @@ -413,7 +417,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
for_training, inputs_need_grad,
shared_group, logger=self.logger,
fixed_param_names=self._fixed_param_names,
grad_req=grad_req,
grad_req=grad_req, group2ctxs=self._group2ctxs,
state_names=self._state_names)
self._total_exec_bytes = self._exec_group._total_exec_bytes
if shared_module is not None:
Expand Down
57 changes: 57 additions & 0 deletions tests/python/unittest/test_module.py
Expand Up @@ -70,6 +70,63 @@ def test_module_input_grads():
assert np.all(c_grad == 3), c_grad


def test_module_ctx_group():
with mx.AttrScope(ctx_group='dev1'):
a = mx.symbol.Variable('a')
a = a * 2
with mx.AttrScope(ctx_group='dev2'):
b = mx.symbol.Variable('b')
c = a + b
shape = (2, 5)
mod1 = mx.mod.Module(c, context=[mx.cpu(0)], data_names=['a', 'b'], label_names=None,
group2ctxs=[{'dev1':mx.cpu(1),'dev2':mx.cpu(2)}])
mod1.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
mod1.init_params()
mod1.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
mod1.backward([mx.nd.ones(shape)])
mod1_input_grads = mod1.get_input_grads()

mod2 = mx.mod.Module(c, data_names=['a', 'b'], label_names=None)
mod2.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
mod2.init_params()
mod2.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
mod2.backward([mx.nd.ones(shape)])
mod2_input_grads = mod2.get_input_grads()

assert np.all(mod1_input_grads[0].asnumpy() == mod2_input_grads[0].asnumpy())
assert np.all(mod1_input_grads[1].asnumpy() == mod2_input_grads[1].asnumpy())


def test_bucket_module_ctx_group():
num_hidden = 10
batch_size = 5
def sym_gen(seq_len):
with mx.AttrScope(ctx_group='dev1'):
data = mx.symbol.Variable('data')
weight = mx.symbol.Variable('dev1_weight')
bias = mx.symbol.Variable('dev1_bias')
fc = data
for i in range(seq_len):
fc = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias,
name='dev1_fc_%d' % i, num_hidden=num_hidden)
with mx.AttrScope(ctx_group='dev2'):
label = mx.symbol.Variable('label')
weight = mx.symbol.Variable('dev2_weight')
bias = mx.symbol.Variable('dev2_bias')
for i in range(seq_len):
fc = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias,
name='dev2_fc_%d' % i, num_hidden=num_hidden)
sym = mx.symbol.SoftmaxOutput(fc, label, name='softmax')

return sym, ('data',), ('label',)

mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10, context=[mx.cpu(0)],
group2ctxs=[{'dev1':mx.cpu(1), 'dev2':mx.cpu(2)}])
mod.bind(data_shapes=[['data', (batch_size, num_hidden)]],
label_shapes=[['label', (batch_size,)]],
for_training=True, inputs_need_grad=True)
assert(mod.binded)

def test_module_layout():
sym = mx.sym.Variable('data')
sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC')
Expand Down

0 comments on commit eef13cf

Please sign in to comment.