diff --git a/spark/dl/src/main/java/com/intel/analytics/bigdl/serialization/Bigdl.java b/spark/dl/src/main/java/com/intel/analytics/bigdl/serialization/Bigdl.java
index c889ed291b1..c97dd3a91c7 100644
--- a/spark/dl/src/main/java/com/intel/analytics/bigdl/serialization/Bigdl.java
+++ b/spark/dl/src/main/java/com/intel/analytics/bigdl/serialization/Bigdl.java
@@ -1313,6 +1313,69 @@ com.intel.analytics.bigdl.serialization.Bigdl.AttrValue getAttrOrThrow(
*/
com.intel.analytics.bigdl.serialization.Bigdl.BigDLTensorOrBuilder getParametersOrBuilder(
int index);
+
+ /**
+ * int32 inputDimMasks = 17;
+ */
+ int getInputDimMasks();
+
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ java.util.List
+ getInputScalesList();
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ com.intel.analytics.bigdl.serialization.Bigdl.AttrValue getInputScales(int index);
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ int getInputScalesCount();
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ java.util.List extends com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder>
+ getInputScalesOrBuilderList();
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder getInputScalesOrBuilder(
+ int index);
+
+ /**
+ * int32 outputDimMasks = 19;
+ */
+ int getOutputDimMasks();
+
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ java.util.List
+ getOutputScalesList();
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ com.intel.analytics.bigdl.serialization.Bigdl.AttrValue getOutputScales(int index);
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ int getOutputScalesCount();
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ java.util.List extends com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder>
+ getOutputScalesOrBuilderList();
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder getOutputScalesOrBuilder(
+ int index);
+
+ /**
+ * bool isMklInt8Enabled = 21;
+ */
+ boolean getIsMklInt8Enabled();
}
/**
* Protobuf type {@code com.intel.analytics.bigdl.serialization.BigDLModule}
@@ -1338,6 +1401,11 @@ private BigDLModule() {
id_ = 0;
hasParameters_ = false;
parameters_ = java.util.Collections.emptyList();
+ inputDimMasks_ = 0;
+ inputScales_ = java.util.Collections.emptyList();
+ outputDimMasks_ = 0;
+ outputScales_ = java.util.Collections.emptyList();
+ isMklInt8Enabled_ = false;
}
@java.lang.Override
@@ -1508,6 +1576,39 @@ private BigDLModule(
input.readMessage(com.intel.analytics.bigdl.serialization.Bigdl.BigDLTensor.parser(), extensionRegistry));
break;
}
+ case 136: {
+
+ inputDimMasks_ = input.readInt32();
+ break;
+ }
+ case 146: {
+ if (!((mutable_bitField0_ & 0x00020000) == 0x00020000)) {
+ inputScales_ = new java.util.ArrayList();
+ mutable_bitField0_ |= 0x00020000;
+ }
+ inputScales_.add(
+ input.readMessage(com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.parser(), extensionRegistry));
+ break;
+ }
+ case 152: {
+
+ outputDimMasks_ = input.readInt32();
+ break;
+ }
+ case 162: {
+ if (!((mutable_bitField0_ & 0x00080000) == 0x00080000)) {
+ outputScales_ = new java.util.ArrayList();
+ mutable_bitField0_ |= 0x00080000;
+ }
+ outputScales_.add(
+ input.readMessage(com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.parser(), extensionRegistry));
+ break;
+ }
+ case 168: {
+
+ isMklInt8Enabled_ = input.readBool();
+ break;
+ }
}
}
} catch (com.google.protobuf.InvalidProtocolBufferException e) {
@@ -1528,6 +1629,12 @@ private BigDLModule(
if (((mutable_bitField0_ & 0x00008000) == 0x00008000)) {
parameters_ = java.util.Collections.unmodifiableList(parameters_);
}
+ if (((mutable_bitField0_ & 0x00020000) == 0x00020000)) {
+ inputScales_ = java.util.Collections.unmodifiableList(inputScales_);
+ }
+ if (((mutable_bitField0_ & 0x00080000) == 0x00080000)) {
+ outputScales_ = java.util.Collections.unmodifiableList(outputScales_);
+ }
this.unknownFields = unknownFields.build();
makeExtensionsImmutable();
}
@@ -2187,6 +2294,103 @@ public com.intel.analytics.bigdl.serialization.Bigdl.BigDLTensorOrBuilder getPar
return parameters_.get(index);
}
+ public static final int INPUTDIMMASKS_FIELD_NUMBER = 17;
+ private int inputDimMasks_;
+ /**
+ * int32 inputDimMasks = 17;
+ */
+ public int getInputDimMasks() {
+ return inputDimMasks_;
+ }
+
+ public static final int INPUTSCALES_FIELD_NUMBER = 18;
+ private java.util.List inputScales_;
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public java.util.List getInputScalesList() {
+ return inputScales_;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public java.util.List extends com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder>
+ getInputScalesOrBuilderList() {
+ return inputScales_;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public int getInputScalesCount() {
+ return inputScales_.size();
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue getInputScales(int index) {
+ return inputScales_.get(index);
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder getInputScalesOrBuilder(
+ int index) {
+ return inputScales_.get(index);
+ }
+
+ public static final int OUTPUTDIMMASKS_FIELD_NUMBER = 19;
+ private int outputDimMasks_;
+ /**
+ * int32 outputDimMasks = 19;
+ */
+ public int getOutputDimMasks() {
+ return outputDimMasks_;
+ }
+
+ public static final int OUTPUTSCALES_FIELD_NUMBER = 20;
+ private java.util.List outputScales_;
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public java.util.List getOutputScalesList() {
+ return outputScales_;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public java.util.List extends com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder>
+ getOutputScalesOrBuilderList() {
+ return outputScales_;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public int getOutputScalesCount() {
+ return outputScales_.size();
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue getOutputScales(int index) {
+ return outputScales_.get(index);
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder getOutputScalesOrBuilder(
+ int index) {
+ return outputScales_.get(index);
+ }
+
+ public static final int ISMKLINT8ENABLED_FIELD_NUMBER = 21;
+ private boolean isMklInt8Enabled_;
+ /**
+ * bool isMklInt8Enabled = 21;
+ */
+ public boolean getIsMklInt8Enabled() {
+ return isMklInt8Enabled_;
+ }
+
private byte memoizedIsInitialized = -1;
public final boolean isInitialized() {
byte isInitialized = memoizedIsInitialized;
@@ -2250,6 +2454,21 @@ public void writeTo(com.google.protobuf.CodedOutputStream output)
for (int i = 0; i < parameters_.size(); i++) {
output.writeMessage(16, parameters_.get(i));
}
+ if (inputDimMasks_ != 0) {
+ output.writeInt32(17, inputDimMasks_);
+ }
+ for (int i = 0; i < inputScales_.size(); i++) {
+ output.writeMessage(18, inputScales_.get(i));
+ }
+ if (outputDimMasks_ != 0) {
+ output.writeInt32(19, outputDimMasks_);
+ }
+ for (int i = 0; i < outputScales_.size(); i++) {
+ output.writeMessage(20, outputScales_.get(i));
+ }
+ if (isMklInt8Enabled_ != false) {
+ output.writeBool(21, isMklInt8Enabled_);
+ }
unknownFields.writeTo(output);
}
@@ -2332,6 +2551,26 @@ public int getSerializedSize() {
size += com.google.protobuf.CodedOutputStream
.computeMessageSize(16, parameters_.get(i));
}
+ if (inputDimMasks_ != 0) {
+ size += com.google.protobuf.CodedOutputStream
+ .computeInt32Size(17, inputDimMasks_);
+ }
+ for (int i = 0; i < inputScales_.size(); i++) {
+ size += com.google.protobuf.CodedOutputStream
+ .computeMessageSize(18, inputScales_.get(i));
+ }
+ if (outputDimMasks_ != 0) {
+ size += com.google.protobuf.CodedOutputStream
+ .computeInt32Size(19, outputDimMasks_);
+ }
+ for (int i = 0; i < outputScales_.size(); i++) {
+ size += com.google.protobuf.CodedOutputStream
+ .computeMessageSize(20, outputScales_.get(i));
+ }
+ if (isMklInt8Enabled_ != false) {
+ size += com.google.protobuf.CodedOutputStream
+ .computeBoolSize(21, isMklInt8Enabled_);
+ }
size += unknownFields.getSerializedSize();
memoizedSize = size;
return size;
@@ -2392,6 +2631,16 @@ public boolean equals(final java.lang.Object obj) {
== other.getHasParameters());
result = result && getParametersList()
.equals(other.getParametersList());
+ result = result && (getInputDimMasks()
+ == other.getInputDimMasks());
+ result = result && getInputScalesList()
+ .equals(other.getInputScalesList());
+ result = result && (getOutputDimMasks()
+ == other.getOutputDimMasks());
+ result = result && getOutputScalesList()
+ .equals(other.getOutputScalesList());
+ result = result && (getIsMklInt8Enabled()
+ == other.getIsMklInt8Enabled());
result = result && unknownFields.equals(other.unknownFields);
return result;
}
@@ -2455,6 +2704,21 @@ public int hashCode() {
hash = (37 * hash) + PARAMETERS_FIELD_NUMBER;
hash = (53 * hash) + getParametersList().hashCode();
}
+ hash = (37 * hash) + INPUTDIMMASKS_FIELD_NUMBER;
+ hash = (53 * hash) + getInputDimMasks();
+ if (getInputScalesCount() > 0) {
+ hash = (37 * hash) + INPUTSCALES_FIELD_NUMBER;
+ hash = (53 * hash) + getInputScalesList().hashCode();
+ }
+ hash = (37 * hash) + OUTPUTDIMMASKS_FIELD_NUMBER;
+ hash = (53 * hash) + getOutputDimMasks();
+ if (getOutputScalesCount() > 0) {
+ hash = (37 * hash) + OUTPUTSCALES_FIELD_NUMBER;
+ hash = (53 * hash) + getOutputScalesList().hashCode();
+ }
+ hash = (37 * hash) + ISMKLINT8ENABLED_FIELD_NUMBER;
+ hash = (53 * hash) + com.google.protobuf.Internal.hashBoolean(
+ getIsMklInt8Enabled());
hash = (29 * hash) + unknownFields.hashCode();
memoizedHashCode = hash;
return hash;
@@ -2604,6 +2868,8 @@ private void maybeForceBuilderInitialization() {
.alwaysUseFieldBuilders) {
getSubModulesFieldBuilder();
getParametersFieldBuilder();
+ getInputScalesFieldBuilder();
+ getOutputScalesFieldBuilder();
}
}
public Builder clear() {
@@ -2663,6 +2929,24 @@ public Builder clear() {
} else {
parametersBuilder_.clear();
}
+ inputDimMasks_ = 0;
+
+ if (inputScalesBuilder_ == null) {
+ inputScales_ = java.util.Collections.emptyList();
+ bitField0_ = (bitField0_ & ~0x00020000);
+ } else {
+ inputScalesBuilder_.clear();
+ }
+ outputDimMasks_ = 0;
+
+ if (outputScalesBuilder_ == null) {
+ outputScales_ = java.util.Collections.emptyList();
+ bitField0_ = (bitField0_ & ~0x00080000);
+ } else {
+ outputScalesBuilder_.clear();
+ }
+ isMklInt8Enabled_ = false;
+
return this;
}
@@ -2744,6 +3028,27 @@ public com.intel.analytics.bigdl.serialization.Bigdl.BigDLModule buildPartial()
} else {
result.parameters_ = parametersBuilder_.build();
}
+ result.inputDimMasks_ = inputDimMasks_;
+ if (inputScalesBuilder_ == null) {
+ if (((bitField0_ & 0x00020000) == 0x00020000)) {
+ inputScales_ = java.util.Collections.unmodifiableList(inputScales_);
+ bitField0_ = (bitField0_ & ~0x00020000);
+ }
+ result.inputScales_ = inputScales_;
+ } else {
+ result.inputScales_ = inputScalesBuilder_.build();
+ }
+ result.outputDimMasks_ = outputDimMasks_;
+ if (outputScalesBuilder_ == null) {
+ if (((bitField0_ & 0x00080000) == 0x00080000)) {
+ outputScales_ = java.util.Collections.unmodifiableList(outputScales_);
+ bitField0_ = (bitField0_ & ~0x00080000);
+ }
+ result.outputScales_ = outputScales_;
+ } else {
+ result.outputScales_ = outputScalesBuilder_.build();
+ }
+ result.isMklInt8Enabled_ = isMklInt8Enabled_;
result.bitField0_ = to_bitField0_;
onBuilt();
return result;
@@ -2897,6 +3202,67 @@ public Builder mergeFrom(com.intel.analytics.bigdl.serialization.Bigdl.BigDLModu
}
}
}
+ if (other.getInputDimMasks() != 0) {
+ setInputDimMasks(other.getInputDimMasks());
+ }
+ if (inputScalesBuilder_ == null) {
+ if (!other.inputScales_.isEmpty()) {
+ if (inputScales_.isEmpty()) {
+ inputScales_ = other.inputScales_;
+ bitField0_ = (bitField0_ & ~0x00020000);
+ } else {
+ ensureInputScalesIsMutable();
+ inputScales_.addAll(other.inputScales_);
+ }
+ onChanged();
+ }
+ } else {
+ if (!other.inputScales_.isEmpty()) {
+ if (inputScalesBuilder_.isEmpty()) {
+ inputScalesBuilder_.dispose();
+ inputScalesBuilder_ = null;
+ inputScales_ = other.inputScales_;
+ bitField0_ = (bitField0_ & ~0x00020000);
+ inputScalesBuilder_ =
+ com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders ?
+ getInputScalesFieldBuilder() : null;
+ } else {
+ inputScalesBuilder_.addAllMessages(other.inputScales_);
+ }
+ }
+ }
+ if (other.getOutputDimMasks() != 0) {
+ setOutputDimMasks(other.getOutputDimMasks());
+ }
+ if (outputScalesBuilder_ == null) {
+ if (!other.outputScales_.isEmpty()) {
+ if (outputScales_.isEmpty()) {
+ outputScales_ = other.outputScales_;
+ bitField0_ = (bitField0_ & ~0x00080000);
+ } else {
+ ensureOutputScalesIsMutable();
+ outputScales_.addAll(other.outputScales_);
+ }
+ onChanged();
+ }
+ } else {
+ if (!other.outputScales_.isEmpty()) {
+ if (outputScalesBuilder_.isEmpty()) {
+ outputScalesBuilder_.dispose();
+ outputScalesBuilder_ = null;
+ outputScales_ = other.outputScales_;
+ bitField0_ = (bitField0_ & ~0x00080000);
+ outputScalesBuilder_ =
+ com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders ?
+ getOutputScalesFieldBuilder() : null;
+ } else {
+ outputScalesBuilder_.addAllMessages(other.outputScales_);
+ }
+ }
+ }
+ if (other.getIsMklInt8Enabled() != false) {
+ setIsMklInt8Enabled(other.getIsMklInt8Enabled());
+ }
this.mergeUnknownFields(other.unknownFields);
onChanged();
return this;
@@ -5041,6 +5407,564 @@ public com.intel.analytics.bigdl.serialization.Bigdl.BigDLTensor.Builder addPara
}
return parametersBuilder_;
}
+
+ private int inputDimMasks_ ;
+ /**
+ * int32 inputDimMasks = 17;
+ */
+ public int getInputDimMasks() {
+ return inputDimMasks_;
+ }
+ /**
+ * int32 inputDimMasks = 17;
+ */
+ public Builder setInputDimMasks(int value) {
+
+ inputDimMasks_ = value;
+ onChanged();
+ return this;
+ }
+ /**
+ * int32 inputDimMasks = 17;
+ */
+ public Builder clearInputDimMasks() {
+
+ inputDimMasks_ = 0;
+ onChanged();
+ return this;
+ }
+
+ private java.util.List inputScales_ =
+ java.util.Collections.emptyList();
+ private void ensureInputScalesIsMutable() {
+ if (!((bitField0_ & 0x00020000) == 0x00020000)) {
+ inputScales_ = new java.util.ArrayList(inputScales_);
+ bitField0_ |= 0x00020000;
+ }
+ }
+
+ private com.google.protobuf.RepeatedFieldBuilderV3<
+ com.intel.analytics.bigdl.serialization.Bigdl.AttrValue, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder, com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder> inputScalesBuilder_;
+
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public java.util.List getInputScalesList() {
+ if (inputScalesBuilder_ == null) {
+ return java.util.Collections.unmodifiableList(inputScales_);
+ } else {
+ return inputScalesBuilder_.getMessageList();
+ }
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public int getInputScalesCount() {
+ if (inputScalesBuilder_ == null) {
+ return inputScales_.size();
+ } else {
+ return inputScalesBuilder_.getCount();
+ }
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue getInputScales(int index) {
+ if (inputScalesBuilder_ == null) {
+ return inputScales_.get(index);
+ } else {
+ return inputScalesBuilder_.getMessage(index);
+ }
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public Builder setInputScales(
+ int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue value) {
+ if (inputScalesBuilder_ == null) {
+ if (value == null) {
+ throw new NullPointerException();
+ }
+ ensureInputScalesIsMutable();
+ inputScales_.set(index, value);
+ onChanged();
+ } else {
+ inputScalesBuilder_.setMessage(index, value);
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public Builder setInputScales(
+ int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder builderForValue) {
+ if (inputScalesBuilder_ == null) {
+ ensureInputScalesIsMutable();
+ inputScales_.set(index, builderForValue.build());
+ onChanged();
+ } else {
+ inputScalesBuilder_.setMessage(index, builderForValue.build());
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public Builder addInputScales(com.intel.analytics.bigdl.serialization.Bigdl.AttrValue value) {
+ if (inputScalesBuilder_ == null) {
+ if (value == null) {
+ throw new NullPointerException();
+ }
+ ensureInputScalesIsMutable();
+ inputScales_.add(value);
+ onChanged();
+ } else {
+ inputScalesBuilder_.addMessage(value);
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public Builder addInputScales(
+ int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue value) {
+ if (inputScalesBuilder_ == null) {
+ if (value == null) {
+ throw new NullPointerException();
+ }
+ ensureInputScalesIsMutable();
+ inputScales_.add(index, value);
+ onChanged();
+ } else {
+ inputScalesBuilder_.addMessage(index, value);
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public Builder addInputScales(
+ com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder builderForValue) {
+ if (inputScalesBuilder_ == null) {
+ ensureInputScalesIsMutable();
+ inputScales_.add(builderForValue.build());
+ onChanged();
+ } else {
+ inputScalesBuilder_.addMessage(builderForValue.build());
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public Builder addInputScales(
+ int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder builderForValue) {
+ if (inputScalesBuilder_ == null) {
+ ensureInputScalesIsMutable();
+ inputScales_.add(index, builderForValue.build());
+ onChanged();
+ } else {
+ inputScalesBuilder_.addMessage(index, builderForValue.build());
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public Builder addAllInputScales(
+ java.lang.Iterable extends com.intel.analytics.bigdl.serialization.Bigdl.AttrValue> values) {
+ if (inputScalesBuilder_ == null) {
+ ensureInputScalesIsMutable();
+ com.google.protobuf.AbstractMessageLite.Builder.addAll(
+ values, inputScales_);
+ onChanged();
+ } else {
+ inputScalesBuilder_.addAllMessages(values);
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public Builder clearInputScales() {
+ if (inputScalesBuilder_ == null) {
+ inputScales_ = java.util.Collections.emptyList();
+ bitField0_ = (bitField0_ & ~0x00020000);
+ onChanged();
+ } else {
+ inputScalesBuilder_.clear();
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public Builder removeInputScales(int index) {
+ if (inputScalesBuilder_ == null) {
+ ensureInputScalesIsMutable();
+ inputScales_.remove(index);
+ onChanged();
+ } else {
+ inputScalesBuilder_.remove(index);
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder getInputScalesBuilder(
+ int index) {
+ return getInputScalesFieldBuilder().getBuilder(index);
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder getInputScalesOrBuilder(
+ int index) {
+ if (inputScalesBuilder_ == null) {
+ return inputScales_.get(index); } else {
+ return inputScalesBuilder_.getMessageOrBuilder(index);
+ }
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public java.util.List extends com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder>
+ getInputScalesOrBuilderList() {
+ if (inputScalesBuilder_ != null) {
+ return inputScalesBuilder_.getMessageOrBuilderList();
+ } else {
+ return java.util.Collections.unmodifiableList(inputScales_);
+ }
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder addInputScalesBuilder() {
+ return getInputScalesFieldBuilder().addBuilder(
+ com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.getDefaultInstance());
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder addInputScalesBuilder(
+ int index) {
+ return getInputScalesFieldBuilder().addBuilder(
+ index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.getDefaultInstance());
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18;
+ */
+ public java.util.List
+ getInputScalesBuilderList() {
+ return getInputScalesFieldBuilder().getBuilderList();
+ }
+ private com.google.protobuf.RepeatedFieldBuilderV3<
+ com.intel.analytics.bigdl.serialization.Bigdl.AttrValue, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder, com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder>
+ getInputScalesFieldBuilder() {
+ if (inputScalesBuilder_ == null) {
+ inputScalesBuilder_ = new com.google.protobuf.RepeatedFieldBuilderV3<
+ com.intel.analytics.bigdl.serialization.Bigdl.AttrValue, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder, com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder>(
+ inputScales_,
+ ((bitField0_ & 0x00020000) == 0x00020000),
+ getParentForChildren(),
+ isClean());
+ inputScales_ = null;
+ }
+ return inputScalesBuilder_;
+ }
+
+ private int outputDimMasks_ ;
+ /**
+ * int32 outputDimMasks = 19;
+ */
+ public int getOutputDimMasks() {
+ return outputDimMasks_;
+ }
+ /**
+ * int32 outputDimMasks = 19;
+ */
+ public Builder setOutputDimMasks(int value) {
+
+ outputDimMasks_ = value;
+ onChanged();
+ return this;
+ }
+ /**
+ * int32 outputDimMasks = 19;
+ */
+ public Builder clearOutputDimMasks() {
+
+ outputDimMasks_ = 0;
+ onChanged();
+ return this;
+ }
+
+ private java.util.List outputScales_ =
+ java.util.Collections.emptyList();
+ private void ensureOutputScalesIsMutable() {
+ if (!((bitField0_ & 0x00080000) == 0x00080000)) {
+ outputScales_ = new java.util.ArrayList(outputScales_);
+ bitField0_ |= 0x00080000;
+ }
+ }
+
+ private com.google.protobuf.RepeatedFieldBuilderV3<
+ com.intel.analytics.bigdl.serialization.Bigdl.AttrValue, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder, com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder> outputScalesBuilder_;
+
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public java.util.List getOutputScalesList() {
+ if (outputScalesBuilder_ == null) {
+ return java.util.Collections.unmodifiableList(outputScales_);
+ } else {
+ return outputScalesBuilder_.getMessageList();
+ }
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public int getOutputScalesCount() {
+ if (outputScalesBuilder_ == null) {
+ return outputScales_.size();
+ } else {
+ return outputScalesBuilder_.getCount();
+ }
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue getOutputScales(int index) {
+ if (outputScalesBuilder_ == null) {
+ return outputScales_.get(index);
+ } else {
+ return outputScalesBuilder_.getMessage(index);
+ }
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public Builder setOutputScales(
+ int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue value) {
+ if (outputScalesBuilder_ == null) {
+ if (value == null) {
+ throw new NullPointerException();
+ }
+ ensureOutputScalesIsMutable();
+ outputScales_.set(index, value);
+ onChanged();
+ } else {
+ outputScalesBuilder_.setMessage(index, value);
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public Builder setOutputScales(
+ int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder builderForValue) {
+ if (outputScalesBuilder_ == null) {
+ ensureOutputScalesIsMutable();
+ outputScales_.set(index, builderForValue.build());
+ onChanged();
+ } else {
+ outputScalesBuilder_.setMessage(index, builderForValue.build());
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public Builder addOutputScales(com.intel.analytics.bigdl.serialization.Bigdl.AttrValue value) {
+ if (outputScalesBuilder_ == null) {
+ if (value == null) {
+ throw new NullPointerException();
+ }
+ ensureOutputScalesIsMutable();
+ outputScales_.add(value);
+ onChanged();
+ } else {
+ outputScalesBuilder_.addMessage(value);
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public Builder addOutputScales(
+ int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue value) {
+ if (outputScalesBuilder_ == null) {
+ if (value == null) {
+ throw new NullPointerException();
+ }
+ ensureOutputScalesIsMutable();
+ outputScales_.add(index, value);
+ onChanged();
+ } else {
+ outputScalesBuilder_.addMessage(index, value);
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public Builder addOutputScales(
+ com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder builderForValue) {
+ if (outputScalesBuilder_ == null) {
+ ensureOutputScalesIsMutable();
+ outputScales_.add(builderForValue.build());
+ onChanged();
+ } else {
+ outputScalesBuilder_.addMessage(builderForValue.build());
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public Builder addOutputScales(
+ int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder builderForValue) {
+ if (outputScalesBuilder_ == null) {
+ ensureOutputScalesIsMutable();
+ outputScales_.add(index, builderForValue.build());
+ onChanged();
+ } else {
+ outputScalesBuilder_.addMessage(index, builderForValue.build());
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public Builder addAllOutputScales(
+ java.lang.Iterable extends com.intel.analytics.bigdl.serialization.Bigdl.AttrValue> values) {
+ if (outputScalesBuilder_ == null) {
+ ensureOutputScalesIsMutable();
+ com.google.protobuf.AbstractMessageLite.Builder.addAll(
+ values, outputScales_);
+ onChanged();
+ } else {
+ outputScalesBuilder_.addAllMessages(values);
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public Builder clearOutputScales() {
+ if (outputScalesBuilder_ == null) {
+ outputScales_ = java.util.Collections.emptyList();
+ bitField0_ = (bitField0_ & ~0x00080000);
+ onChanged();
+ } else {
+ outputScalesBuilder_.clear();
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public Builder removeOutputScales(int index) {
+ if (outputScalesBuilder_ == null) {
+ ensureOutputScalesIsMutable();
+ outputScales_.remove(index);
+ onChanged();
+ } else {
+ outputScalesBuilder_.remove(index);
+ }
+ return this;
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder getOutputScalesBuilder(
+ int index) {
+ return getOutputScalesFieldBuilder().getBuilder(index);
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder getOutputScalesOrBuilder(
+ int index) {
+ if (outputScalesBuilder_ == null) {
+ return outputScales_.get(index); } else {
+ return outputScalesBuilder_.getMessageOrBuilder(index);
+ }
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public java.util.List extends com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder>
+ getOutputScalesOrBuilderList() {
+ if (outputScalesBuilder_ != null) {
+ return outputScalesBuilder_.getMessageOrBuilderList();
+ } else {
+ return java.util.Collections.unmodifiableList(outputScales_);
+ }
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder addOutputScalesBuilder() {
+ return getOutputScalesFieldBuilder().addBuilder(
+ com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.getDefaultInstance());
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder addOutputScalesBuilder(
+ int index) {
+ return getOutputScalesFieldBuilder().addBuilder(
+ index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.getDefaultInstance());
+ }
+ /**
+ * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20;
+ */
+ public java.util.List
+ getOutputScalesBuilderList() {
+ return getOutputScalesFieldBuilder().getBuilderList();
+ }
+ private com.google.protobuf.RepeatedFieldBuilderV3<
+ com.intel.analytics.bigdl.serialization.Bigdl.AttrValue, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder, com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder>
+ getOutputScalesFieldBuilder() {
+ if (outputScalesBuilder_ == null) {
+ outputScalesBuilder_ = new com.google.protobuf.RepeatedFieldBuilderV3<
+ com.intel.analytics.bigdl.serialization.Bigdl.AttrValue, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder, com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder>(
+ outputScales_,
+ ((bitField0_ & 0x00080000) == 0x00080000),
+ getParentForChildren(),
+ isClean());
+ outputScales_ = null;
+ }
+ return outputScalesBuilder_;
+ }
+
+ private boolean isMklInt8Enabled_ ;
+ /**
+ * bool isMklInt8Enabled = 21;
+ */
+ public boolean getIsMklInt8Enabled() {
+ return isMklInt8Enabled_;
+ }
+ /**
+ * bool isMklInt8Enabled = 21;
+ */
+ public Builder setIsMklInt8Enabled(boolean value) {
+
+ isMklInt8Enabled_ = value;
+ onChanged();
+ return this;
+ }
+ /**
+ * bool isMklInt8Enabled = 21;
+ */
+ public Builder clearIsMklInt8Enabled() {
+
+ isMklInt8Enabled_ = false;
+ onChanged();
+ return this;
+ }
public final Builder setUnknownFields(
final com.google.protobuf.UnknownFieldSet unknownFields) {
return super.setUnknownFieldsProto3(unknownFields);
@@ -22144,7 +23068,7 @@ public com.intel.analytics.bigdl.serialization.Bigdl.Shape getDefaultInstanceFor
java.lang.String[] descriptorData = {
"\n\013bigdl.proto\022\'com.intel.analytics.bigdl" +
".serialization\032\031google/protobuf/any.prot" +
- "o\"\206\006\n\013BigDLModule\022\014\n\004name\030\001 \001(\t\022H\n\nsubMo" +
+ "o\"\342\007\n\013BigDLModule\022\014\n\004name\030\001 \001(\t\022H\n\nsubMo" +
"dules\030\002 \003(\01324.com.intel.analytics.bigdl." +
"serialization.BigDLModule\022D\n\006weight\030\003 \001(" +
"\01324.com.intel.analytics.bigdl.serializat" +
@@ -22161,105 +23085,110 @@ public com.intel.analytics.bigdl.serialization.Bigdl.Shape getDefaultInstanceFor
"igdl.serialization.Shape\022\025\n\rhasParameter" +
"s\030\017 \001(\010\022H\n\nparameters\030\020 \003(\01324.com.intel." +
"analytics.bigdl.serialization.BigDLTenso" +
- "r\032_\n\tAttrEntry\022\013\n\003key\030\001 \001(\t\022A\n\005value\030\002 \001",
- "(\01322.com.intel.analytics.bigdl.serializa" +
- "tion.AttrValue:\0028\001\"g\n\nInitMethod\022K\n\nmeth" +
- "odType\030\001 \001(\01627.com.intel.analytics.bigdl" +
- ".serialization.InitMethodType\022\014\n\004data\030\002 " +
- "\003(\001\"\326\002\n\013BigDLTensor\022C\n\010datatype\030\001 \001(\01621." +
- "com.intel.analytics.bigdl.serialization." +
- "DataType\022\014\n\004size\030\002 \003(\005\022\016\n\006stride\030\003 \003(\005\022\016" +
- "\n\006offset\030\004 \001(\005\022\021\n\tdimension\030\005 \001(\005\022\021\n\tnEl" +
- "ements\030\006 \001(\005\022\020\n\010isScalar\030\007 \001(\010\022G\n\007storag" +
- "e\030\010 \001(\01326.com.intel.analytics.bigdl.seri",
- "alization.TensorStorage\022\n\n\002id\030\t \001(\005\022G\n\nt" +
- "ensorType\030\n \001(\01623.com.intel.analytics.bi" +
- "gdl.serialization.TensorType\"\352\001\n\rTensorS" +
- "torage\022C\n\010datatype\030\001 \001(\01621.com.intel.ana" +
- "lytics.bigdl.serialization.DataType\022\022\n\nf" +
- "loat_data\030\002 \003(\002\022\023\n\013double_data\030\003 \003(\001\022\021\n\t" +
- "bool_data\030\004 \003(\010\022\023\n\013string_data\030\005 \003(\t\022\020\n\010" +
- "int_data\030\006 \003(\005\022\021\n\tlong_data\030\007 \003(\003\022\022\n\nbyt" +
- "es_data\030\010 \003(\014\022\n\n\002id\030\t \001(\005\"u\n\013Regularizer" +
- "\022Q\n\017regularizerType\030\001 \001(\01628.com.intel.an",
- "alytics.bigdl.serialization.RegularizerT" +
- "ype\022\023\n\013regularData\030\002 \003(\001\"\224\016\n\tAttrValue\022C" +
- "\n\010dataType\030\001 \001(\01621.com.intel.analytics.b" +
- "igdl.serialization.DataType\022\017\n\007subType\030\002" +
- " \001(\t\022\024\n\nint32Value\030\003 \001(\005H\000\022\024\n\nint64Value" +
- "\030\004 \001(\003H\000\022\024\n\nfloatValue\030\005 \001(\002H\000\022\025\n\013double" +
- "Value\030\006 \001(\001H\000\022\025\n\013stringValue\030\007 \001(\tH\000\022\023\n\t" +
- "boolValue\030\010 \001(\010H\000\022P\n\020regularizerValue\030\t " +
- "\001(\01324.com.intel.analytics.bigdl.serializ" +
- "ation.RegularizerH\000\022K\n\013tensorValue\030\n \001(\013",
- "24.com.intel.analytics.bigdl.serializati" +
- "on.BigDLTensorH\000\022Q\n\023variableFormatValue\030" +
- "\013 \001(\01622.com.intel.analytics.bigdl.serial" +
- "ization.VarFormatH\000\022N\n\017initMethodValue\030\014" +
- " \001(\01323.com.intel.analytics.bigdl.seriali" +
- "zation.InitMethodH\000\022P\n\020bigDLModuleValue\030" +
- "\r \001(\01324.com.intel.analytics.bigdl.serial" +
- "ization.BigDLModuleH\000\022R\n\021nameAttrListVal" +
- "ue\030\016 \001(\01325.com.intel.analytics.bigdl.ser" +
- "ialization.NameAttrListH\000\022S\n\narrayValue\030",
- "\017 \001(\0132=.com.intel.analytics.bigdl.serial" +
- "ization.AttrValue.ArrayValueH\000\022S\n\017dataFo" +
- "rmatValue\030\020 \001(\01628.com.intel.analytics.bi" +
- "gdl.serialization.InputDataFormatH\000\022+\n\013c" +
- "ustomValue\030\021 \001(\0132\024.google.protobuf.AnyH\000" +
- "\022?\n\005shape\030\022 \001(\0132..com.intel.analytics.bi" +
- "gdl.serialization.ShapeH\000\032\242\006\n\nArrayValue" +
- "\022\014\n\004size\030\001 \001(\005\022C\n\010datatype\030\002 \001(\01621.com.i" +
- "ntel.analytics.bigdl.serialization.DataT" +
- "ype\022\013\n\003i32\030\003 \003(\005\022\013\n\003i64\030\004 \003(\003\022\013\n\003flt\030\005 \003",
- "(\002\022\013\n\003dbl\030\006 \003(\001\022\013\n\003str\030\007 \003(\t\022\017\n\007boolean\030" +
- "\010 \003(\010\022I\n\013Regularizer\030\t \003(\01324.com.intel.a" +
- "nalytics.bigdl.serialization.Regularizer" +
- "\022D\n\006tensor\030\n \003(\01324.com.intel.analytics.b" +
- "igdl.serialization.BigDLTensor\022J\n\016variab" +
- "leFormat\030\013 \003(\01622.com.intel.analytics.big" +
- "dl.serialization.VarFormat\022G\n\ninitMethod" +
- "\030\014 \003(\01323.com.intel.analytics.bigdl.seria" +
- "lization.InitMethod\022I\n\013bigDLModule\030\r \003(\013" +
- "24.com.intel.analytics.bigdl.serializati",
- "on.BigDLModule\022K\n\014nameAttrList\030\016 \003(\01325.c" +
- "om.intel.analytics.bigdl.serialization.N" +
- "ameAttrList\022L\n\ndataFormat\030\017 \003(\01628.com.in" +
- "tel.analytics.bigdl.serialization.InputD" +
- "ataFormat\022$\n\006custom\030\020 \003(\0132\024.google.proto" +
- "buf.Any\022=\n\005shape\030\021 \003(\0132..com.intel.analy" +
- "tics.bigdl.serialization.ShapeB\007\n\005value\"" +
- "\314\001\n\014NameAttrList\022\014\n\004name\030\001 \001(\t\022M\n\004attr\030\002" +
- " \003(\0132?.com.intel.analytics.bigdl.seriali" +
- "zation.NameAttrList.AttrEntry\032_\n\tAttrEnt",
- "ry\022\013\n\003key\030\001 \001(\t\022A\n\005value\030\002 \001(\01322.com.int" +
- "el.analytics.bigdl.serialization.AttrVal" +
- "ue:\0028\001\"\332\001\n\005Shape\022K\n\tshapeType\030\001 \001(\01628.co" +
+ "r\022\025\n\rinputDimMasks\030\021 \001(\005\022G\n\013inputScales\030",
+ "\022 \003(\01322.com.intel.analytics.bigdl.serial" +
+ "ization.AttrValue\022\026\n\016outputDimMasks\030\023 \001(" +
+ "\005\022H\n\014outputScales\030\024 \003(\01322.com.intel.anal" +
+ "ytics.bigdl.serialization.AttrValue\022\030\n\020i" +
+ "sMklInt8Enabled\030\025 \001(\010\032_\n\tAttrEntry\022\013\n\003ke" +
+ "y\030\001 \001(\t\022A\n\005value\030\002 \001(\01322.com.intel.analy" +
+ "tics.bigdl.serialization.AttrValue:\0028\001\"g" +
+ "\n\nInitMethod\022K\n\nmethodType\030\001 \001(\01627.com.i" +
+ "ntel.analytics.bigdl.serialization.InitM" +
+ "ethodType\022\014\n\004data\030\002 \003(\001\"\326\002\n\013BigDLTensor\022",
+ "C\n\010datatype\030\001 \001(\01621.com.intel.analytics." +
+ "bigdl.serialization.DataType\022\014\n\004size\030\002 \003" +
+ "(\005\022\016\n\006stride\030\003 \003(\005\022\016\n\006offset\030\004 \001(\005\022\021\n\tdi" +
+ "mension\030\005 \001(\005\022\021\n\tnElements\030\006 \001(\005\022\020\n\010isSc" +
+ "alar\030\007 \001(\010\022G\n\007storage\030\010 \001(\01326.com.intel." +
+ "analytics.bigdl.serialization.TensorStor" +
+ "age\022\n\n\002id\030\t \001(\005\022G\n\ntensorType\030\n \001(\01623.co" +
+ "m.intel.analytics.bigdl.serialization.Te" +
+ "nsorType\"\352\001\n\rTensorStorage\022C\n\010datatype\030\001" +
+ " \001(\01621.com.intel.analytics.bigdl.seriali",
+ "zation.DataType\022\022\n\nfloat_data\030\002 \003(\002\022\023\n\013d" +
+ "ouble_data\030\003 \003(\001\022\021\n\tbool_data\030\004 \003(\010\022\023\n\013s" +
+ "tring_data\030\005 \003(\t\022\020\n\010int_data\030\006 \003(\005\022\021\n\tlo" +
+ "ng_data\030\007 \003(\003\022\022\n\nbytes_data\030\010 \003(\014\022\n\n\002id\030" +
+ "\t \001(\005\"u\n\013Regularizer\022Q\n\017regularizerType\030" +
+ "\001 \001(\01628.com.intel.analytics.bigdl.serial" +
+ "ization.RegularizerType\022\023\n\013regularData\030\002" +
+ " \003(\001\"\224\016\n\tAttrValue\022C\n\010dataType\030\001 \001(\01621.c" +
+ "om.intel.analytics.bigdl.serialization.D" +
+ "ataType\022\017\n\007subType\030\002 \001(\t\022\024\n\nint32Value\030\003",
+ " \001(\005H\000\022\024\n\nint64Value\030\004 \001(\003H\000\022\024\n\nfloatVal" +
+ "ue\030\005 \001(\002H\000\022\025\n\013doubleValue\030\006 \001(\001H\000\022\025\n\013str" +
+ "ingValue\030\007 \001(\tH\000\022\023\n\tboolValue\030\010 \001(\010H\000\022P\n" +
+ "\020regularizerValue\030\t \001(\01324.com.intel.anal" +
+ "ytics.bigdl.serialization.RegularizerH\000\022" +
+ "K\n\013tensorValue\030\n \001(\01324.com.intel.analyti" +
+ "cs.bigdl.serialization.BigDLTensorH\000\022Q\n\023" +
+ "variableFormatValue\030\013 \001(\01622.com.intel.an" +
+ "alytics.bigdl.serialization.VarFormatH\000\022" +
+ "N\n\017initMethodValue\030\014 \001(\01323.com.intel.ana",
+ "lytics.bigdl.serialization.InitMethodH\000\022" +
+ "P\n\020bigDLModuleValue\030\r \001(\01324.com.intel.an" +
+ "alytics.bigdl.serialization.BigDLModuleH" +
+ "\000\022R\n\021nameAttrListValue\030\016 \001(\01325.com.intel" +
+ ".analytics.bigdl.serialization.NameAttrL" +
+ "istH\000\022S\n\narrayValue\030\017 \001(\0132=.com.intel.an" +
+ "alytics.bigdl.serialization.AttrValue.Ar" +
+ "rayValueH\000\022S\n\017dataFormatValue\030\020 \001(\01628.co" +
+ "m.intel.analytics.bigdl.serialization.In" +
+ "putDataFormatH\000\022+\n\013customValue\030\021 \001(\0132\024.g",
+ "oogle.protobuf.AnyH\000\022?\n\005shape\030\022 \001(\0132..co" +
"m.intel.analytics.bigdl.serialization.Sh" +
- "ape.ShapeType\022\r\n\005ssize\030\002 \001(\005\022\022\n\nshapeVal" +
- "ue\030\003 \003(\005\022=\n\005shape\030\004 \003(\0132..com.intel.anal" +
- "ytics.bigdl.serialization.Shape\"\"\n\tShape" +
- "Type\022\n\n\006SINGLE\020\000\022\t\n\005MULTI\020\001*\260\001\n\tVarForma" +
- "t\022\020\n\014EMPTY_FORMAT\020\000\022\013\n\007DEFAULT\020\001\022\t\n\005ONE_" +
- "D\020\002\022\n\n\006IN_OUT\020\003\022\n\n\006OUT_IN\020\004\022\020\n\014IN_OUT_KW",
- "_KH\020\005\022\020\n\014OUT_IN_KW_KH\020\006\022\023\n\017GP_OUT_IN_KW_" +
- "KH\020\007\022\023\n\017GP_IN_OUT_KW_KH\020\010\022\023\n\017OUT_IN_KT_K" +
- "H_KW\020\t*\253\001\n\016InitMethodType\022\030\n\024EMPTY_INITI" +
- "ALIZATION\020\000\022\022\n\016RANDOM_UNIFORM\020\001\022\030\n\024RANDO" +
- "M_UNIFORM_PARAM\020\002\022\021\n\rRANDOM_NORMAL\020\003\022\t\n\005" +
- "ZEROS\020\004\022\010\n\004ONES\020\005\022\t\n\005CONST\020\006\022\n\n\006XAVIER\020\007" +
- "\022\022\n\016BILINEARFILLER\020\010*L\n\017RegularizerType\022" +
- "\023\n\017L1L2Regularizer\020\000\022\021\n\rL1Regularizer\020\001\022" +
- "\021\n\rL2Regularizer\020\002*%\n\017InputDataFormat\022\010\n" +
- "\004NCHW\020\000\022\010\n\004NHWC\020\001*\"\n\nTensorType\022\t\n\005DENSE",
- "\020\000\022\t\n\005QUANT\020\001*\210\002\n\010DataType\022\t\n\005INT32\020\000\022\t\n" +
- "\005INT64\020\001\022\t\n\005FLOAT\020\002\022\n\n\006DOUBLE\020\003\022\n\n\006STRIN" +
- "G\020\004\022\010\n\004BOOL\020\005\022\010\n\004CHAR\020\006\022\t\n\005SHORT\020\007\022\t\n\005BY" +
- "TES\020\010\022\017\n\013REGULARIZER\020\t\022\n\n\006TENSOR\020\n\022\023\n\017VA" +
- "RIABLE_FORMAT\020\013\022\016\n\nINITMETHOD\020\014\022\n\n\006MODUL" +
- "E\020\r\022\022\n\016NAME_ATTR_LIST\020\016\022\017\n\013ARRAY_VALUE\020\017" +
- "\022\017\n\013DATA_FORMAT\020\020\022\n\n\006CUSTOM\020\021\022\t\n\005SHAPE\020\022" +
- "b\006proto3"
+ "apeH\000\032\242\006\n\nArrayValue\022\014\n\004size\030\001 \001(\005\022C\n\010da" +
+ "tatype\030\002 \001(\01621.com.intel.analytics.bigdl" +
+ ".serialization.DataType\022\013\n\003i32\030\003 \003(\005\022\013\n\003" +
+ "i64\030\004 \003(\003\022\013\n\003flt\030\005 \003(\002\022\013\n\003dbl\030\006 \003(\001\022\013\n\003s" +
+ "tr\030\007 \003(\t\022\017\n\007boolean\030\010 \003(\010\022I\n\013Regularizer" +
+ "\030\t \003(\01324.com.intel.analytics.bigdl.seria" +
+ "lization.Regularizer\022D\n\006tensor\030\n \003(\01324.c" +
+ "om.intel.analytics.bigdl.serialization.B",
+ "igDLTensor\022J\n\016variableFormat\030\013 \003(\01622.com" +
+ ".intel.analytics.bigdl.serialization.Var" +
+ "Format\022G\n\ninitMethod\030\014 \003(\01323.com.intel.a" +
+ "nalytics.bigdl.serialization.InitMethod\022" +
+ "I\n\013bigDLModule\030\r \003(\01324.com.intel.analyti" +
+ "cs.bigdl.serialization.BigDLModule\022K\n\014na" +
+ "meAttrList\030\016 \003(\01325.com.intel.analytics.b" +
+ "igdl.serialization.NameAttrList\022L\n\ndataF" +
+ "ormat\030\017 \003(\01628.com.intel.analytics.bigdl." +
+ "serialization.InputDataFormat\022$\n\006custom\030",
+ "\020 \003(\0132\024.google.protobuf.Any\022=\n\005shape\030\021 \003" +
+ "(\0132..com.intel.analytics.bigdl.serializa" +
+ "tion.ShapeB\007\n\005value\"\314\001\n\014NameAttrList\022\014\n\004" +
+ "name\030\001 \001(\t\022M\n\004attr\030\002 \003(\0132?.com.intel.ana" +
+ "lytics.bigdl.serialization.NameAttrList." +
+ "AttrEntry\032_\n\tAttrEntry\022\013\n\003key\030\001 \001(\t\022A\n\005v" +
+ "alue\030\002 \001(\01322.com.intel.analytics.bigdl.s" +
+ "erialization.AttrValue:\0028\001\"\332\001\n\005Shape\022K\n\t" +
+ "shapeType\030\001 \001(\01628.com.intel.analytics.bi" +
+ "gdl.serialization.Shape.ShapeType\022\r\n\005ssi",
+ "ze\030\002 \001(\005\022\022\n\nshapeValue\030\003 \003(\005\022=\n\005shape\030\004 " +
+ "\003(\0132..com.intel.analytics.bigdl.serializ" +
+ "ation.Shape\"\"\n\tShapeType\022\n\n\006SINGLE\020\000\022\t\n\005" +
+ "MULTI\020\001*\260\001\n\tVarFormat\022\020\n\014EMPTY_FORMAT\020\000\022" +
+ "\013\n\007DEFAULT\020\001\022\t\n\005ONE_D\020\002\022\n\n\006IN_OUT\020\003\022\n\n\006O" +
+ "UT_IN\020\004\022\020\n\014IN_OUT_KW_KH\020\005\022\020\n\014OUT_IN_KW_K" +
+ "H\020\006\022\023\n\017GP_OUT_IN_KW_KH\020\007\022\023\n\017GP_IN_OUT_KW" +
+ "_KH\020\010\022\023\n\017OUT_IN_KT_KH_KW\020\t*\253\001\n\016InitMetho" +
+ "dType\022\030\n\024EMPTY_INITIALIZATION\020\000\022\022\n\016RANDO" +
+ "M_UNIFORM\020\001\022\030\n\024RANDOM_UNIFORM_PARAM\020\002\022\021\n",
+ "\rRANDOM_NORMAL\020\003\022\t\n\005ZEROS\020\004\022\010\n\004ONES\020\005\022\t\n" +
+ "\005CONST\020\006\022\n\n\006XAVIER\020\007\022\022\n\016BILINEARFILLER\020\010" +
+ "*L\n\017RegularizerType\022\023\n\017L1L2Regularizer\020\000" +
+ "\022\021\n\rL1Regularizer\020\001\022\021\n\rL2Regularizer\020\002*%" +
+ "\n\017InputDataFormat\022\010\n\004NCHW\020\000\022\010\n\004NHWC\020\001*\"\n" +
+ "\nTensorType\022\t\n\005DENSE\020\000\022\t\n\005QUANT\020\001*\210\002\n\010Da" +
+ "taType\022\t\n\005INT32\020\000\022\t\n\005INT64\020\001\022\t\n\005FLOAT\020\002\022" +
+ "\n\n\006DOUBLE\020\003\022\n\n\006STRING\020\004\022\010\n\004BOOL\020\005\022\010\n\004CHA" +
+ "R\020\006\022\t\n\005SHORT\020\007\022\t\n\005BYTES\020\010\022\017\n\013REGULARIZER" +
+ "\020\t\022\n\n\006TENSOR\020\n\022\023\n\017VARIABLE_FORMAT\020\013\022\016\n\nI",
+ "NITMETHOD\020\014\022\n\n\006MODULE\020\r\022\022\n\016NAME_ATTR_LIS" +
+ "T\020\016\022\017\n\013ARRAY_VALUE\020\017\022\017\n\013DATA_FORMAT\020\020\022\n\n" +
+ "\006CUSTOM\020\021\022\t\n\005SHAPE\020\022b\006proto3"
};
com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner =
new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() {
@@ -22279,7 +23208,7 @@ public com.google.protobuf.ExtensionRegistry assignDescriptors(
internal_static_com_intel_analytics_bigdl_serialization_BigDLModule_fieldAccessorTable = new
com.google.protobuf.GeneratedMessageV3.FieldAccessorTable(
internal_static_com_intel_analytics_bigdl_serialization_BigDLModule_descriptor,
- new java.lang.String[] { "Name", "SubModules", "Weight", "Bias", "PreModules", "NextModules", "ModuleType", "Attr", "Version", "Train", "NamePostfix", "Id", "InputShape", "OutputShape", "HasParameters", "Parameters", });
+ new java.lang.String[] { "Name", "SubModules", "Weight", "Bias", "PreModules", "NextModules", "ModuleType", "Attr", "Version", "Train", "NamePostfix", "Id", "InputShape", "OutputShape", "HasParameters", "Parameters", "InputDimMasks", "InputScales", "OutputDimMasks", "OutputScales", "IsMklInt8Enabled", });
internal_static_com_intel_analytics_bigdl_serialization_BigDLModule_AttrEntry_descriptor =
internal_static_com_intel_analytics_bigdl_serialization_BigDLModule_descriptor.getNestedTypes().get(0);
internal_static_com_intel_analytics_bigdl_serialization_BigDLModule_AttrEntry_fieldAccessorTable = new
diff --git a/spark/dl/src/main/resources/serialization/bigdl.proto b/spark/dl/src/main/resources/serialization/bigdl.proto
index 9f0e276304d..5e53c6e32c3 100644
--- a/spark/dl/src/main/resources/serialization/bigdl.proto
+++ b/spark/dl/src/main/resources/serialization/bigdl.proto
@@ -21,6 +21,11 @@ message BigDLModule
Shape outputShape = 14; //output shape
bool hasParameters = 15; // indicator if module has parameters
repeated BigDLTensor parameters = 16; // parameters, e.g., weight and bias
+ int32 inputDimMasks = 17;
+ repeated AttrValue inputScales = 18;
+ int32 outputDimMasks = 19;
+ repeated AttrValue outputScales = 20;
+ bool isMklInt8Enabled = 21;
}
enum VarFormat {
EMPTY_FORMAT = 0;
diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/abstractnn/AbstractModule.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/abstractnn/AbstractModule.scala
index 3dc2e4832b0..8cccc7d0123 100644
--- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/abstractnn/AbstractModule.scala
+++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/abstractnn/AbstractModule.scala
@@ -59,8 +59,6 @@ abstract class AbstractModule[A <: Activity: ClassTag, B <: Activity: ClassTag,
implicit ev: TensorNumeric[T]) extends Serializable with InferShape{
// ================================= Public APIs =============================================
-
-
/**
* The cached output. So we don't compute it again when need it
*/
@@ -1183,3 +1181,5 @@ abstract class AbstractModule[A <: Activity: ClassTag, B <: Activity: ClassTag,
}
}
+
+
diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/Linear.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/Linear.scala
index 0cca0ba0941..9bfbaeba118 100644
--- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/Linear.scala
+++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/Linear.scala
@@ -32,7 +32,8 @@ class Linear(
private val initWeight: Tensor[Float] = null,
private val initBias: Tensor[Float] = null,
private val initGradWeight: Tensor[Float] = null,
- private val initGradBias: Tensor[Float] = null) extends MklDnnLayer with Initializable {
+ private val initGradBias: Tensor[Float] = null
+) extends MklDnnLayer with Initializable with MklInt8Convertible {
private[mkldnn] val weight: Blob = new Blob(Array(outputSize, inputSize))
private[mkldnn] val bias: Blob = new Blob(Array(outputSize))
diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/MklInt8Convertible.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/MklInt8Convertible.scala
new file mode 100644
index 00000000000..b944870528c
--- /dev/null
+++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/MklInt8Convertible.scala
@@ -0,0 +1,164 @@
+/*
+ * Copyright 2016 The BigDL Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.intel.analytics.bigdl.nn.mkldnn
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+* Trait which provides MKL-DNN functionality to convert FP32 model to INT8 model
+*/
+trait MklInt8Convertible {
+ // input dimension mask
+ protected var inDimMask: Int = 0
+ // output dimension mask
+ protected var outDimMask: Int = 0
+ // input scales
+ private[mkldnn] var inScalesBuffer: ArrayBuffer[Array[Float]] = ArrayBuffer.empty[Array[Float]]
+ // output scales
+ private[mkldnn] var outScalesBuffer: ArrayBuffer[Array[Float]] = ArrayBuffer.empty[Array[Float]]
+
+
+ /**
+ * Get dimension mask of input
+ * @return inDimMask field which stores value of input dimension mask
+ */
+ def getInputDimMask(): Int = {
+ inDimMask
+ }
+
+ /**
+ * Set dimension mask of input
+ * @param mask value of input dimension mask to be set
+ * @return Unit
+ */
+ def setInputDimMask(mask: Int) : Unit = {
+ inDimMask = mask
+ }
+
+ /**
+ * Get dimension mask of output
+ * @return outDimMask field which stores value of output dimension mask
+ */
+ def getOutputDimMask(): Int = {
+ outDimMask
+ }
+
+ /**
+ * Set dimension mask of output
+ * @param mask value of output dimension mask to be set
+ * @return Unit
+ */
+ def setOutputDimMask(mask: Int): Unit = {
+ outDimMask = mask
+ }
+
+ /**
+ * Get input scales
+ * @return field which stores value of input scales
+ */
+ def getInputScales(): Array[Array[Float]] = {
+ inScalesBuffer.toArray
+ }
+
+ /**
+ * Set input scales
+ * Clear existing buffer of input scales, and place updated scales into the cleared buffer
+ * @param inScales value of input scales to be set
+ * @return Unit
+ */
+ def setInputScales(inScales: Array[Array[Float]]): Unit = {
+ inScalesBuffer.clear()
+ inScales.foreach(appendInputScales)
+ }
+
+ /**
+ * Get output scales
+ * @return field which stores value of output scales
+ */
+ def getOutputScales(): Array[Array[Float]] = {
+ outScalesBuffer.toArray
+ }
+
+ /**
+ * Set output scales
+ * Clear existing buffer of output scales, and place updated scales into the cleared buffer
+ * @param outScales value of output scales to be set
+ * @return Unit
+ */
+ def setOutputScales(outScales: Array[Array[Float]]): Unit = {
+ outScalesBuffer.clear()
+ outScales.foreach(appendOutputScales)
+ }
+
+ /**
+ * Append a scale, an array of float, into input scales buffer
+ * @param scale value of an input scale to be appended
+ * @return Unit
+ */
+ private def appendInputScales(scale: Array[Float]): Unit = {
+ inScalesBuffer.append(scale)
+ }
+
+ /**
+ * Append a scale, an array of float, into output scales buffer
+ * @param scale value of an output scale to be appended
+ * @return Unit
+ */
+ private def appendOutputScales(scale: Array[Float]): Unit = {
+ outScalesBuffer.append(scale)
+ }
+
+ /**
+ * Update input scales at specific index with provided new scale
+ * @param scale the new scale
+ * @param index the index of which the scale need to be updated
+ * @return Unit
+ */
+ def updateInputScales(scale: Array[Float], index: Int): Unit = {
+ updateScalesHelper(inScalesBuffer, scale, index)
+ }
+
+ /**
+ * Update output scales at specific index with provided new scale
+ * @param scale the new scale
+ * @param index the index of which the scale need to be updated
+ * @return Unit
+ */
+ def updateOutputSclaes(scale: Array[Float], index: Int): Unit = {
+ updateScalesHelper(outScalesBuffer, scale, index)
+ }
+
+
+ /**
+ * Scales update helper. Replace scale at specific index with provided new scale
+ * @param scales the scales arrayBuffer to be updated
+ * @param scale the new scale
+ * @param index the index of which the scale need to be updated
+ * @return Unit
+ */
+ private def updateScalesHelper(scales: ArrayBuffer[Array[Float]],
+ scale: Array[Float], index: Int): Unit = {
+ if (scales.length - 1 < index) {
+ scales.append(scale)
+ }
+
+ scales(index).indices.foreach(i =>
+ if (scale(i) > scales(index)(i)) {
+ scales(index)(i) = scale(i)
+ })
+ }
+
+}
diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializable.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializable.scala
index 56bb2ea1686..61442dfabb6 100644
--- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializable.scala
+++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializable.scala
@@ -15,12 +15,15 @@
*/
package com.intel.analytics.bigdl.utils.serializer
+
import java.lang.reflect.Field
+import scala.collection.JavaConverters._
import com.intel.analytics.bigdl.nn.Container
import scala.collection.JavaConverters._
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
+import com.intel.analytics.bigdl.nn.mkldnn.MklInt8Convertible
import com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.ArrayValue
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
@@ -29,7 +32,6 @@ import com.intel.analytics.bigdl.utils.serializer.converters.{DataConverter, Sha
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializer._
import com.intel.analytics.bigdl.serialization.Bigdl._
-import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import scala.reflect.runtime.universe
@@ -85,7 +87,6 @@ trait ModuleSerializable extends Loadable with Savable{
// step2 : module specific logic to load module, either default, cell, container or graph
val moduleId = context.bigdlModule.getId
-
val storages = context.storages
val module = if (storages.contains(moduleId)) {
@@ -121,6 +122,7 @@ trait ModuleSerializable extends Loadable with Savable{
val constructorFullParams = constructorMirror.symbol.paramss
val args = new Array[Object](constructorFullParams.map(_.size).sum)
var i = 0
+
constructorFullParams.foreach(map => {
map.foreach(param => {
val name = param.name.decodedName.toString
@@ -185,8 +187,10 @@ trait ModuleSerializable extends Loadable with Savable{
}
getLock.synchronized {
+
// step 4 : set data types (ClassTag and TensorNumric)
setDataTypes(context, bigDLModelBuilder)
+
// step 5 : apply module specific logic to create module
doSerializeModule(context, bigDLModelBuilder)
}
@@ -218,6 +222,7 @@ trait ModuleSerializable extends Loadable with Savable{
val fullParams = getCostructorMirror(cls).symbol.paramss
val constructorParams = fullParams(0)
constructorParams.foreach(param => {
+
val paramName = param.name.decodedName.toString
var ptype = param.typeSignature
val attrBuilder = AttrValue.newBuilder
@@ -232,28 +237,36 @@ trait ModuleSerializable extends Loadable with Savable{
field.setAccessible(true)
val fieldValue = field.get(module)
DataConverter.setAttributeValue(context, attrBuilder, fieldValue, ptype)
-
bigDLModelBuilder.putAttr(paramName, attrBuilder.build)
})
}
+ /*
+ * Re-create BigDL module by deserializing protobuf context.
+ * @param DeserializeContext: Deserialization context
+ * @param AbstractModule: The BigDL module to be re-created
+ * @return ModuleData: Tuple3 contains information of current module and modules adjacent to it
+ */
protected def createBigDLModule[T: ClassTag](context: DeserializeContext,
module : AbstractModule[Activity, Activity, T])
- (implicit ev: TensorNumeric[T])
- : ModuleData[T] = {
+ (implicit ev: TensorNumeric[T]): ModuleData[T] = {
val model = context.bigdlModule
val preModules = model.getPreModulesList.asScala
val nextModules = model.getNextModulesList.asScala
val bigDLModule = ModuleData(module, preModules, nextModules)
+
if (model.getName != "") {
module.setName(model.getName)
}
+
module.setNamePostfix(model.getNamePostfix)
+
if (model.getTrain) {
module.training()
} else {
module.evaluate()
}
+
module.inputShapeValue = ShapeConverter.shapeToBigDL(context, model, "input")
module.outputShapeValue = ShapeConverter.shapeToBigDL(context, model, "output")
@@ -261,9 +274,23 @@ trait ModuleSerializable extends Loadable with Savable{
if (_copyWeightAndBias && context.bigdlModule.getSubModulesCount == 0) {
copy2BigDL(context, bigDLModule)
}
+
+ // Load MKL-DNN INT8 attributes (scales&mask of input&output) into
+ // BigDL Module from protobuf definition if the MKL-DNN INT8 flag is ON
+ if (model.getIsMklInt8Enabled) {
+ loadMklInt8Attr(context, module.asInstanceOf[MklInt8Convertible])
+ }
+
bigDLModule
}
+
+ /*
+ * Create BigDL model's protobuf definition by serializing BigDL Module object
+ * @param BigDLModule.Builder: BigDL model builder of protobuf definition
+ * @param SerializeContext: Serialized context of BigDL module
+ * @return SerializeResult:
+ */
protected def createSerializeBigDLModule[T: ClassTag](
modelBuilder : BigDLModule.Builder, context: SerializeContext[T])(implicit ev: TensorNumeric[T])
: SerializeResult = {
@@ -288,6 +315,16 @@ trait ModuleSerializable extends Loadable with Savable{
if (_copyWeightAndBias && !module.isInstanceOf[Container[_, _, _]]) {
copyFromBigDL(context, modelBuilder)
}
+
+ // Save MKL-DNN attributes (scales and masks) into model of protobuf definition if
+ // the module is with trait of MklInt8COnvertible, and set the MKL-DNN INT8 flag to true
+ if (module.module.isInstanceOf[MklInt8Convertible]) {
+ saveMklInt8Attr(context.moduleData.module.asInstanceOf[MklInt8Convertible], modelBuilder)
+ modelBuilder.setIsMklInt8Enabled(true)
+ } else {
+ modelBuilder.setIsMklInt8Enabled(false)
+ }
+
SerializeResult(modelBuilder, context.storages)
}
@@ -298,7 +335,6 @@ trait ModuleSerializable extends Loadable with Savable{
*/
protected def copy2BigDL[T: ClassTag](context: DeserializeContext, module : ModuleData[T])
(implicit ev: TensorNumeric[T]): Unit = {
-
if (context.bigdlModule.getHasParameters) {
copyParameters2BigDL(context, module)
} else {
@@ -356,12 +392,50 @@ trait ModuleSerializable extends Loadable with Savable{
}
/**
- * copy BigDL module data (weight and bias if exist) to BigDL Model to be persisted
+ * Deserialize MKL-DNN INT8 attributes from protobuf context
+ * and load them into BigDL Module object
+ * @param context deserialized context
+ * @param module bigDL Module with relationships
+ */
+ private def loadMklInt8Attr[T: ClassTag](context: DeserializeContext,
+ module: MklInt8Convertible)
+ (implicit ev: TensorNumeric[T]): Unit = {
+
+ val protobufModel = context.bigdlModule
+
+ // Extract ArrayValue for each AttrValue, and then get FltList as input scales
+ val inputScales = protobufModel.getInputScalesList.iterator().asScala
+ .map(attrValueToFloatArray)
+
+ // Extract ArrayValue for each AttrValue, and then get FltList as output scales
+ val outputScales = protobufModel.getOutputScalesList.iterator().asScala
+ .map(attrValueToFloatArray)
+
+ module.setInputDimMask(protobufModel.getInputDimMasks)
+ module.setInputScales(inputScales.toArray)
+ module.setOutputDimMask(protobufModel.getOutputDimMasks)
+ module.setOutputScales(outputScales.toArray)
+ }
+
+ /**
+ * Convert Attr Value object to Array of Float
+ * @param AttrValue
+ * @return Array[Float]
+ */
+ def attrValueToFloatArray(attr: AttrValue): Array[Float] = {
+ attr.getArrayValue.getFltList.asScala.toArray.map(_.asInstanceOf[Float])
+ }
+
+
+
+ /**
+ * Copy BigDL module data (weight and bias if exist) into BigDL Model's protobuf definition
* @param modelBuilder serialized module builder
* @param context serialization context
*/
protected def copyFromBigDL[T: ClassTag](context : SerializeContext[T],
modelBuilder : BigDLModule.Builder)(implicit ev : TensorNumeric[T]) : Unit = {
+
val parameters = context.moduleData.module.parameters
if (parameters != null && parameters._1 != null) {
modelBuilder.setHasParameters(true)
@@ -373,6 +447,57 @@ trait ModuleSerializable extends Loadable with Savable{
}
}
+
+ /**
+ * Serialize and save MKL DNN INT8 attributes into BigDL Model of protobuf definition
+ * @param modelBuilder serialized module builder
+ * @param context serialization context
+ */
+ protected def saveMklInt8Attr[T: ClassTag](module : MklInt8Convertible,
+ modelBuilder : BigDLModule.Builder)
+ (implicit ev : TensorNumeric[T]) : Unit = {
+
+
+ // Save scale and mask of input into BigDL model builder
+ val inputScales : Array[Array[Float]] = module.getInputScales()
+ val inputMasks : Int = module.getInputDimMask()
+
+
+ val inputScalesAttrList = inputScales.map(floatArrayToAttrValue)
+
+ modelBuilder.addAllInputScales(inputScalesAttrList.toIterable.asJava)
+ modelBuilder.setInputDimMasks(inputMasks)
+
+
+ // Save scale and mask of output into BigDL model builder
+ val outputScales : Array[Array[Float]] = module.getOutputScales()
+ val outputMasks : Int = module.getOutputDimMask()
+
+ val outputScalesAttrList = outputScales.map(floatArrayToAttrValue)
+
+ modelBuilder.addAllOutputScales(outputScalesAttrList.toIterable.asJava)
+ modelBuilder.setOutputDimMasks(outputMasks)
+ }
+
+
+ /**
+ * Convert an array of float into an attr value object
+ * @param Array[Float]
+ * @return AttrValue
+ */
+ private def floatArrayToAttrValue(arry : Array[Float]) : AttrValue = {
+ val tempAttrValBuilder = AttrValue.newBuilder()
+ tempAttrValBuilder.setDataType(DataType.ARRAY_VALUE)
+
+ val tempArryValBuilder = ArrayValue.newBuilder()
+ tempArryValBuilder.setSize(arry.length)
+ tempArryValBuilder.setDatatype(DataType.FLOAT)
+
+ arry.foreach(tempArryValBuilder.addFlt)
+ tempAttrValBuilder.setArrayValue(tempArryValBuilder).build()
+ }
+
+
}
trait ContainerSerializable extends ModuleSerializable {
diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializer.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializer.scala
index 734c1a9369a..fb0845c4f05 100644
--- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializer.scala
+++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializer.scala
@@ -119,7 +119,7 @@ object ModuleSerializer extends ModuleSerializable{
(implicit ev: TensorNumeric[T]) : ModuleData[T] = {
try {
val model = context.bigdlModule
- val deSerializer = if (serializerMaps.contains(model.getModuleType)) {
+ val deserializer = if (serializerMaps.contains(model.getModuleType)) {
serializerMaps(model.getModuleType)
} else {
val attrMap = model.getAttrMap
@@ -145,8 +145,8 @@ object ModuleSerializer extends ModuleSerializable{
}
}
}
- deSerializer.setCopyWeightAndBias(context.copyWeightAndBias).
- loadModule(context)
+ deserializer.setCopyWeightAndBias(context.copyWeightAndBias).loadModule(context)
+
} catch {
case e: Exception =>
throw new RuntimeException(
diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/converters/DataConverter.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/converters/DataConverter.scala
index edeb4b85a54..11aee0ea88a 100644
--- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/converters/DataConverter.scala
+++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/converters/DataConverter.scala
@@ -166,7 +166,8 @@ object DataConverter extends DataConverter{
VariableFormatConverter.setAttributeValue(context, attributeBuilder, value)
} else if (valueType =:= universe.typeOf[InitializationMethod]) {
InitMethodConverter.setAttributeValue(context, attributeBuilder, value)
- } else if (valueType.toString == ModuleSerializer.regularizerType.toString) {
+ } else if (valueType.toString == ModuleSerializer.regularizerType.toString
+ || valueType <:< universe.typeOf[Regularizer[_]]) {
RegularizerConverter.setAttributeValue(context, attributeBuilder, value)
} else if (valueType <:< universe.typeOf[Tensor[_]]) {
TensorConverter.setAttributeValue(context, attributeBuilder, value)
diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/Fp32ToInt8Spec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/Fp32ToInt8Spec.scala
new file mode 100644
index 00000000000..a3ffe09a114
--- /dev/null
+++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/Fp32ToInt8Spec.scala
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2016 The BigDL Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.analytics.bigdl.nn.mkldnn
+
+import java.io.File
+import java.util.UUID
+
+import com.intel.analytics.bigdl.nn.Module
+import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}
+
+import scala.util.Random
+
+class Fp32ToInt8Spec extends FlatSpec with Matchers with BeforeAndAfter {
+
+ val modelPath: String = "myTestModel" + UUID.randomUUID().toString
+ val weightPath: String = "myTestModelWeight" + UUID.randomUUID().toString
+
+ "Saving and loading scale and mask" should "work properly" in {
+
+ val myModel = Linear(3, 4)
+
+ val assignedInputMask: Int = Random.nextInt(100)
+ val assignedInputScales: Array[Array[Float]] = Array.ofDim[Float](3, 4).map(
+ (arry: Array[Float]) => {
+ arry.map((x: Float) => {
+ Random.nextFloat()
+ })
+ }
+ )
+
+ val assignedOutputMask: Int = Random.nextInt()
+ val assignedOutputScales: Array[Array[Float]] = Array.ofDim[Float](3, 4).map(
+ (arry: Array[Float]) => {
+ arry.map((x: Float) => {
+ Random.nextFloat()
+ })
+ }
+ )
+
+ myModel.setInputDimMask(assignedInputMask)
+ myModel.setInputScales(assignedInputScales)
+
+ myModel.setOutputDimMask(assignedOutputMask)
+ myModel.setOutputScales(assignedOutputScales)
+
+ myModel.saveModule(modelPath, weightPath, true)
+
+ val loadedModel = Module.loadModule[Float](modelPath, weightPath).asInstanceOf[Linear]
+
+ val loadedInputMask = loadedModel.getInputDimMask()
+ val loadedInputScales = loadedModel.getInputScales()
+ val loadedOutputMask = loadedModel.getOutputDimMask()
+ val loadedOutputScales = loadedModel.getOutputScales()
+
+ loadedInputMask should be (assignedInputMask)
+ loadedInputScales should be (assignedInputScales)
+
+ loadedOutputMask should be (assignedOutputMask)
+ loadedOutputScales should be (assignedOutputScales)
+
+ }
+
+ after {
+ new File(modelPath).delete()
+ new File(weightPath).delete()
+ }
+}