Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/main/java/org/numenta/nupic/Connections.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ public class Connections implements Persistable {
private double localAreaDensity = -1.0;
private double numActiveColumnsPerInhArea;
private double stimulusThreshold = 0;
private double synPermInactiveDec = 0.01;
private double synPermActiveInc = 0.10;
private double synPermInactiveDec = 0.008;
private double synPermActiveInc = 0.05;
private double synPermConnected = 0.10;
private double synPermBelowStimulusInc = synPermConnected / 10.0;
private double minPctOverlapDutyCycles = 0.001;
Expand Down Expand Up @@ -242,6 +242,13 @@ public class Connections implements Persistable {
*/
public Connections() {}

/**
* Sets the derived values of the {@link SpatialPooler}'s initialization.
*/
public void doSpatialPoolerPostInit() {
synPermBelowStimulusInc = synPermConnected / 10.0;
synPermTrimThreshold = synPermActiveInc / 2.0;
}

/////////////////////////////////////////
// General Methods //
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/org/numenta/nupic/Parameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,10 @@ public void apply(Object cn) {
Set<KEY> presentKeys = paramMap.keySet();
synchronized (paramMap) {
for (KEY key : presentKeys) {
if((cn instanceof Connections) &&
(key == KEY.SYN_PERM_BELOW_STIMULUS_INC || key == KEY.SYN_PERM_TRIM_THRESHOLD)) {
continue;
}
if(key == KEY.RANDOM) {
((Random)get(key)).setSeed(Long.valueOf(((int)get(KEY.SEED))));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ public boolean isValidEstimatorParams(NamedTuple params) {
public static class AnomalyParams extends NamedTuple {
private static final long serialVersionUID = 1L;

/** Cached Json formatting. Possible because Objects of this class is immutable */
/** Cached Json formatting. Possible because Objects of this class are immutable */
private ObjectNode cachedNode;

private final Statistic distribution;
Expand Down Expand Up @@ -748,6 +748,49 @@ public String toJson(boolean doPrettyPrint) {
public String toJson() {
return toJson(false);
}

/* (non-Javadoc)
* @see java.lang.Object#hashCode()
*/
@Override
public int hashCode() {
final int prime = 31;
int result = super.hashCode();
result = prime * result + ((distribution == null) ? 0 : distribution.hashCode());
result = prime * result + Arrays.hashCode(historicalLikelihoods);
result = prime * result + ((movingAverage == null) ? 0 : movingAverage.hashCode());
result = prime * result + windowSize;
return result;
}

/* (non-Javadoc)
* @see java.lang.Object#equals(java.lang.Object)
*/
@Override
public boolean equals(Object obj) {
if(this == obj)
return true;
if(!super.equals(obj))
return false;
if(getClass() != obj.getClass())
return false;
AnomalyParams other = (AnomalyParams)obj;
if(distribution == null) {
if(other.distribution != null)
return false;
} else if(!distribution.equals(other.distribution))
return false;
if(!Arrays.equals(historicalLikelihoods, other.historicalLikelihoods))
return false;
if(movingAverage == null) {
if(other.movingAverage != null)
return false;
} else if(!movingAverage.equals(other.movingAverage))
return false;
if(windowSize != other.windowSize)
return false;
return true;
}
}

// Table lookup for Q function, from wikipedia
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ public void init(Connections c) {
throw new InvalidSPParamValueException("Inhibition parameters are invalid");
}

c.doSpatialPoolerPostInit();
initMatrices(c);
connectAndConfigureInputs(c);
}
Expand Down
53 changes: 52 additions & 1 deletion src/test/java/org/numenta/nupic/ConnectionsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.numenta.nupic.Connections.Activity;
import org.numenta.nupic.Connections.SegmentOverlap;
import org.numenta.nupic.Parameters.KEY;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.model.Cell;
import org.numenta.nupic.model.Column;
Expand Down Expand Up @@ -573,7 +574,7 @@ public void testGetPrintString() {
TemporalMemory.init(con);

String output = con.getPrintString();
assertEquals(1369, output.length());
assertEquals(1370, output.length());

Set<String> fieldSet = Parameters.getEncoderDefaultParameters().keys().stream().
map(k -> k.getFieldName()).collect(Collectors.toCollection(LinkedHashSet::new));
Expand All @@ -591,6 +592,56 @@ public void testGetPrintString() {
}
}

@Test
public void testDoSpatialPoolerPostInit() {
Parameters p = getParameters();
p.set(KEY.SYN_PERM_CONNECTED, 0.2);
p.set(KEY.SYN_PERM_ACTIVE_INC, 0.003);

///////////////////// First without Post Init /////////////////////
SpatialPooler sp = new SpatialPooler();
@SuppressWarnings("serial")
Connections conn = new Connections() {
@Override
public void doSpatialPoolerPostInit() {
// Override to do nothing
}
};
p.apply(conn);
sp.init(conn);

double synPermConnected = conn.getSynPermConnected();
double synPermActiveInc = conn.getSynPermActiveInc();
double synPermBelowStimulusInc = conn.getSynPermBelowStimulusInc();
double synPermTrimThreshold = conn.getSynPermTrimThreshold();

// Assert that static values (synPermConnected & synPermActiveInc) don't change,
// and that synPermBelowStimulusInc & synPermTrimThreshold are the defaults
assertEquals(0.2, synPermConnected, 0.001);
assertEquals(0.003, synPermActiveInc, 0.001);
assertEquals(0.01, synPermBelowStimulusInc, 0.001);
assertEquals(0.025, synPermTrimThreshold, 0.0001);


///////////////////// Now with Post Init /////////////////////
sp = new SpatialPooler();
conn = new Connections();
p.apply(conn);
sp.init(conn);

synPermConnected = conn.getSynPermConnected();
synPermActiveInc = conn.getSynPermActiveInc();
synPermBelowStimulusInc = conn.getSynPermBelowStimulusInc();
synPermTrimThreshold = conn.getSynPermTrimThreshold();

// Assert that static values (synPermConnected & synPermActiveInc) don't change,
// and that synPermBelowStimulusInc & synPermTrimThreshold change due to postInit()
assertEquals(0.2, synPermConnected, 0.001);
assertEquals(0.003, synPermActiveInc, 0.001);
assertEquals(0.02, synPermBelowStimulusInc, 0.001); // affected by postInit()
assertEquals(0.0015, synPermTrimThreshold, 0.0001); // affected by postInit()
}

public static Parameters getParameters() {
Parameters parameters = Parameters.getAllDefaultParameters();
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { 8 });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ public double compute(int[] activeColumns, int[] predictedColumns, double inputV
public void testHashCodeAndEquals() {
double[] likelihoods = new double[] { 0.2, 0.3 };

Sample s = new Sample(new DateTime(), 0.1, 0.1);
DateTime metricTime = new DateTime();
Sample s = new Sample(metricTime, 0.1, 0.1);
List<Sample> samples = new ArrayList<>();
samples.add(s);
TDoubleList d = new TDoubleArrayList();
Expand Down Expand Up @@ -217,6 +218,33 @@ public double compute(int[] activeColumns, int[] predictedColumns, double inputV
AnomalyLikelihoodMetrics metrics6 = new AnomalyLikelihoodMetrics(likelihoods, avges, params5);

assertNotEquals(metrics, metrics6);

//////////////////////////
// Test same Samples / Different Params
likelihoods = new double[] { 0.2, 0.3 };

s = new Sample(metricTime, 0.1, 0.1);
samples = new ArrayList<>();
samples.add(s);
d = new TDoubleArrayList();
d.add(0.5);
total = 0.4;
avges = (
new Anomaly() {
private static final long serialVersionUID = 1L;
@Override
public double compute(int[] activeColumns, int[] predictedColumns, double inputValue, long timestamp) {
return 0;
}
}
).new AveragedAnomalyRecordList(samples, d, total);

Statistic stat6 = new Statistic(0.1, 0.1, 0.1);
MovingAverage ma6 = new MovingAverage(new TDoubleArrayList(), 1);
AnomalyParams params6 = new AnomalyParams(new String[] { Anomaly.KEY_DIST, Anomaly.KEY_MVG_AVG, Anomaly.KEY_HIST_LIKE}, stat6, ma6, likelihoods);
AnomalyLikelihoodMetrics metrics7 = new AnomalyLikelihoodMetrics(likelihoods, avges, params6);

assertNotEquals(metrics, metrics7);
}

}
37 changes: 37 additions & 0 deletions src/test/java/org/numenta/nupic/util/NearestNeighborTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package org.numenta.nupic.util;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;

import org.junit.Test;


public class NearestNeighborTest {

@Test
public void testVecLpDist() {
NearestNeighbor nn = new NearestNeighbor(5, 10);
assertNull(nn.vecLpDist(0.0, null, false));
}

@Test
public void testRightVecSumAtNZ() {
int[][] connectedSynapses = new int[][]{
{1, 0, 0, 0, 0, 1, 0, 0, 0, 0},
{0, 1, 0, 0, 0, 0, 1, 0, 0, 0},
{0, 0, 1, 0, 0, 0, 0, 1, 0, 0},
{0, 0, 0, 1, 0, 0, 0, 0, 1, 0},
{0, 0, 0, 0, 1, 0, 0, 0, 0, 1}};

int[] inputVector = new int[]{1, 0, 1, 0, 1, 0, 1, 0, 1, 0};
int[] trueResults = new int[]{1, 1, 1, 1, 1};

NearestNeighbor nn = new NearestNeighbor(5, 10);
int[] result = nn.rightVecSumAtNZ(inputVector, connectedSynapses);

for (int i = 0; i < result.length; i++) {
assertEquals(trueResults[i], result[i]);
}
}

}