Skip to content

Commit

Permalink
Merge 4cb2faf into 9077795
Browse files Browse the repository at this point in the history
  • Loading branch information
Hopding committed Mar 7, 2017
2 parents 9077795 + 4cb2faf commit 199210b
Show file tree
Hide file tree
Showing 12 changed files with 274 additions and 48 deletions.
13 changes: 8 additions & 5 deletions src/main/java/org/numenta/nupic/Parameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,19 @@
package org.numenta.nupic;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.EnumMap;
import java.util.Arrays;

import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.model.Cell;
import org.numenta.nupic.model.Segment;
import org.numenta.nupic.model.Column;
import org.numenta.nupic.model.ComputeCycle;
import org.numenta.nupic.model.DistalDendrite;
Expand Down Expand Up @@ -417,8 +418,10 @@ public static enum KEY {

// Network Layer indicator for auto classifier generation
AUTO_CLASSIFY("hasClassifiers", Boolean.class),



/** Maps encoder input field name to type of classifier to be used for them */
INFERRED_FIELDS("inferredFields", Map.class), // Map<String, Classifier.class>

// How many bits to use if encoding the respective date fields.
// e.g. Tuple(bits to use:int, radius:double)
DATEFIELD_SEASON("season", Tuple.class),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
* @author David Ray
* @see BitHistory
*/
public class CLAClassifier implements Persistable {
public class CLAClassifier implements Persistable, Classifier {
private static final long serialVersionUID = 1L;

int verbosity = 0;
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/org/numenta/nupic/algorithms/Classifier.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.numenta.nupic.algorithms;

import java.util.Map;

/**
* Classifier is an interface for Classifier types used to predict future inputs
* to the system, such as {@link CLAClassifier} or {@link SDRClassifier}.
*/
public interface Classifier {
public <T> Classification<T> compute(int recordNum,
Map<String, Object> classification,
int[] patternNZ,
boolean learn,
boolean infer);
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
* @author David Ray
* @author Andrew Dillon
*/
public class SDRClassifier implements Persistable {
public class SDRClassifier implements Persistable, Classifier {
private static final long serialVersionUID = 1L;

int verbosity = 0;
Expand Down
74 changes: 55 additions & 19 deletions src/main/java/org/numenta/nupic/network/Layer.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
import org.numenta.nupic.FieldMetaType;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.Parameters.KEY;
import org.numenta.nupic.algorithms.Anomaly;
import org.numenta.nupic.algorithms.CLAClassifier;
import org.numenta.nupic.algorithms.Classification;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.Anomaly;
import org.numenta.nupic.algorithms.Classifier;
import org.numenta.nupic.algorithms.SDRClassifier;
import org.numenta.nupic.algorithms.CLAClassifier;
import org.numenta.nupic.encoders.DateEncoder;
import org.numenta.nupic.encoders.Encoder;
import org.numenta.nupic.encoders.EncoderTuple;
Expand Down Expand Up @@ -231,7 +233,7 @@ public class Layer<T> implements Persistable {
private boolean hasGenericProcess;

/**
* List of {@link Encoders} used when storing bucket information see
* List of {@link Encoder}s used when storing bucket information see
* {@link #doEncoderBucketMapping(Inference, Map)}
*/
private List<EncoderTuple> encoderTuples;
Expand Down Expand Up @@ -399,7 +401,7 @@ public Layer(Parameters params, MultiEncoder e, SpatialPooler sp, TemporalMemory
(encoder == null ? "" : "MultiEncoder,"),
(spatialPooler == null ? "" : "SpatialPooler,"),
(temporalMemory == null ? "" : "TemporalMemory,"),
(autoCreateClassifiers == null ? "" : "Auto creating CLAClassifiers for each input field."),
(autoCreateClassifiers == null ? "" : "Auto creating Classifiers for each input field."),
(anomalyComputer == null ? "" : "Anomaly"));
}
}
Expand Down Expand Up @@ -1048,7 +1050,7 @@ public void start() {
/**
* Restarts this {@code Layer}
*
* {@link #restart()} is to be called after a call to {@link #halt()}, to begin
* {@link #restart} is to be called after a call to {@link #halt()}, to begin
* processing again. The {@link Network} will continue from where it previously
* left off after the last call to halt().
*
Expand Down Expand Up @@ -1180,7 +1182,7 @@ public Set<Cell> getPredictiveCells() {
}

/**
* Returns the previous predictive {@link Cells}
* Returns the previous predictive {@link Cell}s
*
* @return the binary vector representing the current prediction.
*/
Expand Down Expand Up @@ -1472,7 +1474,7 @@ void notifyError(Exception e) {
* </p>
* <p>
* If any algorithms are repeated then {@link Inference}s will
* <em><b>NOT</b></em> be shared between layers. {@link Regions}
* <em><b>NOT</b></em> be shared between layers. {@link Region}s
* <em><b>NEVER</b></em> share {@link Inference}s
* </p>
*
Expand Down Expand Up @@ -1657,7 +1659,7 @@ private Observable<ManualInput> resolveObservableSequence(T t) {

/**
* Executes the check point logic, handles the return of the serialized byte array
* by delegating the call to {@link rx.Observer#onNext(byte[])} of all the currently queued
* by delegating the call to {@link rx.Observer#onNext}(byte[]) of all the currently queued
* Observers; then clears the list of Observers.
*/
private void doCheckPoint() {
Expand Down Expand Up @@ -1712,7 +1714,15 @@ private void doEncoderBucketMapping(Inference inference, Map<String, Object> enc
int[] tempArray = new int[e.getWidth()];
System.arraycopy(encoding, offset, tempArray, 0, tempArray.length);

inference.getClassifierInput().put(name, new NamedTuple(new String[] { "name", "inputValue", "bucketIdx", "encoding" }, name, o, bucketIdx, tempArray));
inference.getClassifierInput().put(
name,
new NamedTuple(
new String[] { "name", "inputValue", "bucketIdx", "encoding" },
name,
o,
bucketIdx,
tempArray
));
}
}

Expand Down Expand Up @@ -1798,9 +1808,9 @@ private Observable<ManualInput> fillInOrderedSequence(Observable<ManualInput> o)

/**
* Called internally to create a subscription on behalf of the specified
* {@link LayerObserver}
* Layer {@link Observer}
*
* @param sub the LayerObserver (subscriber).
* @param sub the Layer Observer (subscriber).
* @return
*/
private Subscription createSubscription(final Observer<Inference> sub) {
Expand Down Expand Up @@ -1909,12 +1919,36 @@ private void clearSubscriberObserverLists() {
* @return
*/
NamedTuple makeClassifiers(MultiEncoder encoder) {
Map inferredFields = (Map<String, Class<? extends Classifier>>) params.get(KEY.INFERRED_FIELDS);
if(inferredFields == null || inferredFields.entrySet().size() == 0) {
throw new IllegalStateException(
"KEY.AUTO_CLASSIFY has been set to \"true\", but KEY.INFERRED_FIELDS is null or\n\t" +
"empty. Must specify desired Classifier for at least one input field in\n\t" +
"KEY.INFERRED_FIELDS or set KEY.AUTO_CLASSIFY to \"false\" (which is its default\n\t" +
"value in Parameters)."
);
}
String[] names = new String[encoder.getEncoders(encoder).size()];
CLAClassifier[] ca = new CLAClassifier[names.length];
Classifier[] ca = new Classifier[names.length];
int i = 0;
for(EncoderTuple et : encoder.getEncoders(encoder)) {
names[i] = et.getName();
ca[i] = new CLAClassifier();
Object fieldClassifier = inferredFields.get(et.getName());
if(fieldClassifier == CLAClassifier.class) {
LOGGER.info("Classifying \"" + et.getName() + "\" input field with CLAClassifier");
ca[i] = new CLAClassifier();
} else if(fieldClassifier == SDRClassifier.class) {
LOGGER.info("Classifying \"" + et.getName() + "\" input field with SDRClassifier");
ca[i] = new SDRClassifier();
} else if(fieldClassifier != null) {
throw new IllegalStateException(
"Invalid Classifier class token, \"" + fieldClassifier + "\",\n\t" +
"specified for, \"" + et.getName() + "\", input field.\n\t" +
"Valid class tokens are CLAClassifier.class and SDRClassifier.class"
);
} else { // fieldClassifier is null
LOGGER.info("Not classifying \"" + et.getName() + "\" input field");
}
i++;
}
return new NamedTuple(names, (Object[])ca);
Expand Down Expand Up @@ -2014,8 +2048,7 @@ public void run() {
* that stores the state of this {@code Network} while keeping the Network up and running.
* The Network will be stored at the pre-configured location (in binary form only, not JSON).
*
* @param network the {@link Network} to check point.
* @return the {@link CheckPointOp} operator
* @return the {@link CheckPointOp} operator
*/
@SuppressWarnings("unchecked")
CheckPointOp<byte[]> getCheckPointOperator() {
Expand Down Expand Up @@ -2328,10 +2361,13 @@ public ManualInput call(ManualInput t1) {
bucketIdx = inputs.get("bucketIdx");
actValue = inputs.get("inputValue");

CLAClassifier c = (CLAClassifier)t1.getClassifiers().get(key);
Classification<Object> result = c.compute(recordNum, inputMap, t1.getSDR(), isLearn, true);
Classifier c = (Classifier)t1.getClassifiers().get(key);

t1.recordNum(recordNum).storeClassification((String)inputs.get("name"), result);
// c will be null if no classifier was specified for this field in KEY.INFERRED_FIELDS map
if(c != null) {
Classification<Object> result = c.compute(recordNum, inputMap, t1.getSDR(), isLearn, true);
t1.recordNum(recordNum).storeClassification((String)inputs.get("name"), result);
}
}

return t1;
Expand Down
12 changes: 8 additions & 4 deletions src/main/java/org/numenta/nupic/network/ManualInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import java.util.Map;
import java.util.Set;

import org.numenta.nupic.algorithms.CLAClassifier;
import org.numenta.nupic.algorithms.Classifier;
import org.numenta.nupic.algorithms.Classification;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.TemporalMemory;
Expand Down Expand Up @@ -191,7 +191,9 @@ public ManualInput customObject(Object o) {

/**
* <p>
* Returns the {@link Map} used as input into the {@link CLAClassifier}
* Returns the {@link Map} used as input into the field's {@link Classifier}
* (it is only actually used as input if a Classifier type has specified for
* the field).
*
* This mapping contains the name of the field being classified mapped
* to a {@link NamedTuple} containing:
Expand Down Expand Up @@ -237,7 +239,7 @@ public ManualInput classifiers(NamedTuple tuple) {

/**
* Returns a {@link NamedTuple} keyed to the input field
* names, whose values are the {@link CLAClassifier} used
* names, whose values are the {@link Classifier} used
* to track the classification of a particular field
*/
@Override
Expand Down Expand Up @@ -341,10 +343,12 @@ ManualInput copy() {
* Returns the most recent {@link Classification}
*
* @param fieldName
* @return
* @return the most recent {@link Classification}, or null if none exists.
*/
@Override
public Classification<Object> getClassification(String fieldName) {
if(classification == null)
return null;
return classification.get(fieldName);
}

Expand Down

0 comments on commit 199210b

Please sign in to comment.