Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cereal for TMRegion #421

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions src/nupic/engine/Link.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,24 +478,22 @@ class Link : public Serializable
void save_ar(Archive& ar) const {
ar(cereal::make_nvp("srcRegionName", srcRegionName_),
cereal::make_nvp("srcOutputName", srcOutputName_),
cereal::make_nvp("destRegionName",destRegionName_),
cereal::make_nvp("destRegionName", destRegionName_),
cereal::make_nvp("destInputName", destInputName_),
cereal::make_nvp("destOffset", destOffset_),
cereal::make_nvp("is_FanIn", is_FanIn_),
cereal::make_nvp("propagationDelayBuffer", propagationDelayBuffer_)
);
cereal::make_nvp("destOffset", destOffset_),
cereal::make_nvp("is_FanIn", is_FanIn_));
ar(cereal::make_nvp("propagationDelayBuffer", propagationDelayBuffer_));
}
// FOR Cereal Deserialization
template<class Archive>
void load_ar(Archive& ar) {
ar(cereal::make_nvp("srcRegionName", srcRegionName_),
cereal::make_nvp("srcOutputName", srcOutputName_),
cereal::make_nvp("destRegionName",destRegionName_),
cereal::make_nvp("destRegionName", destRegionName_),
cereal::make_nvp("destInputName", destInputName_),
cereal::make_nvp("destOffset", destOffset_),
cereal::make_nvp("is_FanIn", is_FanIn_),
cereal::make_nvp("propagationDelayBuffer", propagationDelayBuffer_)
);
cereal::make_nvp("destOffset", destOffset_),
cereal::make_nvp("is_FanIn", is_FanIn_));
ar(cereal::make_nvp("propagationDelayBuffer", propagationDelayBuffer_));
propagationDelay_ = propagationDelayBuffer_.size();
initialized_ = false;
}
Expand Down
2 changes: 1 addition & 1 deletion src/nupic/engine/Region.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ bool Region::isParameter(const std::string &name) const {
return (spec_->parameters.contains(name));
}

// Some functions used to prevent symbles from being in Region.hpp
// Some functions used to prevent symbols from being in Region.hpp
void Region::getDims_(std::map<std::string,Dimensions>& outDims,
std::map<std::string,Dimensions>& inDims) const {
for(auto out: outputs_) {
Expand Down
64 changes: 64 additions & 0 deletions src/nupic/regions/TMRegion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,70 @@ class TMRegion : public RegionImpl, Serializable {
void serialize(BundleIO &bundle) override;
void deserialize(BundleIO &bundle) override;

CerealAdapter; // see Serializable.hpp
// FOR Cereal Serialization
template<class Archive>
void save_ar(Archive& ar) const {
bool init = ((tm_) ? true : false);
ar(cereal::make_nvp("numberOfCols", args_.numberOfCols));
ar(cereal::make_nvp("cellsPerColumn", args_.cellsPerColumn));
ar(cereal::make_nvp("activationThreshold", args_.activationThreshold));
ar(cereal::make_nvp("initialPermanence", args_.initialPermanence));
ar(cereal::make_nvp("connectedPermanence", args_.connectedPermanence));
ar(cereal::make_nvp("maxNewSynapseCount", args_.maxNewSynapseCount));
ar(cereal::make_nvp("permanenceIncrement", args_.permanenceIncrement));
ar(cereal::make_nvp("permanenceDecrement", args_.permanenceDecrement));
ar(cereal::make_nvp("predictedSegmentDecrement", args_.predictedSegmentDecrement));
ar(cereal::make_nvp("seed", args_.seed));
ar(cereal::make_nvp("maxSegmentsPerCell", args_.maxSegmentsPerCell));
ar(cereal::make_nvp("maxSynapsesPerSegment", args_.maxSynapsesPerSegment));
ar(cereal::make_nvp("extra", args_.extra));
ar(cereal::make_nvp("checkInputs", args_.checkInputs));
ar(cereal::make_nvp("learningMode", args_.learningMode));
ar(cereal::make_nvp("sequencePos", args_.sequencePos));
ar(cereal::make_nvp("iter", args_.iter));
ar(cereal::make_nvp("orColumnOutputs", args_.orColumnOutputs));
ar(cereal::make_nvp("dim", dim_)); // from RegionImpl
ar(cereal::make_nvp("init", init));
if (init) {
// Save the algorithm state
ar(cereal::make_nvp("TM", tm_));
}
}

// FOR Cereal Deserialization
template<class Archive>
void load_ar(Archive& ar) {
ar(cereal::make_nvp("numberOfCols", args_.numberOfCols));
ar(cereal::make_nvp("cellsPerColumn", args_.cellsPerColumn));
ar(cereal::make_nvp("activationThreshold", args_.activationThreshold));
ar(cereal::make_nvp("initialPermanence", args_.initialPermanence));
ar(cereal::make_nvp("connectedPermanence", args_.connectedPermanence));
ar(cereal::make_nvp("maxNewSynapseCount", args_.maxNewSynapseCount));
ar(cereal::make_nvp("permanenceIncrement", args_.permanenceIncrement));
ar(cereal::make_nvp("permanenceDecrement", args_.permanenceDecrement));
ar(cereal::make_nvp("predictedSegmentDecrement", args_.predictedSegmentDecrement));
ar(cereal::make_nvp("seed", args_.seed));
ar(cereal::make_nvp("maxSegmentsPerCell", args_.maxSegmentsPerCell));
ar(cereal::make_nvp("maxSynapsesPerSegment", args_.maxSynapsesPerSegment));
ar(cereal::make_nvp("extra", args_.extra));
ar(cereal::make_nvp("checkInputs", args_.checkInputs));
ar(cereal::make_nvp("learningMode", args_.learningMode));
ar(cereal::make_nvp("sequencePos", args_.sequencePos));
ar(cereal::make_nvp("iter", args_.iter));
ar(cereal::make_nvp("orColumnOutputs", args_.orColumnOutputs));
args_.outputWidth = (args_.orColumnOutputs)?args_.numberOfCols
: (args_.numberOfCols * args_.cellsPerColumn);
ar(cereal::make_nvp("dim", dim_)); // from RegionImpl
ar(cereal::make_nvp("init", args_.init));
if (args_.init) {
// Restore algorithm state
nupic::algorithms::temporal_memory::TemporalMemory* tm = new nupic::algorithms::temporal_memory::TemporalMemory();
tm_.reset(tm);
ar(cereal::make_nvp("TM", tm_));
}
}

// Per-node size (in elements) of the given output.
// For per-region outputs, it is the total element count.
// This method is called only for outputs whose size is not
Expand Down
9 changes: 0 additions & 9 deletions src/test/unit/regions/SPRegionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,6 @@ namespace testing
<< "actual type is " << BasicType::getName(r1OutputArray.getType());

Real32 *buffer1 = (Real32*) r1OutputArray.getBuffer();
//for (size_t i = 0; i < r1OutputArray.getCount(); i++)
//{
//VERBOSE << " [" << i << "]= " << buffer1[i] << "" << std::endl;
//}

VERBOSE << " SPRegion input" << std::endl;
Array r2InputArray = region2->getInputData("bottomUpIn");
Expand Down Expand Up @@ -324,11 +320,6 @@ TEST(SPRegionTest, testSerialization)
std::map<std::string, std::string> parameterMap;
EXPECT_TRUE(captureParameters(n1region2, parameterMap)) << "Capturing parameters before save.";

// TODO: JSON serialization does not work.
// returns 3 (not really a crash)
// It fails returning from SpatialPooler, in rapidjson::PrettyWriter.h line 128
// It is apparently checking that it is not in array mode.

Directory::removeTree("TestOutputDir", true);
VERBOSE << "Writing stream to " << Path::makeAbsolute("TestOutputDir/spRegionTest.stream") << "\n";
net1.saveToFile_ar("TestOutputDir/spRegionTest.stream", SerializableFormat::JSON);
Expand Down
179 changes: 77 additions & 102 deletions src/test/unit/regions/TMRegionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,110 +335,85 @@ TEST(TMRegionTest, testLinking) {
}

TEST(TMRegionTest, testSerialization) {
// use default parameters the first time
Network *net1 = new Network();
Network *net2 = nullptr;
Network *net3 = nullptr;

try {

VERBOSE << "Setup first network and save it" << std::endl;
std::shared_ptr<Region> n1region1 = net1->addRegion( "region1", "ScalarSensor",
"{n: 48,w: 10,minValue: 0.05,maxValue: 10}");
n1region1->setParameterReal64("sensedValue", 5.0);

std::shared_ptr<Region> n1region2 = net1->addRegion("region2", "TMRegion", "{numberOfCols: 48}");

net1->link("region1", "region2", "", "", "encoded", "bottomUpIn");
VERBOSE << "Initialize" << std::endl;
net1->initialize();

VERBOSE << "compute region1" << std::endl;
n1region1->prepareInputs();
n1region1->compute();

VERBOSE << "compute region2" << std::endl;
n1region2->prepareInputs();
n1region2->compute();

// take a snapshot of everything in TMRegion at this point
// save to a bundle.
std::map<std::string, std::string> parameterMap;
EXPECT_TRUE(captureParameters(n1region2, parameterMap))
<< "Capturing parameters before save.";

VERBOSE << "saveToFile" << std::endl;
Directory::removeTree("TestOutputDir", true);
net1->saveToFile("TestOutputDir/tmRegionTest.stream");

VERBOSE << "Restore from bundle into a second network and compare." << std::endl;
net2 = new Network();
net2->loadFromFile("TestOutputDir/tmRegionTest.stream");


VERBOSE << "checked restored network" << std::endl;
std::shared_ptr<Region> n2region2 = net2->getRegion("region2");
ASSERT_TRUE(n2region2->getType() == "TMRegion")
<< " Restored TMRegion region does not have the right type. Expected "
"TMRegion, found "
<< n2region2->getType();

EXPECT_FLOAT_EQ(n2region2->getParameterReal32("anomaly"), 1.0f);
EXPECT_TRUE(compareParameters(n2region2, parameterMap))
<< "Conflict when comparing TMRegion parameters after restore with "
"before save.";

VERBOSE << "continue with execution." << std::endl;
// can we continue with execution? See if we get any exceptions.
n1region1->setParameterReal64("sensedValue", 0.12);
n1region1->prepareInputs();
n1region1->compute();

n2region2->prepareInputs();
VERBOSE << "continue 4." << std::endl;
n2region2->compute();
VERBOSE << "continue 5." << std::endl;

// Change a parameters and see if it is retained after a restore.
n2region2->setParameterReal32("permanenceDecrement", 0.099f);
n2region2->compute();

parameterMap.clear();
EXPECT_TRUE(captureParameters(n2region2, parameterMap))
<< "Capturing parameters before second save.";
// serialize using a stream to a single file
VERBOSE << "save second network." << std::endl;
net2->saveToFile("TestOutputDir/tmRegionTest.stream");

VERBOSE << "Restore into a third network and compare changed parameters." << std::endl;
net3 = new Network();
net3->loadFromFile("TestOutputDir/tmRegionTest.stream");
std::shared_ptr<Region> n3region2 = net3->getRegion("region2");
EXPECT_TRUE(n3region2->getType() == "TMRegion")
<< "Failure: Restored region does not have the right type. "
" Expected \"TMRegion\", found \""
<< n3region2->getType() << "\".";

EXPECT_TRUE(compareParameters(n3region2, parameterMap))
<< "Comparing parameters after second restore with before save.";
} catch (nupic::Exception &ex) {
FAIL() << "Failure: Exception: " << ex.getFilename() << "("
<< ex.getLineNumber() << ") " << ex.getMessage() << "" << std::endl;
} catch (std::exception &e) {
FAIL() << "Failure: Exception: " << e.what() << "" << std::endl;
}
Network net1;
Network net2;
Network net3;

VERBOSE << "Setup first network and save it" << std::endl;
std::shared_ptr<Region> n1region1 = net1.addRegion( "region1", "ScalarSensor",
"{n: 48,w: 10,minValue: 0.05,maxValue: 10}");
n1region1->setParameterReal64("sensedValue", 5.0);

std::shared_ptr<Region> n1region2 = net1.addRegion("region2", "TMRegion", "{numberOfCols: 48}");

net1.link("region1", "region2", "", "", "encoded", "bottomUpIn");
VERBOSE << "Initialize" << std::endl;
net1.initialize();

VERBOSE << "compute region1" << std::endl;
n1region1->prepareInputs();
n1region1->compute();

VERBOSE << "compute region2" << std::endl;
n1region2->prepareInputs();
n1region2->compute();

// take a snapshot of everything in TMRegion at this point
// save to a bundle.
std::map<std::string, std::string> parameterMap;
EXPECT_TRUE(captureParameters(n1region2, parameterMap))
<< "Capturing parameters before save.";

VERBOSE << "saveToFile" << std::endl;
Directory::removeTree("TestOutputDir", true);
net1.saveToFile_ar("TestOutputDir/tmRegionTest.stream");

VERBOSE << "Restore from bundle into a second network and compare." << std::endl;
net2.loadFromFile_ar("TestOutputDir/tmRegionTest.stream");

VERBOSE << "checked restored network" << std::endl;
std::shared_ptr<Region> n2region2 = net2.getRegions().getByName("region2");
ASSERT_TRUE(n2region2->getType() == "TMRegion")
<< " Restored TMRegion region does not have the right type. Expected "
"TMRegion, found "
<< n2region2->getType();

EXPECT_TRUE(compareParameters(n2region2, parameterMap))
<< "Conflict when comparing TMRegion parameters after restore with "
"before save.";

VERBOSE << "continue with execution." << std::endl;
// can we continue with execution? See if we get any exceptions.
n1region1->setParameterReal64("sensedValue", 0.12);
n1region1->prepareInputs();
n1region1->compute();

n2region2->prepareInputs();
n2region2->compute();

// Change a parameters and see if it is retained after a restore.
n2region2->setParameterReal32("permanenceDecrement", 0.099f);
n2region2->compute();

parameterMap.clear();
EXPECT_TRUE(captureParameters(n2region2, parameterMap))
<< "Capturing parameters before second save.";
// serialize using a stream to a single file
VERBOSE << "save second network." << std::endl;
net2.saveToFile_ar("TestOutputDir/tmRegionTest.stream");

VERBOSE << "Restore into a third network and compare changed parameters." << std::endl;
net3.loadFromFile_ar("TestOutputDir/tmRegionTest.stream");
std::shared_ptr<Region> n3region2 = net3.getRegions().getByName("region2");
EXPECT_TRUE(n3region2->getType() == "TMRegion")
<< "Failure: Restored region does not have the right type. "
" Expected \"TMRegion\", found \""
<< n3region2->getType() << "\".";

EXPECT_TRUE(compareParameters(n3region2, parameterMap))
<< "Comparing parameters after second restore with before save.";

VERBOSE << "Cleanup" << std::endl;
// cleanup
if (net1 != nullptr) {
delete net1;
}
if (net2 != nullptr) {
delete net2;
}
if (net3 != nullptr) {
delete net3;
}
Directory::removeTree("TestOutputDir", true);
}

Expand Down