Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSoC] Implementing Hierarchical Memory Unit #1048

Closed
wants to merge 69 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
b02fcab
Added mlpack::methods::ann::augmented namespace + some validation/tes…
sidorov-ks May 21, 2017
5e590f3
Separated scoring from generating data
sidorov-ks Jun 5, 2017
659678f
Style issues fix
sidorov-ks Jun 6, 2017
5e132a2
Fixing style issues from the GitHub checker
sidorov-ks Jun 6, 2017
0665cf7
Added newline in the end of score.hpp
sidorov-ks Jun 6, 2017
c983446
Added SortTask + unit test for it
sidorov-ks Jun 6, 2017
26c66b0
Added AddTask class + unit test for it
sidorov-ks Jun 7, 2017
11df502
Fixing Jenkins issue with if/else braces
sidorov-ks Jun 7, 2017
46e0b68
Fixing Jenkins issue with >80 characters long lines
sidorov-ks Jun 7, 2017
e92c200
Fixed Jenkins issue with closing brace
sidorov-ks Jun 7, 2017
7d9e4e0
Fixed various minor Jenkins warnings
sidorov-ks Jun 7, 2017
a5e225e
Removed \n before public: keyword
sidorov-ks Jun 7, 2017
9213e4d
Yet another fix of Jenkins issues
sidorov-ks Jun 7, 2017
27b1f66
Documentation of AddTask and SortTask
sidorov-ks Jun 8, 2017
f857621
Renaming of GenerateData to Generate
sidorov-ks Jun 8, 2017
f357a0e
Fixed from code review
sidorov-ks Jun 10, 2017
53e29ea
Fixed Cppcheck issue with closing brace
sidorov-ks Jun 10, 2017
873192c
Code refactoring + some rewrite towards arma's native methods
sidorov-ks Jun 12, 2017
29b760b
Fixed cppcheck blankline issue
sidorov-ks Jun 12, 2017
f302885
Added (completely minimal) LSTM baseline solution as a unit test
sidorov-ks Jun 13, 2017
1f0b4bb
More cppcheck issue fixes
sidorov-ks Jun 13, 2017
4371fab
Fixed LSTM baseline bug
sidorov-ks Jun 14, 2017
69d5bcc
Adding a (not so good) AddTask test for LSTM baseline solution
sidorov-ks Jun 14, 2017
df08bf0
Chaned input representation for CopyTask baseline
sidorov-ks Jun 15, 2017
08fe104
Changing SortTask baseline for multiple evaluations
sidorov-ks Jun 16, 2017
2dbe00f
Adding baseline solution with unary representation for nRepeats
sidorov-ks Jun 18, 2017
41db9a3
Added CopyTask input representation with unary representation of repe…
sidorov-ks Jun 20, 2017
497b23f
Implemented repeat count representation as part of the CopyTask inter…
sidorov-ks Jun 21, 2017
c100a88
Major cleanup of the code, featuring:
sidorov-ks Jun 22, 2017
de238a9
Cppcheck style fixes + some assertions
sidorov-ks Jun 22, 2017
bab90c0
Refactored AddTask definition for LSTM training
sidorov-ks Jun 23, 2017
fba6576
Fixed AddTask generator for the case num_A = num_B = 0
sidorov-ks Jun 27, 2017
13cd1de
Fixing cppcheck issues
sidorov-ks Jun 27, 2017
718e741
Adding some king of code for HAM interface and memory structure for it
sidorov-ks Jun 30, 2017
5fc7ed3
Updated AddTask definition
sidorov-ks Jul 3, 2017
c30ecbc
Fixed issues from @rcurtin's review
sidorov-ks Jul 3, 2017
382d73d
Put TreeMemory to the working condition + tests for it
sidorov-ks Jul 4, 2017
28aa20a
Trying to understand the HAM paper
sidorov-ks Aug 4, 2017
a028623
More stubs
sidorov-ks Aug 5, 2017
8ccecc6
Refactoring TreeMemory for storing multidimensional vectors + some tests
sidorov-ks Aug 7, 2017
e696682
Transferring HAM tests to the separate file
sidorov-ks Aug 8, 2017
62ab95a
Merged from upstream/master
sidorov-ks Aug 8, 2017
477c57e
Trying to add blind HAM test
sidorov-ks Aug 8, 2017
714c973
Changed includes - at least it compiles
sidorov-ks Aug 8, 2017
655d7f2
Implemented indentity as FFN
sidorov-ks Aug 8, 2017
ef1e79a
Trying to implement JOIN - not so good so far
sidorov-ks Aug 9, 2017
4ec08f7
Fixed issue with JOIN operation
sidorov-ks Aug 9, 2017
cb62fdc
Implemented all primitive FFN models for testing the forward pass of …
sidorov-ks Aug 9, 2017
fc30187
Trying to make HAM forward pass + native FFN support for TreeMemory -…
sidorov-ks Aug 9, 2017
41ba7a7
Minor fixes in HAMUnit + update for CMakeList
sidorov-ks Aug 9, 2017
10aea5b
More includes - doesn't work anyway
sidorov-ks Aug 9, 2017
872c221
Finally something compilable
sidorov-ks Aug 11, 2017
fd581ae
Trying to get HAMUnit compile errors - almost there
sidorov-ks Aug 11, 2017
766b0d9
Fixed nDim issue
sidorov-ks Aug 12, 2017
d939fe6
Finally got the Attention and Forward methods right
sidorov-ks Aug 12, 2017
eae8e9d
Added FFN controller to HAMUnit + finalizing blind HAM test
sidorov-ks Aug 13, 2017
abdea99
Fixed issues from @zoq's review
sidorov-ks Aug 14, 2017
aea67a9
Trying to add Parameters() - not so good so far
sidorov-ks Aug 17, 2017
989731a
Added Parameters() method + unit test for it
sidorov-ks Aug 17, 2017
326534b
Trying to refactor Parameters() - not so good so far
sidorov-ks Aug 19, 2017
9173a24
Resolving issue with strict parameter
sidorov-ks Aug 20, 2017
9fe7d1d
Successfully refactored HAM<>.Parameters()
sidorov-ks Aug 22, 2017
5c82e4b
Added docs to the TreeMemory and HAMUnit
sidorov-ks Aug 23, 2017
198e32d
Trying to fix the Travis issue
sidorov-ks Aug 25, 2017
142ad4f
Fixing cppcheck issues
sidorov-ks Aug 25, 2017
aea121e
Fixing newline issue
sidorov-ks Aug 25, 2017
27059a9
Trying to merge const_init
sidorov-ks Aug 25, 2017
837aae9
Replaced zero_init with const_init
sidorov-ks Aug 25, 2017
38f2af5
Merge branch 'master' into ham
zoq Aug 25, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/mlpack/methods/ann/augmented/ham_unit.hpp
Expand Up @@ -14,7 +14,7 @@
#define MLPACK_METHODS_ANN_AUGMENTED_HAM_UNIT_HPP

#include <mlpack/methods/ann/layer/layer.hpp>
#include <mlpack/methods/ann/layer/layer.hpp>
#include <mlpack/methods/ann/init_rules/random_init.hpp>

#include "tree_memory.hpp"

Expand Down Expand Up @@ -56,9 +56,9 @@ class HAMUnit
arma::mat&& g);

//! Return the initial point for the optimization.
const arma::mat& Parameters() const { RebuildParameters(); return parameters; }
const arma::mat& Parameters() const { return parameters; }
//! Modify the initial point for the optimization.
arma::mat& Parameters() { RebuildParameters(); return parameters; }
arma::mat& Parameters() { return parameters; }

void ResetParameters();
private:
Expand All @@ -78,6 +78,8 @@ class HAMUnit
E embed;
C controller;

bool reset;

// Currently processed sequence.
arma::mat sequence;
size_t t;
Expand Down
36 changes: 32 additions & 4 deletions src/mlpack/methods/ann/augmented/ham_unit_impl.hpp
Expand Up @@ -29,7 +29,7 @@ HAMUnit<E, J, S, W, C>::HAMUnit(size_t memorySize,
: memorySize(memorySize), memoryDim(memoryDim),
search(search), embed(embed), controller(controller),
memory(TreeMemory<double, J, S>(memorySize, memoryDim, join, write)),
t(0)
t(0), reset(false)
{
// Nothing to do here
}
Expand Down Expand Up @@ -92,11 +92,11 @@ void HAMUnit<E, J, S, W, C>::Forward(arma::mat&& input, arma::mat&& output) {

template<typename E, typename J, typename S, typename W, typename C>
void HAMUnit<E, J, S, W, C>::RebuildParameters() {
arma::mat embedParams = embed.Parameters();
arma::mat embedParams = embed.Parameters();
arma::mat searchParams = search.Parameters();
arma::mat controllerParams = controller.Parameters();
arma::mat joinParams = memory.JoinObject().Parameters();
arma::mat writeParams = memory.WriteObject().Parameters();
arma::mat writeParams = memory.WriteObject().Parameters();
size_t embedCount = embedParams.n_elem,
searchCount = searchParams.n_elem,
controllerCount = controllerParams.n_elem,
Expand All @@ -115,6 +115,7 @@ void HAMUnit<E, J, S, W, C>::RebuildParameters() {
parameters.rows(
embedCount + searchCount + controllerCount + joinCount,
parameters.n_elem - 1) = writeParams;
reset = true;
}

template<typename E, typename J, typename S, typename W, typename C>
Expand All @@ -123,7 +124,34 @@ void HAMUnit<E, J, S, W, C>::ResetParameters() {
embed.ResetParameters();
controller.ResetParameters();
memory.ResetParameters();
RebuildParameters();

arma::mat embedParams = embed.Parameters();
arma::mat searchParams = search.Parameters();
arma::mat controllerParams = controller.Parameters();
arma::mat joinParams = memory.JoinObject().Parameters();
arma::mat writeParams = memory.WriteObject().Parameters();

size_t embedCount = embedParams.n_elem,
searchCount = searchParams.n_elem,
controllerCount = controllerParams.n_elem,
joinCount = joinParams.n_elem,
writeCount = writeParams.n_elem;

parameters = arma::mat(embedCount + searchCount + controllerCount + joinCount + writeCount, 1);

embed.Parameters() = arma::mat(parameters.memptr(), embedCount, 1, false, true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented Parameters() method as rcurtin described (make individual function parameters as memory pointers to some contiguous memory block), but there are still only zeros in the HAMUnit paramters. Can you take a look at the issue?

It should work if you use embed.Parameters() = arma::mat(parameters.memptr(), embedCount, 1, false, false); instead of embed.Parameters() = arma::mat(parameters.memptr(), embedCount, 1, false, true);. Let us know if this works for you.

search.Parameters() = arma::mat(parameters.memptr() + embedCount, searchCount, 1, false, true);
controller.Parameters() = arma::mat(parameters.memptr() + embedCount + searchCount, controllerCount, 1, false, true);
memory.JoinObject().Parameters() = arma::mat(parameters.memptr() + embedCount + searchCount + controllerCount, joinCount, 1, false, true);
memory.WriteObject().Parameters() = arma::mat(parameters.memptr() + embedCount + searchCount + controllerCount + joinCount, writeCount, 1, false, true);

std::cerr << "E:\n" << embed.Parameters().t();
std::cerr << "S:\n" << search.Parameters().t();
std::cerr << "C:\n" << controller.Parameters().t();
std::cerr << "J:\n" << memory.JoinObject().Parameters().t();
std::cerr << "W:\n" << memory.WriteObject().Parameters().t();

reset = true;
}

} // namespace augmented
Expand Down
50 changes: 32 additions & 18 deletions src/mlpack/tests/ham_test.cpp
Expand Up @@ -259,50 +259,64 @@ BOOST_AUTO_TEST_CASE(BlindHAMUnitTest) {
// Embed model is just an identity function.
FFN<MeanSquaredError<> > embedModel;
embedModel.Add<Linear<> >(nDim, nDim);
embedModel.ResetParameters();
// Identity = apply identity linear transformation + add zero bias.
embedModel.Parameters().rows(0, nDim * nDim - 1) =
arma::mat embedParams = arma::zeros(nDim * nDim + nDim, 1);
embedParams.rows(0, nDim * nDim - 1) =
arma::vectorise(arma::eye(nDim, nDim));
embedModel.Parameters().rows(nDim * nDim, nDim * nDim + nDim - 1) =
arma::zeros(nDim);
// Join function is sum of its two vector inputs.
FFN<MeanSquaredError<> > joinModel;
joinModel.Add<Linear<> >(2 * nDim, nDim);
joinModel.ResetParameters();
joinModel.Parameters().rows(0, nDim * nDim - 1) = arma::vectorise(arma::eye(nDim, nDim));
joinModel.Parameters().rows(nDim * nDim, 2 * nDim * nDim - 1) = arma::vectorise(arma::eye(nDim, nDim));
joinModel.Parameters().rows(2 * nDim * nDim, 2 * nDim * nDim + nDim - 1) = arma::zeros(nDim);
arma::mat joinParams = arma::zeros(2 * nDim * nDim + nDim, 1);
joinParams.rows(0, nDim * nDim - 1) = arma::vectorise(arma::eye(nDim, nDim));
joinParams.rows(nDim * nDim, 2 * nDim * nDim - 1) = arma::vectorise(arma::eye(nDim, nDim));
// Write function is replacing its old input with its new input.
FFN<MeanSquaredError<> > writeModel;
writeModel.Add<Linear<> >(2 * nDim, nDim);
writeModel.ResetParameters();
writeModel.Parameters().rows(0, nDim * nDim - 1) = arma::zeros(nDim * nDim);
writeModel.Parameters().rows(nDim * nDim, 2 * nDim * nDim - 1) =
arma::mat writeParams = arma::zeros(2 * nDim * nDim + nDim, 1);
writeParams.rows(nDim * nDim, 2 * nDim * nDim - 1) =
arma::vectorise(arma::eye(nDim, nDim));
writeModel.Parameters().rows(2 * nDim * nDim, 2 * nDim * nDim + nDim - 1) = arma::zeros(nDim);
// Search model is a constant model that ignores its input and returns 1 / 3.
FFN<MeanSquaredError<> > searchModel;
searchModel.Add<Linear<> >(2 * nDim, 1);
searchModel.ResetParameters();
searchModel.Parameters().rows(0, 2 * nDim - 1) = arma::zeros(2 * nDim);
searchModel.Parameters().at(2 * nDim) = -log(2);
searchModel.Add<SigmoidLayer<> >();
arma::mat searchParams = arma::zeros(2 * nDim + 1, 1);
searchParams.at(2 * nDim) = -log(2);
// Controller is a feedforward model: sigmoid(5x1 + x2 - x3 - 2x4).
FFN<CrossEntropyError<> > controller;
controller.Add<Linear<> >(nDim, 1);
controller.ResetParameters();
controller.Parameters().rows(0, nDim - 1) = arma::vec("5 1 -1 -2");
controller.Parameters().at(nDim) = 0;
controller.Add<SigmoidLayer<> >();
arma::mat controllerParams = arma::zeros(nDim + 1, 1);
controllerParams.rows(0, nDim - 1) = arma::vec("5 1 -1 -2");

// Pack all the parameters into a single vector.
arma::mat allParams(
embedParams.n_elem + searchParams.n_elem + controllerParams.n_elem +
joinParams.n_elem + writeParams.n_elem, 1);
size_t ptr = 0;
std::vector<arma::mat*> ordering{
&embedParams,
&searchParams,
&controllerParams,
&joinParams,
&writeParams
};
for (arma::mat* el : ordering)
{
allParams.rows(ptr, ptr + el->n_elem - 1) = *el;
ptr += el->n_elem;
}

// Now run the HAM unit (the initial sequence is:
// [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1])
HAMUnit<> hamUnit(seqLen, nDim, embedModel, joinModel, searchModel, writeModel, controller);
hamUnit.ResetParameters();
hamUnit.Parameters() = allParams;

arma::mat input("1 0 0 0; 0 1 0 0; 0 0 1 0; 0 0 0 1;");
input = input.t();
arma::mat output;
hamUnit.Forward(std::move(input), std::move(output));
std::cerr << output;
arma::mat targetOutput("0.4174; 0.4743; 0.5167; 0.5485;");
BOOST_REQUIRE_SMALL(arma::abs(output - targetOutput).max(), 1e-4);

Expand Down