Skip to content
This repository has been archived by the owner on Dec 31, 2020. It is now read-only.

Update to 0.6.7 #15

Merged
merged 6 commits into from
May 2, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flink-htm-examples/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies {
compile 'org.scala-lang:scala-library:2.11.7'

// htm.java
compile 'org.numenta:htm.java:0.6.5'
compile 'org.numenta:htm.java:0.6.7'

// flink
compile 'org.apache.flink:flink-scala_2.11:1.0.0'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ object NetworkDemoParameters {
//SpatialPooler specific
POTENTIAL_RADIUS -> 12, //3
POTENTIAL_PCT -> 0.5, //0.5
GLOBAL_INHIBITIONS -> false,
GLOBAL_INHIBITION -> false,
LOCAL_AREA_DENSITY -> -1.0,
NUM_ACTIVE_COLUMNS_PER_INH_AREA -> 5.0,
STIMULUS_THRESHOLD -> 1.0,
Expand All @@ -49,7 +49,7 @@ object NetworkDemoParameters {
PERMANENCE_DECREMENT -> 0.05,
ACTIVATION_THRESHOLD -> 4))
.union(Parameters(
GLOBAL_INHIBITIONS -> true,
GLOBAL_INHIBITION -> true,
COLUMN_DIMENSIONS -> Array(2048),
CELLS_PER_COLUMN -> 32,
NUM_ACTIVE_COLUMNS_PER_INH_AREA -> 40.0,
Expand Down Expand Up @@ -80,7 +80,7 @@ trait WorkshopAnomalyParameters {
// spParams
POTENTIAL_PCT -> 0.8,
COLUMN_DIMENSIONS -> Array(2048),
GLOBAL_INHIBITIONS -> true,
GLOBAL_INHIBITION -> true,
/* inputWidth */
MAX_BOOST -> 1.0,
NUM_ACTIVE_COLUMNS_PER_INH_AREA -> 40,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,17 @@ object Demo extends HotGymModel {
.mapWithState { (inference, state: Option[Double]) =>

val prediction = Prediction(
inference.getInput.timestamp.toString(LOOSE_DATE_TIME),
inference.getInput.consumption,
inference._1.timestamp.toString(LOOSE_DATE_TIME),
inference._1.consumption,
state match {
case Some(prediction) => prediction
case None => 0.0
},
inference.getAnomalyScore)
inference._2.getAnomalyScore)

// store the prediction about the next value as state for the next iteration,
// so that actual vs predicted is a meaningful comparison
val predictedConsumption = inference.getClassification("consumption").getMostProbableValue(1).asInstanceOf[Any] match {
val predictedConsumption = inference._2.getClassification("consumption").getMostProbableValue(1).asInstanceOf[Any] match {
case value: Double if value != 0.0 => value
case _ => state.getOrElse(0.0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ object Demo extends TrafficModel {
.filter { report => report.datetime.isBefore(investigationInterval.getEnd) }
.keyBy("streamId")
.learn(network)
.select(inference => (inference.getInput, inference.getAnomalyScore))
.select(inference => (inference._1, inference._2.getAnomalyScore))

val anomalousRoutes = anomalyScores
.filter { anomaly => investigationInterval.contains(anomaly._1.datetime) }
Expand Down
2 changes: 1 addition & 1 deletion flink-htm-streaming-java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies {
compile 'org.slf4j:slf4j-api:1.7.13'

// htm.java
compile 'org.numenta:htm.java:0.6.5'
compile 'org.numenta:htm.java:0.6.7'

// flink
compile 'org.apache.flink:flink-java:1.0.0'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
package org.numenta.nupic.flink.serialization;

import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoException;
import com.esotericsoftware.kryo.Serializer;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.streaming.util.FieldAccessor;
import org.numenta.nupic.Persistable;
import org.numenta.nupic.network.Network;
import org.numenta.nupic.serialize.HTMObjectInput;
import org.numenta.nupic.serialize.HTMObjectOutput;
import org.numenta.nupic.serialize.SerialConfig;
import org.numenta.nupic.serialize.SerializerCore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.lang.reflect.Field;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Kryo serializer for HTM network and related objects.
*
*/
public class KryoSerializer<T extends Persistable> extends Serializer<T> implements Serializable {

protected static final Logger LOGGER = LoggerFactory.getLogger(KryoSerializer.class);

private final SerializerCore serializer = new SerializerCore(SerialConfig.DEFAULT_REGISTERED_TYPES);

/**
* Write the given instance to the given output.
*
* @param kryo instance of {@link Kryo} object
* @param output a Kryo {@link Output} object
* @param t instance to serialize
*/
@Override
public void write(Kryo kryo, Output output, T t) {
try {
preSerialize(t);

try(ByteArrayOutputStream stream = new ByteArrayOutputStream(4096)) {

// write the object using the HTM serializer
HTMObjectOutput writer = serializer.getObjectOutput(stream);
writer.writeObject(t, t.getClass());
writer.close();

// write the serialized data
output.writeInt(stream.size());
stream.writeTo(output);

LOGGER.debug("wrote {} bytes", stream.size());
}
}
catch(IOException e) {
throw new KryoException(e);
}
}

/**
* Read an instance of the given class from the given input.
*
* @param kryo instance of {@link Kryo} object
* @param input a Kryo {@link Input}
* @param aClass The class of the object to be read in.
* @return an instance of type &lt;T&gt;
*/
@Override
public T read(Kryo kryo, Input input, Class<T> aClass) {

// read the serialized data
byte[] data = new byte[input.readInt()];
input.readBytes(data);

try {
try(ByteArrayInputStream stream = new ByteArrayInputStream(data)) {
HTMObjectInput reader = serializer.getObjectInput(stream);
T t = (T) reader.readObject(aClass);

postDeSerialize(t);

return t;
}
}
catch(Exception e) {
throw new KryoException(e);
}
}

/**
* Copy the given instance.
* @param kryo instance of {@link Kryo} object
* @param original an object to copy.
* @return
*/
@Override
public T copy(Kryo kryo, T original) {
try {
preSerialize(original);

try(CopyStream output = new CopyStream(4096)) {
HTMObjectOutput writer = serializer.getObjectOutput(output);
writer.writeObject(original, original.getClass());
writer.close();

try(InputStream input = output.toInputStream()) {
HTMObjectInput reader = serializer.getObjectInput(input);
T t = (T) reader.readObject(original.getClass());

postDeSerialize(t);

return t;
}
}
}
catch(Exception e) {
throw new KryoException(e);
}
}

static class CopyStream extends ByteArrayOutputStream {
public CopyStream(int size) { super(size); }

/**
* Get an input stream based on the contents of this output stream.
* Do not use the output stream after calling this method.
* @return an {@link InputStream}
*/
public InputStream toInputStream() {
return new ByteArrayInputStream(this.buf, 0, this.count);
}
}

/**
* The HTM serializer handles the Persistable callbacks automatically, but
* this method is for any additional actions to be taken.
* @param t the instance to be serialized.
*/
protected void preSerialize(T t) {
}

/**
* The HTM serializer handles the Persistable callbacks automatically, but
* this method is for any additional actions to be taken.
* @param t the instance newly deserialized.
*/
protected void postDeSerialize(T t) {
}

/**
* Register the HTM types with the Kryo serializer.
* @param config
*/
public static void registerTypes(ExecutionConfig config) {
for(Class<?> c : SerialConfig.DEFAULT_REGISTERED_TYPES) {
Class<?> serializerClass = DEFAULT_SERIALIZERS.getOrDefault(c, (Class<?>) KryoSerializer.class);
config.registerTypeWithKryoSerializer(c, (Class<? extends Serializer<?>>) serializerClass);
}
for(Class<?> c : KryoSerializer.ADDITIONAL_REGISTERED_TYPES) {
Class<?> serializerClass = DEFAULT_SERIALIZERS.getOrDefault(c, (Class<?>) KryoSerializer.class);
config.registerTypeWithKryoSerializer(c, (Class<? extends Serializer<?>>) serializerClass);
}
}

static final Class<?>[] ADDITIONAL_REGISTERED_TYPES = { Network.class };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell there is no reason to register the Network object. Internally, all registering does is share the class name (string) among all instances of that particular class (which is slower). So pre-registering helps avoid this behavior, but since there is only ever one instance of the Network, I don't see how registering it will be any more efficient because there is only ever one of its type being serialized?


/**
* A map of serializers for various classes.
*/
static final Map<Class<?>,Class<?>> DEFAULT_SERIALIZERS = Stream.of(
new Tuple2<>(Network.class, NetworkSerializer.class)
).collect(Collectors.toMap(kv -> kv.f0, kv -> kv.f1));


public static class NetworkSerializer extends KryoSerializer<Network> {

private final static Field shouldDoHaltField;

static {
try {
shouldDoHaltField = Network.class.getDeclaredField("shouldDoHalt");
shouldDoHaltField.setAccessible(true);
} catch (NoSuchFieldException e) {
throw new UnsupportedOperationException("unable to locate Network::shouldDoHalt", e);
}

}

@Override
protected void preSerialize(Network network) {
super.preSerialize(network);
try {
// issue: HTM.java #417
Copy link
Member

@cogmission cogmission May 2, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the effect of not managing the shouldDoHalt field?

shouldDoHaltField.set(network, false);
} catch (IllegalAccessException e) {
throw new UnsupportedOperationException("unable to set Network::shouldDoHalt", e);
}
}

@Override
protected void postDeSerialize(Network network) {
super.postDeSerialize(network);
}
}
}
Loading