Skip to content

Commit

Permalink
add fm test. fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed May 22, 2017
1 parent 5f19131 commit a28274a
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _update_params(param_arrays, grad_arrays, updater, num_device,
# cast storage type if stype doesn't match
if param_names is not None and param_names[i] in stype_dict:
for i, grad in enumerate(grad_list):
stype = stype_dict[name]
stype = stype_dict[param_names[i]]
if grad_list[i].storage_type != stype:
grad_list[i] = nd.cast_storage(grad, stype)
index = i
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/module/bucketing_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,13 +394,13 @@ def backward(self, out_grads=None):
assert self.binded and self.params_initialized
self._curr_module.backward(out_grads=out_grads)

def update(self):
def update(self, storage_type_dict=None):
"""Update parameters according to installed optimizer and the gradient computed
in the previous forward-backward cycle.
"""
assert self.binded and self.params_initialized and self.optimizer_initialized
self._params_dirty = True
self._curr_module.update()
self._curr_module.update(storage_type_dict=storage_type_dict)

def get_outputs(self, merge_multi_context=True):
"""Get outputs from a previous forward computation.
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,9 +925,11 @@ def as_in_context(self, context):
return self.copyto(context)

def to_csr(self):
# pylint: disable=undefined-variable
return cast_storage(self, storage_type='csr')

def to_rsp(self):
# pylint: disable=undefined-variable
return cast_storage(self, storage_type='row_sparse')

def onehot_encode(indices, out):
Expand Down
1 change: 0 additions & 1 deletion python/mxnet/sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def __pow__(self, other):
def __rpow__(self, other):
raise Exception('Not implemented for SparseND yet!')


def __getstate__(self):
raise Exception('Not implemented for SparseND yet!')

Expand Down
1 change: 0 additions & 1 deletion src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,6 @@ void SliceCsrImpl(const SliceParam &param, const OpContext& ctx,
// assume idx indptr share the same type
MSHADOW_INT_TYPE_SWITCH(in.aux_type(kIndPtr), IType, {
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {

auto in_indptr = in.aux_data(kIndPtr).dptr<IType>();
auto out_indptr = out.aux_data(kIndPtr).dptr<IType>();
SliceCsrIndPtrImpl<cpu, IType>(begin, end, ctx.run_ctx, in_indptr, out_indptr);
Expand Down
68 changes: 68 additions & 0 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import mxnet as mx
import mxnet.ndarray as nd
from mxnet.test_utils import *
import numpy as np
from functools import reduce
import numpy.random as rnd
import scipy

def test_module_dtype():
dtype = np.float16
Expand Down Expand Up @@ -255,6 +258,70 @@ def mean_abs(x):
break
assert(mon_result_counts == [2, 2, 1, 6, 6, 4])

def test_fm_module():
def fm_model(k, feature_dim, storage_type='default_storage'):
initializer = mx.initializer.Normal(sigma=0.01)
x = mx.symbol.Variable("data", storage_type=storage_type)
v = mx.symbol.Variable("v", shape=(feature_dim, k), init=initializer)

w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=initializer)
w1 = mx.symbol.dot(x, w1_weight)

v_s = mx.symbol.sum(data=mx.symbol.square(data=v), axis=1)
x_s = mx.symbol.square(data=x)
bd = 0.5 * mx.symbol.negative(data=mx.symbol.broadcast_mul(x_s, v_s))

w2 = mx.symbol.dot(x, v)
w2_squared = 0.5 * mx.symbol.square(data=w2)

w_all = mx.symbol.Concat(w1, w2_squared, bd, dim=1)
model = mx.symbol.sum(data=w_all, axis=1, keepdims=True)
y = mx.symbol.Variable("out_label")
model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out")
return model

ctx = default_context()
k = 5
feature_dim = 20
model = fm_model(k, feature_dim, 'csr')

num_batches = 8
batch_size = 25
scipy_data = scipy.sparse.rand(num_batches * batch_size, feature_dim,
density=0.5, format='csr')
dns_label = mx.nd.ones((num_batches * batch_size,1))
csr_data = mx.sparse_nd.csr(scipy_data.data, scipy_data.indptr, scipy_data.indices,
(num_batches * batch_size, feature_dim))
data = csr_data

train_iter = mx.io.NDArrayIter(data=data,
label={'out_label':dns_label},
batch_size=batch_size)

# create module
mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['out_label'])
# allocate memory by given the input data and lable shapes
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
# initialize parameters by uniform random numbers
mod.init_params(initializer=mx.init.Uniform(scale=.1))
# use Sparse SGD with learning rate 0.1 to train
mod.init_optimizer(optimizer='sgd')
# use accuracy as the metric
metric = mx.metric.create('MSE')
# train 5 epoch, i.e. going over the data iter one pass
# TODO(haibin) test with row_sparse instead
storage_type_dict = {'v' : 'default_storage'}

for epoch in range(10):
train_iter.reset()
metric.reset()
for batch in train_iter:
mod.forward(batch, is_train=True) # compute predictions
mod.update_metric(metric, batch.label) # accumulate prediction accuracy
mod.backward() # compute gradients
mod.update(storage_type_dict) # update parameters
print('Epoch %d, Training %s' % (epoch, metric.get()))

if __name__ == '__main__':
test_module_dtype()
test_module_input_grads()
Expand All @@ -264,3 +331,4 @@ def mean_abs(x):
test_module_layout()
test_module_switch_bucket()
test_monitor()
test_fm_module()

0 comments on commit a28274a

Please sign in to comment.