Skip to content

Commit

Permalink
register sparse sgd under Optim.SGD
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed May 22, 2017
1 parent a28274a commit 838efc0
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 99 deletions.
30 changes: 0 additions & 30 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
from .ndarray import NDArray, zeros, clip, sqrt, sign
from .ndarray import sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update
from .ndarray import sparse_sgd_update, sparse_sgd_mom_update
from .random import normal


Expand Down Expand Up @@ -356,35 +355,6 @@ def update(self, index, weight, grad, state):
sgd_update(weight, grad, out=weight,
lr=lr, wd=wd, **kwargs)


@register
class SparseSGD(SGD):
"""SGD for non-zero rows
"""
def __init__(self, **kwargs):
super(SparseSGD, self).__init__(**kwargs)

def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
lr = self._get_lr(index)
wd = self._get_wd(index)
self._update_count(index)

kwargs = {'rescale_grad': self.rescale_grad}
if self.momentum > 0:
kwargs['momentum'] = self.momentum
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient

if state is not None:
sparse_sgd_mom_update(weight, grad, state, out=weight,
lr=lr, wd=wd, **kwargs)
else:
sparse_sgd_update(weight, grad, out=weight,
lr=lr, wd=wd, **kwargs)


@register
class DCASGD(Optimizer):
"""The DCASGD optimizer
Expand Down
1 change: 0 additions & 1 deletion src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ void FCompExFallback(const nnvm::NodeAttrs& attrs,
std::vector<NDArray> tmps;
common::GetInputBlobs<xpu>(inputs, &in_blobs, &tmps, ctx);
common::GetOutputBlobs<xpu>(outputs, &out_blobs);
LOG(INFO) << "Warning: fallback to default storage for " << fname;
fcompute(attrs, ctx, in_blobs, req, out_blobs);
}

Expand Down
39 changes: 20 additions & 19 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ inline void SGDUpdate(const nnvm::NodeAttrs& attrs,
/*! \brief kernel for sparse sgd
*/
template<int req>
struct SparseSGDDnsRspKernel {
struct SGDDnsRspKernel {
// DType is the output data type
// IType is row sparse idx type
// i is the ith row in row sparse gradient
Expand All @@ -110,9 +110,8 @@ struct SparseSGDDnsRspKernel {
}
};

// Impl implies a different interface than FComputeEx
template<typename xpu>
inline void SparseSGDUpdateDnsRspImpl(const SGDParam& param,
inline void SGDUpdateDnsRspImpl(const SGDParam& param,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
Expand All @@ -137,7 +136,7 @@ inline void SparseSGDUpdateDnsRspImpl(const SGDParam& param,
auto out_data = out.data().FlatTo2D<xpu, DType>(s);
auto num_rows = grad.aux_shape(rowsparse::kIdx)[0];
auto width = weight.shape().ProdShape(1, weight.shape().ndim());
mxnet_op::Kernel<SparseSGDDnsRspKernel<req_type>, xpu>::Launch(s, num_rows, width,
mxnet_op::Kernel<SGDDnsRspKernel<req_type>, xpu>::Launch(s, num_rows, width,
out_data.dptr_, weight_data.dptr_, grad_idx.dptr_, grad_val.dptr_,
static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
Expand All @@ -148,7 +147,7 @@ inline void SparseSGDUpdateDnsRspImpl(const SGDParam& param,
}

template<typename xpu>
inline void SparseSGDUpdateEx(const nnvm::NodeAttrs& attrs,
inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
Expand All @@ -160,9 +159,9 @@ inline void SparseSGDUpdateEx(const nnvm::NodeAttrs& attrs,
auto weight_stype = inputs[0].storage_type();
auto grad_stype = inputs[1].storage_type();
if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage) {
SparseSGDUpdateDnsRspImpl<xpu>(param, ctx, inputs, req, outputs);
} else {
LOG(FATAL) << "Not implemented";
SGDUpdateDnsRspImpl<xpu>(param, ctx, inputs, req, outputs);
} else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage) {
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, SGDUpdate<xpu>, "SGDUpdate");
}
}

Expand Down Expand Up @@ -236,7 +235,7 @@ inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs,
}

template<int req>
struct SparseSGDMomDnsRspDnsKernel {
struct SGDMomDnsRspDnsKernel {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, size_t width, DType* out_data,
DType* mom_data, const DType* weight_data, const IType* grad_idx,
Expand All @@ -262,7 +261,7 @@ struct SparseSGDMomDnsRspDnsKernel {
};

template<typename xpu>
inline void SparseSGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
Expand All @@ -285,7 +284,7 @@ inline void SparseSGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
auto out_data = out.data().FlatTo2D<xpu, DType>(s);
auto num_rows = grad.aux_shape(rowsparse::kIdx)[0];
auto width = weight.shape().ProdShape(1, weight.shape().ndim());
Kernel<SparseSGDMomDnsRspDnsKernel<req_type>, xpu>::Launch(s, num_rows, width,
Kernel<SGDMomDnsRspDnsKernel<req_type>, xpu>::Launch(s, num_rows, width,
out_data.dptr_, mom_data.dptr_, weight_data.dptr_, grad_idx.dptr_, grad_val.dptr_,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
Expand All @@ -296,11 +295,11 @@ inline void SparseSGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
}

template<typename xpu>
inline void SparseSGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
using namespace mxnet_op;
const SGDMomParam& param = nnvm::get<SGDMomParam>(attrs.parsed);
auto weight_stype = inputs[0].storage_type();
Expand All @@ -309,9 +308,11 @@ inline void SparseSGDMomUpdateEx(const nnvm::NodeAttrs& attrs,

if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage &&
mom_stype == kDefaultStorage) {
SparseSGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs, req, outputs);
} else {
LOG(FATAL) << "Not implemented";
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs, req, outputs);
} else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage &&
mom_stype == kDefaultStorage) {
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs,
SGDMomUpdate<xpu>, "SGDMomUpdate");
}
}

Expand Down
48 changes: 8 additions & 40 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,17 @@ It updates the weights using::
weight = weight - learning_rate * gradient
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SGDParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FCompute>("FCompute<cpu>", SGDUpdate<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_arguments(SGDParam::__FIELDS__());

NNVM_REGISTER_OP(sparse_sgd_update)
.describe(R"code(Update function for Stochastic Gradient Descent (SDG) optimizer.
It updates the weights using::
weight = weight - learning_rate * gradient for non-zero rows
If gradients are stored with `row_sparse` storage,
where update is applied only to rows whose gradient has non-zero entries.
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SGDParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
// TODO(haibin) implement FCompute for sparse sgd
// .set_attr<FCompute>("FCompute<cpu>", SGDUpdate<cpu>)
.set_attr<FComputeEx>(FCOMP_EX_CPU, SparseSGDUpdateEx<cpu>)
.set_attr<FCompute>("FCompute<cpu>", SGDUpdate<cpu>)
.set_attr<FComputeEx>(FCOMP_EX_CPU, SGDUpdateEx<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_arguments(SGDParam::__FIELDS__());
Expand All @@ -72,24 +56,9 @@ It updates the weights using::
Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
)code" ADD_FILELINE)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SGDMomParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2};
})
.set_attr<FCompute>("FCompute<cpu>", SGDMomUpdate<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mom", "NDArray-or-Symbol", "Momentum")
.add_arguments(SGDMomParam::__FIELDS__());
If gradients are stored with `row_sparse` storage,
only rows whose gradients contain non-zero entries are updated (for both weight and momentum).
NNVM_REGISTER_OP(sparse_sgd_mom_update)
.describe(R"code(Momentum update function for SGD for non-zero gradients
)code" ADD_FILELINE)
.set_num_inputs(3)
.set_num_outputs(1)
Expand All @@ -100,9 +69,8 @@ NNVM_REGISTER_OP(sparse_sgd_mom_update)
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2};
})
// TODO(haibin) implement FCompute
// .set_attr<FCompute>("FCompute<cpu>", SGDMomUpdate<cpu>)
.set_attr<FComputeEx>(FCOMP_EX_CPU, SparseSGDMomUpdateEx<cpu>)
.set_attr<FCompute>("FCompute<cpu>", SGDMomUpdate<cpu>)
.set_attr<FComputeEx>(FCOMP_EX_CPU, SGDMomUpdateEx<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mom", "NDArray-or-Symbol", "Momentum")
Expand Down
12 changes: 4 additions & 8 deletions src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@ namespace mxnet {
namespace op {

NNVM_REGISTER_OP(sgd_update)
.set_attr<FCompute>("FCompute<gpu>", SGDUpdate<gpu>);
.set_attr<FCompute>("FCompute<gpu>", SGDUpdate<gpu>)
.set_attr<FComputeEx>(FCOMP_EX_GPU, SGDUpdateEx<gpu>);

NNVM_REGISTER_OP(sgd_mom_update)
.set_attr<FCompute>("FCompute<gpu>", SGDMomUpdate<gpu>);

NNVM_REGISTER_OP(sparse_sgd_update)
.set_attr<FComputeEx>(FCOMP_EX_GPU, SparseSGDUpdateEx<gpu>);

NNVM_REGISTER_OP(sparse_sgd_mom_update)
.set_attr<FComputeEx>(FCOMP_EX_GPU, SparseSGDMomUpdateEx<gpu>);
.set_attr<FCompute>("FCompute<gpu>", SGDMomUpdate<gpu>)
.set_attr<FComputeEx>(FCOMP_EX_GPU, SGDMomUpdateEx<gpu>);

NNVM_REGISTER_OP(adam_update)
.set_attr<FCompute>("FCompute<gpu>", AdamUpdate<gpu>);
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def update(self, index, weight, grad, state):
def test_sparse_sgd():
mx.random.seed(0)
opt1 = PySparseSGD
opt2 = mx.optimizer.SparseSGD
opt2 = mx.optimizer.SGD
shape = (3, 4)
kwargs = [{},
{'momentum': 0.9},
Expand Down

0 comments on commit 838efc0

Please sign in to comment.