Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPP implementation of L2Norm and LRN ops #1157

Merged
merged 12 commits into from Jun 22, 2018

Conversation

PariksheetPinjari909
Copy link
Contributor

This PR has CPP implementation of LRN and L2Norm. It also redirect LRN and L2Norm python ops to CPP ops.

@tqchen
Copy link
Member

tqchen commented May 11, 2018

@sxjscience can you please do a round of code review?

@PariksheetPinjari909
Copy link
Contributor Author

I have added nnvm frontend support for lrn and l2norm ops. Please review

@PariksheetPinjari909
Copy link
Contributor Author

@tqchen, I am facing a build error only in the i386 environment with an unrelated source code.
make: *** No rule to make target 'topi/include/topi/nn/scale.h', needed by 'build/topi/topi.o'. Stop. ///scale.h is not related to this pull request.

Could you please help for a clean build to see whether it is an environment issue?

@tqchen
Copy link
Member

tqchen commented Jun 4, 2018

@kazum @merrymercy can you help review this PR?

DMLC_DECLARE_FIELD(axis)
.describe("input data layout channel axis");
DMLC_DECLARE_FIELD(alpha)
.describe("alpha constant.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scaling parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

DMLC_DECLARE_FIELD(alpha)
.describe("alpha constant.");
DMLC_DECLARE_FIELD(beta)
.describe("beta constant.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exponent number.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

DMLC_DECLARE_FIELD(beta)
.describe("beta constant.");
DMLC_DECLARE_FIELD(bias)
.describe("bias constant.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The offset parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

import tvm
import topi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lint is suggesting this change.

offset to avoid dividing by 0. constant value

alpha : float
contant valie
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constant value

auto lrn = outs[0];
auto sqr_sum_up = lrn->op->InputTensors()[1];
auto sqr_sum = sqr_sum_up->op->InputTensors()[0];
auto set_pad = sqr_sum->op->InputTensors()[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -1,7 +1,7 @@
/*!
* Copyright (c) 2017 by Contributors
* \brief NN op constructions
* \file topi/nn.h
* \file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

* \file nn.h?

using namespace tvm;

/*!
* \brief L2 normalization inference operator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove a trailing space.

std::string tag = kBroadcast) {
CHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input";
assert(size % 2 == 1);
assert(axis == 1 || axis == 3);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use CHECK_* macros here because assert() can be compiled out when we define NDEBUG.

@PariksheetPinjari909
Copy link
Contributor Author

@tqchen I am facing the same error again
make: *** No rule to make target 'topi/include/topi/nn/scale.h', needed by 'build/topi/topi.o'. Stop.

@@ -313,6 +313,41 @@ struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
}
};

struct LrnParam : public dmlc::Parameter<LrnParam> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lrn-> LRN

DMLC_REGISTER_PARAMETER(L2normParam);

inline bool L2normInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alignment

return true;
}

NNVM_REGISTER_OP(l2norm)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

l2norm not a typical operator, a typical version os numpy.lingalg.norm, so I would recommend we do not add registration and support proper norm in a separate PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, it is good idea to have generalized norm operation. I will remove L2Norm from this PR and will raise another PR to have generalized norm operation.

Copy link
Member

@tqchen tqchen Jun 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was confused by the name of the API, the current API is l2_normalize which performs the normalization, instead of calculating the norm.

Let us make the name clear in both TOPI and nnvm.
c.f. related tensorflow API https://www.tensorflow.org/api_docs/python/tf/nn/l2_normalize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L2 norm operation can be used to perform l2 normalization. If we are planning to add generalized norm op, we can use the same to compute l2 normalization.

sqr_sum[i, j, k, l] = sum(a_np[i, j, k, sum_start:sum_end] * \
a_np[i, j, k, sum_start:sum_end])

for i in range(axis0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use broadcasting semantics

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen , I have tried using broadcast operations further on this changes, but here the sum is doing based on window size and the window move across the axis. could you please elaborate your suggestion

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, never mind, it is fine to keep this one as for loop

*
* \return A schedule for the given ops.
*/
inline Schedule schedule_l2norm(const Target &target, const Array<Tensor>& outs) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove L2 norm for now and can do that in separate PR, support norm later.

@@ -365,6 +365,129 @@ def forward(x):
inputs = [('x', (1, 3, 28, 28), x)]
helper(y, inputs, dtype, forward)

def verify_lrn(n, c, h, w, size, axis, bias, alpha, beta):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def verify_lrn(ishape, size, axis, bias, alpha, beta) and using ishape instead of dshape looks simpler.

offset to avoid dividing by 0. constant value

alpha : float
contant value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constant

radius = size // 2
sqr_sum = np.zeros(shape=a_np.shape).astype(a_np.dtype)
sqr_sum_up = np.zeros(shape=a_np.shape).astype(a_np.dtype)
lrn_out = np.zeros(shape=a_np.shape).astype(a_np.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line can be removed.

out_np = lrn_python(x_np, size, axis, bias, alpha, beta)
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)

def verify_l2norm(batch, channel, height, width, eps, axis):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def verify_l2norm(ishape, eps, axis) looks simpler.

std::string tag = kBroadcast) {
CHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input";
CHECK_EQ(size % 2, 1) << "size should be odd number";
CHECK_EQ((axis - 1) && (axis - 3), 0) << "axis should be 1 or 3 for NCHW and NHWC";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK(axis == 1 || axis == 3)

l2norm_out : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, axis1, axis2, axis3 = a_np.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch = a_np.shape[0] looks simpler? We don't need axis[1-3].

batch, axis1, axis2, axis3 = a_np.shape
sqr_sum = np.zeros(shape=(batch,)).astype(a_np.dtype)
sqrt_sum = np.zeros(shape=(batch,)).astype(a_np.dtype)
l2norm_out = np.zeros(shape=a_np.shape).astype(a_np.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be removed.

sqrt_sum = np.sqrt(np.maximum(np.broadcast_to(sqr_sum, a_np.shape), eps))
return np.divide(a_np, sqrt_sum)

def verify_l2norm(n, c, h, w, eps, axis=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest def verify_l2norm(shape, eps, axis=None).

@@ -0,0 +1,101 @@
"""Test code for LRN"""
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be removed.

radius = size // 2
sqr_sum = np.zeros(shape=a_np.shape).astype(a_np.dtype)
sqr_sum_up = np.zeros(shape=a_np.shape).astype(a_np.dtype)
lrn_out = np.zeros(shape=a_np.shape).astype(a_np.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove.

return true;
}

NNVM_REGISTER_OP(l2norm)
Copy link
Member

@tqchen tqchen Jun 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was confused by the name of the API, the current API is l2_normalize which performs the normalization, instead of calculating the norm.

Let us make the name clear in both TOPI and nnvm.
c.f. related tensorflow API https://www.tensorflow.org/api_docs/python/tf/nn/l2_normalize

@tqchen
Copy link
Member

tqchen commented Jun 11, 2018

OK, please fix the comments, remove l2norm or add l2_normalize to this PR and let us aim to prioritize and bring this in. This code review has been hanging for a bit long

@tqchen tqchen added the status: need update need update based on feedbacks label Jun 12, 2018
@kevinthesun kevinthesun mentioned this pull request Jun 14, 2018
4 tasks
@tqchen
Copy link
Member

tqchen commented Jun 15, 2018

@PariksheetPinjari909 can you act on the comments? Rebase is needed

@PariksheetPinjari909
Copy link
Contributor Author

@tqchen Thanks, i was about to commit, then i saw rebase is needed. I have made the l2normalize naming to avoid future confusion. Pls review.

@PariksheetPinjari909
Copy link
Contributor Author

@kazum Pls review

@kevinthesun kevinthesun mentioned this pull request Jun 16, 2018
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", L2normalizeInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_support_level(1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need FCorrectLayout attribute to for correct layout pass.


reg.register_pattern("lrn", OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("l2normalize")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use l2_normalize(with underscore), to be consistent with tensorflow API

with tvm.target.create(target):
return topi.generic.schedule_lrn(outs)

reg.register_pattern("lrn", OpPattern.OUT_ELEMWISE_FUSABLE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we confirm if lrn is OUT_ELEMWISE_FUSABLE. We need to add a testcase, lrn followed by relu, and confirm if the test pass on GPU

*
* \return A Tensor whose op member is the l2 normalization operation
*/
inline Tensor l2normalize_instance(const Tensor& data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

l2_normalize

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does instance mean in here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instance name was given with respect to mxnet l2_normalize function, but now we are supporting l2_normalize in all axes so no need to keep the instance name. I will remove it. Thanks for pointing out.

with tvm.target.create(target):
return topi.generic.schedule_l2normalize(outs)

reg.register_pattern("l2normalize", OpPattern.OUT_ELEMWISE_FUSABLE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we want to mark it as OUT_ELEMWISE_FUSABLE, confirm with a testcase of op + elemwise operator so that it generate testcase for used ops

l2normalize_out : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch = a_np.shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line can be removed.

sqr_sum = np.zeros(shape=(batch,)).astype(a_np.dtype)
sqrt_sum = np.zeros(shape=(batch,)).astype(a_np.dtype)
l2norm_out = np.zeros(shape=a_np.shape).astype(a_np.dtype)
batch = a_np.shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

l2normalize_out : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch = a_np.shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.


NNVM_REGISTER_OP(lrn)
.describe(R"code(LRN layer)code" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tesndor", "Input data.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"4D Tensor"


NNVM_REGISTER_OP(l2_normalize)
.describe(R"code(L2NORMALIZE layer)code" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tesndor", "Input data.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"4D Tensor"

axis0, axis1, axis2, axis3 = a_np.shape
radius = size // 2
sqr_sum = np.zeros(shape=a_np.shape).astype(a_np.dtype)
def sum_dot_values(i, j, k, l):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from itertools import product

for i, j, k, l in product(*[range(_axis) for _axis in a_np.shape]):

and we can remove the nested loop below. I think this cleanup is matter of taste, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it looks nicer now. Thanks

dtype = "float32"
x_np = np.random.uniform(size=ishape).astype(dtype)

def l2_normalize_python(a_np, eps, axis=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this function to topi.testing.

dtype = "float32"
x_np = np.random.uniform(size=ishape).astype(dtype)

def lrn_python(a_np, size, axis, bias, alpha, beta):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this function to topi.testing

@tqchen
Copy link
Member

tqchen commented Jun 21, 2018

Some final comments and should be approved from my side, await @kazum 's comment

Copy link
Contributor

@kazum kazum left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good from my side, thanks!

@tqchen
Copy link
Member

tqchen commented Jun 21, 2018

@PariksheetPinjari909 please act on my final comments

@PariksheetPinjari909
Copy link
Contributor Author

@tqchen all reviews are handled now. Pls check.

@tqchen tqchen merged commit e0e0a23 into apache:master Jun 22, 2018
@tqchen
Copy link
Member

tqchen commented Jun 22, 2018

Thanks! This is merged!

tqchen pushed a commit to tqchen/tvm that referenced this pull request Jul 6, 2018
mnuyens pushed a commit to mnuyens/tvm that referenced this pull request Jul 10, 2018
grwlf pushed a commit to grwlf/tvm that referenced this pull request Aug 8, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants