Skip to content
This repository has been archived by the owner on Sep 1, 2023. It is now read-only.

Commit

Permalink
NUP-2506: Add operator '==' to classes used in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lscheinkman committed Mar 28, 2018
1 parent 42dc0ac commit 860459c
Show file tree
Hide file tree
Showing 23 changed files with 535 additions and 1 deletion.
78 changes: 78 additions & 0 deletions bindings/py/tests/algorithms/cells4_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,53 @@

_RGEN = Random(43)

def createCells4(nCols=8,
nCellsPerCol=4,
activationThreshold=1,
minThreshold=1,
newSynapseCount=2,
segUpdateValidDuration=2,
permInitial=0.5,
permConnected=0.8,
permMax=1.0,
permDec=0.1,
permInc=0.2,
globalDecay=0.05,
doPooling=True,
pamLength=2,
maxAge=3,
seed=42,
initFromCpp=True,
checkSynapseConsistency=False):

cells = Cells4(nCols,
nCellsPerCol,
activationThreshold,
minThreshold,
newSynapseCount,
segUpdateValidDuration,
permInitial,
permConnected,
permMax,
permDec,
permInc,
globalDecay,
doPooling,
seed,
initFromCpp,
checkSynapseConsistency)


cells.setPamLength(pamLength)
cells.setMaxAge(maxAge)
cells.setMaxInfBacktrack(4)

for i in xrange(nCols):
for j in xrange(nCellsPerCol):
cells.addNewSegment(i, j, True if j % 2 == 0 else False,
[((i + 1) % nCols, (j + 1) % nCellsPerCol)])

return cells

class Cells4Test(unittest.TestCase):

Expand Down Expand Up @@ -205,3 +251,35 @@ def testLearn(self):
cells.compute(x, True, False)

self._testPersistence(cells)

def testEquals(self):
nCols = 10
c1 = createCells4(nCols)
c2 = createCells4(nCols)
self.assertEquals(c1, c2)

# learn
data = [numpy.random.choice(nCols, nCols/3, False) for _ in xrange(10)]
for idx in data:
x = numpy.zeros(nCols, dtype="float32")
x[idx] = 1.0
c1.compute(x, True, True)
c2.compute(x, True, True)
self.assertEquals(c1, c2)

self.assertEquals(c1, c2)

c1.rebuildOutSynapses()
c2.rebuildOutSynapses()
self.assertEquals(c1, c2)

# inference
data = [numpy.random.choice(nCols, nCols/3, False) for _ in xrange(100)]
for idx in data:
x = numpy.zeros(nCols, dtype="float32")
x[idx] = 1.0
c1.compute(x, True, False)
c2.compute(x, True, False)
self.assertEquals(c1, c2)

self.assertEquals(c1, c2)
7 changes: 7 additions & 0 deletions bindings/py/tests/nupic_random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,13 @@ def testShuffleBadDtype(self):
self.assertRaises(ValueError, r.shuffle, arr)


def testEquals(self):
r1 = Random(42)
v1 = r1.getReal64()
r2 = Random(42)
v2 = r2.getReal64()
self.assertEquals(v1, v2)
self.assertEquals(r1, r2)

if __name__ == "__main__":
unittest.main()
6 changes: 6 additions & 0 deletions src/nupic/algorithms/Cell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,9 @@ void Cell::load(std::istream &inStream) {
_freeSegments.push_back(i);
}
}
bool Cell::operator==(const Cell &other) const {
if (_freeSegments != other._freeSegments) {
return false;
}
return _segments == other._segments;
}
4 changes: 4 additions & 0 deletions src/nupic/algorithms/Cell.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ class Cell : Serializable<CellProto> {
return _segments[segIdx];
}

//----------------------------------------------------------------------
bool operator==(const Cell &other) const;
inline bool operator!=(const Cell &other) const { return !operator==(other); }

//--------------------------------------------------------------------------------
Segment &getSegment(UInt segIdx) {
NTA_ASSERT(segIdx < _segments.size());
Expand Down
52 changes: 52 additions & 0 deletions src/nupic/algorithms/Cells4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2801,6 +2801,58 @@ std::ostream &operator<<(std::ostream &outStream, const Cells4 &cells) {
return outStream;
}

bool Cells4::operator==(const Cells4 &other) const {

if (_activationThreshold != other._activationThreshold ||
_avgInputDensity != other._avgInputDensity ||
_avgLearnedSeqLength != other._avgLearnedSeqLength ||
_checkSynapseConsistency != other._checkSynapseConsistency ||
_doPooling != other._doPooling || _globalDecay != other._globalDecay ||
_initSegFreq != other._initSegFreq ||
_learnedSeqLength != other._learnedSeqLength ||
_maxAge != other._maxAge || _maxInfBacktrack != other._maxInfBacktrack ||
_maxLrnBacktrack != other._maxLrnBacktrack ||
_maxSegmentsPerCell != other._maxSegmentsPerCell ||
_maxSeqLength != other._maxSeqLength ||
_maxSynapsesPerSegment != other._maxSynapsesPerSegment ||
_minThreshold != other._minThreshold || _nCells != other._nCells ||
_nCellsPerCol != other._nCellsPerCol || _nColumns != other._nColumns ||
_newSynapseCount != other._newSynapseCount ||
_nIterations != other._nIterations ||
_nLrnIterations != other._nLrnIterations ||
_ownsMemory != other._ownsMemory || _pamCounter != other._pamCounter ||
_pamLength != other._pamLength ||
_permConnected != other._permConnected || _permDec != other._permDec ||
_permInc != other._permInc || _permInitial != other._permInitial ||
_permMax != other._permMax || _resetCalled != other._resetCalled ||
_segUpdateValidDuration != other._segUpdateValidDuration ||
_verbosity != other._verbosity || _version != other._version) {
return false;
}
if (_rng != other._rng) {
return false;
}
if (_cells != other._cells) {
return false;
}
if (_segmentUpdates != other._segmentUpdates) {
return false;
}
if (_learnActiveStateT != other._learnActiveStateT) {
return false;
}
if (_learnActiveStateT1 != other._learnActiveStateT1) {
return false;
}
if (_learnPredictedStateT != other._learnPredictedStateT) {
return false;
}
if (_learnPredictedStateT1 != other._learnPredictedStateT1) {
return false;
}
return true;
}

//----------------------------------------------------------------------
/**
* Compute cell and segment activities using forward propagation
Expand Down
6 changes: 6 additions & 0 deletions src/nupic/algorithms/Cells4.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,12 @@ class Cells4 : public Serializable<Cells4Proto> {
//----------------------------------------------------------------------
~Cells4();

//----------------------------------------------------------------------
bool operator==(const Cells4 &other) const;
inline bool operator!=(const Cells4 &other) const {
return !operator==(other);
}

//----------------------------------------------------------------------
UInt version() const { return _version; }

Expand Down
7 changes: 7 additions & 0 deletions src/nupic/algorithms/InSynapse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ class InSynapse {
return *this;
}

inline bool operator==(const InSynapse &other) const {
return _srcCellIdx == other._srcCellIdx && _permanence == other._permanence;
}
inline bool operator!=(const InSynapse &other) const {
return !operator==(other);
}

inline UInt srcCellIdx() const { return _srcCellIdx; }
const inline Real &permanence() const { return _permanence; }
inline Real &permanence() { return _permanence; }
Expand Down
14 changes: 14 additions & 0 deletions src/nupic/algorithms/Segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,20 @@ Segment &Segment::operator=(const Segment &o) {
return *this;
}

//--------------------------------------------------------------------------------
bool Segment::operator==(const Segment &other) const {
if (_totalActivations != other._totalActivations ||
_positiveActivations != other._positiveActivations ||
_lastActiveIteration != other._lastActiveIteration ||
_lastPosDutyCycle != other._lastPosDutyCycle ||
_lastPosDutyCycleIteration != other._lastPosDutyCycleIteration ||
_seqSegFlag != other._seqSegFlag || _frequency != other._frequency ||
_nConnected != other._nConnected) {
return false;
}
return _synapses == other._synapses;
}

//--------------------------------------------------------------------------------
Segment::Segment(const Segment &o)
: _totalActivations(o._totalActivations),
Expand Down
29 changes: 29 additions & 0 deletions src/nupic/algorithms/Segment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,19 @@ class CState : Serializable<CStateProto> {
memcpy(_pData, o._pData, _nCells);
return *this;
}
bool operator==(const CState &other) const {
if (_version != other._version || _nCells != other._nCells ||
_fMemoryAllocatedByPython != other._fMemoryAllocatedByPython) {
return false;
}
if (_pData != nullptr && other._pData != nullptr) {
return ::memcmp(_pData, other._pData, _nCells) == 0;
}
return _pData == other._pData;
}
inline bool operator!=(const CState &other) const {
return !operator==(other);
}
bool initialize(const UInt nCells) {
if (_nCells != 0) // if already initialized
return false; // don't do it again
Expand Down Expand Up @@ -217,6 +230,19 @@ class CStateIndexed : public CState {
_isSorted = o._isSorted;
return *this;
}
bool operator==(const CStateIndexed &other) const {
if (_version != other._version || _countOn != other._countOn ||
_isSorted != other._isSorted) {
return false;
}
if (_cellsOn != other._cellsOn) {
return false;
}
return CState::operator==(other);
}
inline bool operator!=(const CStateIndexed &other) const {
return !operator==(other);
}
std::vector<UInt> cellsOn(bool fSorted = false) {
// It's better for the caller to ask us to sort, rather than
// to sort himself, since we can optimize out the sort when we
Expand Down Expand Up @@ -342,6 +368,9 @@ class Segment : Serializable<SegmentProto> {
Real _lastPosDutyCycle;
UInt _lastPosDutyCycleIteration;

bool operator==(const Segment &o) const;
inline bool operator!=(const Segment &o) const { return !operator==(o); }

private:
bool _seqSegFlag; // sequence segment flag
Real _frequency; // frequency [UNUSED IN LATEST IMPLEMENTATION]
Expand Down
11 changes: 11 additions & 0 deletions src/nupic/algorithms/SegmentUpdate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ bool SegmentUpdate::invariants(Cells4 *cells) const {

return ok;
}

bool SegmentUpdate::operator==(const SegmentUpdate &o) const {

if (_cellIdx != o._cellIdx || _segIdx != o._segIdx ||
_sequenceSegment != o._sequenceSegment || _timeStamp != o._timeStamp ||
_phase1Flag != o._phase1Flag ||
_weaklyPredicting != o._weaklyPredicting) {
return false;
}
return _synapses == o._synapses;
}
5 changes: 5 additions & 0 deletions src/nupic/algorithms/SegmentUpdate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ class SegmentUpdate : Serializable<SegmentUpdateProto> {
NTA_ASSERT(invariants());
return *this;
}
//---------------------------------------------------------------------
bool operator==(const SegmentUpdate &other) const;
inline bool operator!=(const SegmentUpdate &other) const {
return !operator==(other);
}

//---------------------------------------------------------------------
bool isSequenceSegment() const { return _sequenceSegment; }
Expand Down
13 changes: 12 additions & 1 deletion src/nupic/engine/Link.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,18 @@ void Link::read(LinkProto::Reader &proto) {
}
}
}

bool Link::operator==(const Link &o) const {
if (initialized_ != o.initialized_ ||
propagationDelay_ != o.propagationDelay_ || linkType_ != o.linkType_ ||
linkParams_ != o.linkParams_ || destOffset_ != o.destOffset_ ||
srcRegionName_ != o.srcRegionName_ ||
destRegionName_ != o.destRegionName_ ||
srcOutputName_ != o.srcOutputName_ ||
destInputName_ != o.destInputName_) {
return false;
}
return true;
}
namespace nupic {
std::ostream &operator<<(std::ostream &f, const Link &link) {
f << "<Link>\n";
Expand Down
3 changes: 3 additions & 0 deletions src/nupic/engine/Link.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,9 @@ class Link : public Serializable<LinkProto> {
using Serializable::read;
void read(LinkProto::Reader &proto);

bool operator==(const Link &other) const;
inline bool operator!=(const Link &other) const { return !operator==(other); }

private:
// common initialization for the two constructors.
void commonConstructorInit_(const std::string &linkType,
Expand Down
19 changes: 19 additions & 0 deletions src/nupic/engine/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1061,4 +1061,23 @@ void Network::unregisterCPPRegion(const std::string name) {
Region::unregisterCPPRegion(name);
}

bool Network::operator==(const Network &o) const {

if (initialized_ != o.initialized_ || iteration_ != o.iteration_ ||
minEnabledPhase_ != o.minEnabledPhase_ ||
maxEnabledPhase_ != o.maxEnabledPhase_ ||
regions_.getCount() != o.regions_.getCount()) {
return false;
}

for (size_t i = 0; i < regions_.getCount(); i++) {
Region *r1 = regions_.getByIndex(i).second;
Region *r2 = o.regions_.getByIndex(i).second;
if (*r1 != *r2) {
return false;
}
}
return true;
}

} // namespace nupic
5 changes: 5 additions & 0 deletions src/nupic/engine/Network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,11 @@ class Network : public Serializable<NetworkProto> {
*/
static void unregisterCPPRegion(const std::string name);

bool operator==(const Network &other) const;
inline bool operator!=(const Network &other) const {
return !operator==(other);
}

private:
// Both constructors use this common initialization method
void commonInit();
Expand Down
Loading

0 comments on commit 860459c

Please sign in to comment.