New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSoC] Implementing Hierarchical Memory Unit #1048

Open
wants to merge 69 commits into
base: master
from

Conversation

Projects
None yet
3 participants
@17minutes
Contributor

17minutes commented Jul 4, 2017

This PR is part of my GSoC project "Augmented RNNs".

Implementing HAM unit API and tree memory structure for it, as described in arxiv paper.

As far as I understand, the entire HAM-related discussion is transfered here, in contrast to task-specific discussion.

@17minutes

This comment has been minimized.

Show comment
Hide comment
@17minutes

17minutes Aug 4, 2017

Contributor

I've written up some provisional code stub + comments with my understanding of the paper. Could you glance over it and point out some logical mistakes?

Contributor

17minutes commented Aug 4, 2017

I've written up some provisional code stub + comments with my understanding of the paper. Could you glance over it and point out some logical mistakes?

@zoq

zoq reviewed Aug 4, 2017 edited

Looks good for me, just made some minor comments; Also I think we are going to use the FFN class for the Join/Write/Search currently the FFN class isn't part of the variant, Sumedh already made some adjustments in that direction so it might be a good idea to take a look at #1072.

@17minutes

This comment has been minimized.

Show comment
Hide comment
@17minutes

17minutes Aug 6, 2017

Contributor

As a next step, how can I create a hard-coded FFN model? More precisely, I want to hard-code some parts of HAM to see if Forward pass works correctly.

Contributor

17minutes commented Aug 6, 2017

As a next step, how can I create a hard-coded FFN model? More precisely, I want to hard-code some parts of HAM to see if Forward pass works correctly.

@zoq

This comment has been minimized.

Show comment
Hide comment
@zoq

zoq Aug 6, 2017

Member

What do you mean with hard-coded FNN model:

FFN<NegativeLogLikelihood<> > model;
odel.Add<Linear<> >(dataset.n_rows, 50);
model.Add<SigmoidLayer<> >();
model.Add<Linear<> >(50, 10);
model.Add<LogSoftMax<> >();

is probably not what you like? If you set a static seed you work with fixed parameter settings.

Member

zoq commented Aug 6, 2017

What do you mean with hard-coded FNN model:

FFN<NegativeLogLikelihood<> > model;
odel.Add<Linear<> >(dataset.n_rows, 50);
model.Add<SigmoidLayer<> >();
model.Add<Linear<> >(50, 10);
model.Add<LogSoftMax<> >();

is probably not what you like? If you set a static seed you work with fixed parameter settings.

@rcurtin

Looks like there is still a lot of work to do, so I hope my comments aren't overwhelming; they are meant to be helpful. :) If I can clarify anything I've written, just let me know.

If you have a chance to remove the unrelated code from this PR, it would make it nicer to scroll around and review. I can help you with that if you like, just let me know.

@17minutes

This comment has been minimized.

Show comment
Hide comment
@17minutes

17minutes Aug 7, 2017

Contributor

@zoq: No, I mean feeding my own weights to all MLPs in HAM unit. For example, if SEARCH is sigmoid(W * x + b), then I want to create SEARCH function with W = [0 0 ... 0] and b = -ln(2) to get a function that turns left 'blindly' with probability 1/3.

@rcurtin: Thanks, it's not overwhelming - I was starting to drown in the complexity anyway ^_^ Implementing.

Contributor

17minutes commented Aug 7, 2017

@zoq: No, I mean feeding my own weights to all MLPs in HAM unit. For example, if SEARCH is sigmoid(W * x + b), then I want to create SEARCH function with W = [0 0 ... 0] and b = -ln(2) to get a function that turns left 'blindly' with probability 1/3.

@rcurtin: Thanks, it's not overwhelming - I was starting to drown in the complexity anyway ^_^ Implementing.

@17minutes

This comment has been minimized.

Show comment
Hide comment
@17minutes

17minutes Aug 7, 2017

Contributor

If you have a chance to remove the unrelated code from this PR, it would make it nicer to scroll around and review. I can help you with that if you like, just let me know.

I know, I just wait for benchmarking PR to be merged - I'll rebase as soon as it will be merged, because there is a lot of useful code there :)

Contributor

17minutes commented Aug 7, 2017

If you have a chance to remove the unrelated code from this PR, it would make it nicer to scroll around and review. I can help you with that if you like, just let me know.

I know, I just wait for benchmarking PR to be merged - I'll rebase as soon as it will be merged, because there is a lot of useful code there :)

@zoq

This comment has been minimized.

Show comment
Hide comment
@zoq

zoq Aug 7, 2017

Member

@zoq: No, I mean feeding my own weights to all MLPs in HAM unit. For example, if SEARCH is sigmoid(W * x + b), then I want to create SEARCH function with W = [0 0 ... 0] and b = -ln(2) to get a function that turns left 'blindly' with probability 1/3.

What you could do is to set the parameters manually e.g.

model.Parameters().subvec(0, 10).zeros();
model.Parameters()(10) = 3;

does this help? I could also implement some function that enables you to use the layer do to the same, e.g.

layer.Parameters().zeros();
Member

zoq commented Aug 7, 2017

@zoq: No, I mean feeding my own weights to all MLPs in HAM unit. For example, if SEARCH is sigmoid(W * x + b), then I want to create SEARCH function with W = [0 0 ... 0] and b = -ln(2) to get a function that turns left 'blindly' with probability 1/3.

What you could do is to set the parameters manually e.g.

model.Parameters().subvec(0, 10).zeros();
model.Parameters()(10) = 3;

does this help? I could also implement some function that enables you to use the layer do to the same, e.g.

layer.Parameters().zeros();
@zoq

Just some minor comments; also the attention part looks good, couldn't see any issues in the results and logic.

@zoq

This comment has been minimized.

Show comment
Hide comment
@zoq

zoq Aug 13, 2017

Member

08:58 < partobs-mdp> By the way, do we have some type that would be inherited by both FFN and RNN? It would be nice since then we could test both MLP and LSTM controllers. (But if there is no such type, I think we'd better bite the bullet and do it immediately with RNN controller, just as described in the paper)
09:00 < partobs-mdp> (Oops, sorry, forgot about templates - we could do typename C = FFN<CrossEntropyError<>> and run whatever methods we need)
09:23 -!- kris1 [~kris@103.232.241.5] has joined #mlpack
09:47 < partobs-mdp> zoq: rcurtin: Finished working on controller - the complete blind HAM is also there. If I didn't mess up with the numbers, then it should be okay :)

You could also use LayerTypes, Sumedh also uses the FFN class as controller so he made the FNN class part of LayerTypes: https://github.com/mlpack/mlpack/pull/1072/files#diff-1ad44b5de9eab9004eb0a4ccc45c88dbR112

Member

zoq commented Aug 13, 2017

08:58 < partobs-mdp> By the way, do we have some type that would be inherited by both FFN and RNN? It would be nice since then we could test both MLP and LSTM controllers. (But if there is no such type, I think we'd better bite the bullet and do it immediately with RNN controller, just as described in the paper)
09:00 < partobs-mdp> (Oops, sorry, forgot about templates - we could do typename C = FFN<CrossEntropyError<>> and run whatever methods we need)
09:23 -!- kris1 [~kris@103.232.241.5] has joined #mlpack
09:47 < partobs-mdp> zoq: rcurtin: Finished working on controller - the complete blind HAM is also there. If I didn't mess up with the numbers, then it should be okay :)

You could also use LayerTypes, Sumedh also uses the FFN class as controller so he made the FNN class part of LayerTypes: https://github.com/mlpack/mlpack/pull/1072/files#diff-1ad44b5de9eab9004eb0a4ccc45c88dbR112

@17minutes

This comment has been minimized.

Show comment
Hide comment
@17minutes

17minutes Aug 14, 2017

Contributor

@zoq: About #1072 - what is the status of that PR? I hope it will be merged soon, since it (as far as I can see) already contains tested integration of FFN with LayerTypes.

Contributor

17minutes commented Aug 14, 2017

@zoq: About #1072 - what is the status of that PR? I hope it will be merged soon, since it (as far as I can see) already contains tested integration of FFN with LayerTypes.

@zoq

This comment has been minimized.

Show comment
Hide comment
@zoq

zoq Aug 14, 2017

Member

#1072 isn't done yet, I would recommend to copy the changes from the PR you think are helpful.

Member

zoq commented Aug 14, 2017

#1072 isn't done yet, I would recommend to copy the changes from the PR you think are helpful.

@rcurtin

Just added some comments, hope they are helpful. I am not sure if the tests are passing because Travis can't build the PR. If you need any help debugging anything, feel free to ask. :)

Show outdated Hide outdated src/mlpack/methods/ann/augmented/ham_unit.hpp
Show outdated Hide outdated src/mlpack/methods/ann/augmented/ham_unit_impl.hpp
Show outdated Hide outdated src/mlpack/methods/ann/augmented/tree_memory_impl.hpp
size_t start = actualMemorySize - 1 + pos;
arma::Mat<T> newCellValue;
writeFunction.Predict(Stack(Cell(start), el), newCellValue);
Cell(start) = newCellValue;

This comment has been minimized.

@rcurtin

rcurtin Aug 16, 2017

Member

Can writeFunction.Predict() be modified to output directly into Cell(value) and thus obviate the need for newCellValue?

@rcurtin

rcurtin Aug 16, 2017

Member

Can writeFunction.Predict() be modified to output directly into Cell(value) and thus obviate the need for newCellValue?

This comment has been minimized.

@17minutes

17minutes Aug 17, 2017

Contributor

I tried it, but it doesn't really work since Cell(node) doesn't really return reference to the corresponding column - it returns arma::subview_col<T>, and there goes the C++ type system :)

@17minutes

17minutes Aug 17, 2017

Contributor

I tried it, but it doesn't really work since Cell(node) doesn't really return reference to the corresponding column - it returns arma::subview_col<T>, and there goes the C++ type system :)

This comment has been minimized.

@rcurtin

rcurtin Aug 18, 2017

Member

I'd suggest avoiding the functions Cell() and Leaf() entirely, and instead access the correct column of your memory directly. After all, Cell() and Leaf() are just calling .col() and then forcing the type of the result. I think that should be avoided.

Instead try the following:

writeFunction.Predict(Stack(memory.unsafe_col(start), el), memory.unsafe_col(start));

The unsafe_col() function returns a new column vector that is an alias of the same memory, so no copy happens.

@rcurtin

rcurtin Aug 18, 2017

Member

I'd suggest avoiding the functions Cell() and Leaf() entirely, and instead access the correct column of your memory directly. After all, Cell() and Leaf() are just calling .col() and then forcing the type of the result. I think that should be avoided.

Instead try the following:

writeFunction.Predict(Stack(memory.unsafe_col(start), el), memory.unsafe_col(start));

The unsafe_col() function returns a new column vector that is an alias of the same memory, so no copy happens.

Show outdated Hide outdated src/mlpack/methods/ann/augmented/tree_memory_impl.hpp
Show outdated Hide outdated src/mlpack/tests/ham_test.cpp
Show outdated Hide outdated src/mlpack/tests/ham_test.cpp
@rcurtin

This comment has been minimized.

Show comment
Hide comment
@rcurtin

rcurtin Aug 18, 2017

Member

It looks like the RebuildParameters() method is necessary because the various parts of the HAMUnit aren't stored in contiguous memory. I think it should be possible to, when you create the HAMUnit object, ensure that all the memory used is contiguous. Consider this code:

// Assume we know the size of everything...
extern size_t embedCount, searchCount, controllerCount, joinCount, writeCount;

// Now create the parameters matrix.
parameters = arma::mat(embedCount + searchCount + controllerCount + joinCount + writeCount, 1);

// Next, set up the parameters matrices of other things to be aliases
// of the right part of the parameters matrix.
embed.Parameters() = arma::mat(parameters.memptr(), embedCount, 1, false /* no copy */, true /* no resize */);
search.Parameters() = arma::mat(parameters.memptr() + embedCount, searchCount, 1, false, true);
controller.Parameters() = arma::mat(parameters.memptr() + embedCount + searchCount, controllerCount, 1, false, true);
memory.JoinObject().Parameters() = arma::mat(parameters.memptr() + embedCount + searchCount + controllerCount, joinCount, 1, false, true);
memory.WriteObject.Parameters() = arma::mat(parameters.memptr() + embedCount + searchCount + controllerCount + joinCount, writeCount, 1, false, true);

That uses one of the Armadillo advanced constructors that re-uses other memory. So after that code, the memory of each component of HAMUnit is set up as an alias of HAMUnit::parameters, and no RebuildParameters() function would be necessary.

Also as a side note, be careful with statements like this:

arma::mat embedParams = embed.Parameters();

That incurs a full copy of embed.Parameters(). To avoid the copy, either don't make the temporary variable embedParams, or use a const arma::mat& for its type.

Let me know if I can clarify anything; I think the idea I just proposed would help make the design simpler (and help make debugging the rest of it simpler).

Member

rcurtin commented Aug 18, 2017

It looks like the RebuildParameters() method is necessary because the various parts of the HAMUnit aren't stored in contiguous memory. I think it should be possible to, when you create the HAMUnit object, ensure that all the memory used is contiguous. Consider this code:

// Assume we know the size of everything...
extern size_t embedCount, searchCount, controllerCount, joinCount, writeCount;

// Now create the parameters matrix.
parameters = arma::mat(embedCount + searchCount + controllerCount + joinCount + writeCount, 1);

// Next, set up the parameters matrices of other things to be aliases
// of the right part of the parameters matrix.
embed.Parameters() = arma::mat(parameters.memptr(), embedCount, 1, false /* no copy */, true /* no resize */);
search.Parameters() = arma::mat(parameters.memptr() + embedCount, searchCount, 1, false, true);
controller.Parameters() = arma::mat(parameters.memptr() + embedCount + searchCount, controllerCount, 1, false, true);
memory.JoinObject().Parameters() = arma::mat(parameters.memptr() + embedCount + searchCount + controllerCount, joinCount, 1, false, true);
memory.WriteObject.Parameters() = arma::mat(parameters.memptr() + embedCount + searchCount + controllerCount + joinCount, writeCount, 1, false, true);

That uses one of the Armadillo advanced constructors that re-uses other memory. So after that code, the memory of each component of HAMUnit is set up as an alias of HAMUnit::parameters, and no RebuildParameters() function would be necessary.

Also as a side note, be careful with statements like this:

arma::mat embedParams = embed.Parameters();

That incurs a full copy of embed.Parameters(). To avoid the copy, either don't make the temporary variable embedParams, or use a const arma::mat& for its type.

Let me know if I can clarify anything; I think the idea I just proposed would help make the design simpler (and help make debugging the rest of it simpler).

@zoq

This comment has been minimized.

Show comment
Hide comment
@zoq

zoq Aug 24, 2017

Member

Thanks for the comments, makes it easier to understand the code. Can you take a look at the build issues? It would be great if the code can be build automatically, it looks like some files are missing.

Member

zoq commented Aug 24, 2017

Thanks for the comments, makes it easier to understand the code. Can you take a look at the build issues? It would be great if the code can be build automatically, it looks like some files are missing.

17minutes and others added some commits Aug 25, 2017

@rcurtin

This comment has been minimized.

Show comment
Hide comment
@rcurtin

rcurtin Aug 25, 2017

Member

Looks like this compiles and passes the tests correctly, which is great. :) Have you had a chance to benchmark the task against the LSTM model?

Member

rcurtin commented Aug 25, 2017

Looks like this compiles and passes the tests correctly, which is great. :) Have you had a chance to benchmark the task against the LSTM model?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment