Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of Normal Distribution to ANN module #2382

Merged
merged 13 commits into from May 14, 2020

Conversation

nishantkr18
Copy link
Member

@nishantkr18 nishantkr18 commented Apr 22, 2020

Hey!
The implementation of Normal Distribution in #1912 had a small bug in the LogProbablity() function. Also, the backward pass and tests for the class was not present.
This is an attempt to complete the work, as normal distribution is of use in many places, especially for continuous action space environments in RL module.
For the tests, I have taken 4 random values of mean, sigma and x to produce corresponding LogProbabilites and gradients, and compared it with pytorch's implementation.
I use this snippet to get the prob, dmu and dsigma.

mu = torch.tensor(1.1, requires_grad = True)
std = torch.tensor(0.1, requires_grad = True)
n = Normal(mu, std)
action_tensor = torch.tensor(1.05, requires_grad = False)
log_prob = n.log_prob(action_tensor)
ratio = torch.exp(log_prob)

print(log_prob.data.numpy())
ratio.backward()
print('gradients of mu and sigma: ', mu.grad.data.numpy(), std.grad.data.numpy())

Please have a look and share your thoughts :)

src/mlpack/tests/ann_dist_test.cpp Outdated Show resolved Hide resolved
src/mlpack/tests/ann_dist_test.cpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/dists/normal_distribution.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/dists/normal_distribution.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/dists/normal_distribution.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/dists/normal_distribution.hpp Outdated Show resolved Hide resolved
@nishantkr18 nishantkr18 requested a review from zoq April 25, 2020 19:00
@@ -127,4 +128,40 @@ BOOST_AUTO_TEST_CASE(JacobianBernoulliDistributionLogisticTest)
}
}

/**
* Normal Distribution module test.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we run the Jacobian tests for the normal distribution as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes.. definitely! I've made the changes.. pls have a look

@nishantkr18
Copy link
Member Author

closed and reopened PR to re-trigger azure pipelines

Copy link
Member

@kartikdutt18 kartikdutt18 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @nishantkr18, Thanks for working on this. I have left some minor style comments.
Other than that, I have more of a question which is related to this comment here in #1730.

Do you think that maybe we should update this PR and Bernoulli distribution in another PR and remove that section. Maybe @zoq could provide some input on this one. The only problem that I think that would be there is that a user would have to change the include statement. Maybe I missed something.

Kindly Let me know what you think.
Thanks for the contribution!

* Normal distribution is a function which accepts a mean and a standard deviation
* term and creates a probablity distribution out of it.
*/
template <typename DataType = arma::mat>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @nishantkr18, Could also please add a template parameter description. Thanks.

Comment on lines 107 to 109
/**
* Return the mean.
*/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @nishantkr18, For single line comments, do you mind using // to comment. Thanks a lot!

const DataType& Mean() const { return mean; }

/**
* Return a modifiable copy of the mean.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Return a modifiable copy of the mean.
//! Modify the mean.

Comment on lines 117 to 119
/**
* Return the standard deviation.
*/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/**
* Return the standard deviation.
*/
//! Get the standard deviation.

Comment on lines 135 to 139
{
// We just need to serialize each of the members.
ar & BOOST_SERIALIZATION_NVP(mean);
ar & BOOST_SERIALIZATION_NVP(sigma);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry, it's a picky comment, Do you mind shifting this to the implementation file. Most of the code base follows that. Thanks a lot!

@nishantkr18
Copy link
Member Author

I have left some minor style comments.

Thanks for those, I'll get them sorted. 👍

Do you think that maybe we should update this PR and Bernoulli distribution in another PR and remove that section.

Yeah, that sounds like a plan. Again I'm waiting for @zoq or @rcurtin 's reviews.
I think GaussianDistribution is used in a lot of places, and has API differences to NormalDistribution in ann/dist, similar is the case with Bernoulli dist. It'd be grt if we work on this in a separate PR since getting it merged might take some time. And I'd love to have a working normal-distribution just like this for now, so that we can get the #1912 PR moving. what do u say?

@kartikdutt18
Copy link
Member

And I'd love to have a working normal-distribution just like this for now, so that we can get the #1912 PR moving. what do u say?

Sure, I am fine with that. Maybe after this gets merged, we can open an issue to discuss the same, provided it needs discussing.

@nishantkr18
Copy link
Member Author

nishantkr18 commented May 8, 2020

Maybe after this gets merged, we can open an issue to discuss the same, provided it needs discussing.

Yeah sure 👍 , sorry for the two commits, forgot to add description in the first one

@nishantkr18 nishantkr18 requested a review from zoq May 10, 2020 19:27
Copy link
Member

@zoq zoq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No further comments from my side, thanks for putting this together.

Copy link

@mlpack-bot mlpack-bot bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second approval provided automatically after 24 hours. 👍

Copy link
Member

@favre49 favre49 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't have any comments, could you add to HISTORY.MD as well before we merge?

@nishantkr18
Copy link
Member Author

Don't have any comments, could you add to HISTORY.MD as well before we merge?

Done 👍

@zoq zoq merged commit 84c08fe into mlpack:master May 14, 2020
@zoq
Copy link
Member

zoq commented May 14, 2020

Thanks again for the contribution!

@nishantkr18 nishantkr18 deleted the normal-dist branch December 26, 2020 06:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants