diff --git a/spark/dl/src/main/java/com/intel/analytics/bigdl/serialization/Bigdl.java b/spark/dl/src/main/java/com/intel/analytics/bigdl/serialization/Bigdl.java index c889ed291b1..c97dd3a91c7 100644 --- a/spark/dl/src/main/java/com/intel/analytics/bigdl/serialization/Bigdl.java +++ b/spark/dl/src/main/java/com/intel/analytics/bigdl/serialization/Bigdl.java @@ -1313,6 +1313,69 @@ com.intel.analytics.bigdl.serialization.Bigdl.AttrValue getAttrOrThrow( */ com.intel.analytics.bigdl.serialization.Bigdl.BigDLTensorOrBuilder getParametersOrBuilder( int index); + + /** + * int32 inputDimMasks = 17; + */ + int getInputDimMasks(); + + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + java.util.List + getInputScalesList(); + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + com.intel.analytics.bigdl.serialization.Bigdl.AttrValue getInputScales(int index); + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + int getInputScalesCount(); + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + java.util.List + getInputScalesOrBuilderList(); + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder getInputScalesOrBuilder( + int index); + + /** + * int32 outputDimMasks = 19; + */ + int getOutputDimMasks(); + + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + java.util.List + getOutputScalesList(); + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + com.intel.analytics.bigdl.serialization.Bigdl.AttrValue getOutputScales(int index); + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + int getOutputScalesCount(); + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + java.util.List + getOutputScalesOrBuilderList(); + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder getOutputScalesOrBuilder( + int index); + + /** + * bool isMklInt8Enabled = 21; + */ + boolean getIsMklInt8Enabled(); } /** * Protobuf type {@code com.intel.analytics.bigdl.serialization.BigDLModule} @@ -1338,6 +1401,11 @@ private BigDLModule() { id_ = 0; hasParameters_ = false; parameters_ = java.util.Collections.emptyList(); + inputDimMasks_ = 0; + inputScales_ = java.util.Collections.emptyList(); + outputDimMasks_ = 0; + outputScales_ = java.util.Collections.emptyList(); + isMklInt8Enabled_ = false; } @java.lang.Override @@ -1508,6 +1576,39 @@ private BigDLModule( input.readMessage(com.intel.analytics.bigdl.serialization.Bigdl.BigDLTensor.parser(), extensionRegistry)); break; } + case 136: { + + inputDimMasks_ = input.readInt32(); + break; + } + case 146: { + if (!((mutable_bitField0_ & 0x00020000) == 0x00020000)) { + inputScales_ = new java.util.ArrayList(); + mutable_bitField0_ |= 0x00020000; + } + inputScales_.add( + input.readMessage(com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.parser(), extensionRegistry)); + break; + } + case 152: { + + outputDimMasks_ = input.readInt32(); + break; + } + case 162: { + if (!((mutable_bitField0_ & 0x00080000) == 0x00080000)) { + outputScales_ = new java.util.ArrayList(); + mutable_bitField0_ |= 0x00080000; + } + outputScales_.add( + input.readMessage(com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.parser(), extensionRegistry)); + break; + } + case 168: { + + isMklInt8Enabled_ = input.readBool(); + break; + } } } } catch (com.google.protobuf.InvalidProtocolBufferException e) { @@ -1528,6 +1629,12 @@ private BigDLModule( if (((mutable_bitField0_ & 0x00008000) == 0x00008000)) { parameters_ = java.util.Collections.unmodifiableList(parameters_); } + if (((mutable_bitField0_ & 0x00020000) == 0x00020000)) { + inputScales_ = java.util.Collections.unmodifiableList(inputScales_); + } + if (((mutable_bitField0_ & 0x00080000) == 0x00080000)) { + outputScales_ = java.util.Collections.unmodifiableList(outputScales_); + } this.unknownFields = unknownFields.build(); makeExtensionsImmutable(); } @@ -2187,6 +2294,103 @@ public com.intel.analytics.bigdl.serialization.Bigdl.BigDLTensorOrBuilder getPar return parameters_.get(index); } + public static final int INPUTDIMMASKS_FIELD_NUMBER = 17; + private int inputDimMasks_; + /** + * int32 inputDimMasks = 17; + */ + public int getInputDimMasks() { + return inputDimMasks_; + } + + public static final int INPUTSCALES_FIELD_NUMBER = 18; + private java.util.List inputScales_; + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public java.util.List getInputScalesList() { + return inputScales_; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public java.util.List + getInputScalesOrBuilderList() { + return inputScales_; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public int getInputScalesCount() { + return inputScales_.size(); + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue getInputScales(int index) { + return inputScales_.get(index); + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder getInputScalesOrBuilder( + int index) { + return inputScales_.get(index); + } + + public static final int OUTPUTDIMMASKS_FIELD_NUMBER = 19; + private int outputDimMasks_; + /** + * int32 outputDimMasks = 19; + */ + public int getOutputDimMasks() { + return outputDimMasks_; + } + + public static final int OUTPUTSCALES_FIELD_NUMBER = 20; + private java.util.List outputScales_; + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public java.util.List getOutputScalesList() { + return outputScales_; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public java.util.List + getOutputScalesOrBuilderList() { + return outputScales_; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public int getOutputScalesCount() { + return outputScales_.size(); + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue getOutputScales(int index) { + return outputScales_.get(index); + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder getOutputScalesOrBuilder( + int index) { + return outputScales_.get(index); + } + + public static final int ISMKLINT8ENABLED_FIELD_NUMBER = 21; + private boolean isMklInt8Enabled_; + /** + * bool isMklInt8Enabled = 21; + */ + public boolean getIsMklInt8Enabled() { + return isMklInt8Enabled_; + } + private byte memoizedIsInitialized = -1; public final boolean isInitialized() { byte isInitialized = memoizedIsInitialized; @@ -2250,6 +2454,21 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) for (int i = 0; i < parameters_.size(); i++) { output.writeMessage(16, parameters_.get(i)); } + if (inputDimMasks_ != 0) { + output.writeInt32(17, inputDimMasks_); + } + for (int i = 0; i < inputScales_.size(); i++) { + output.writeMessage(18, inputScales_.get(i)); + } + if (outputDimMasks_ != 0) { + output.writeInt32(19, outputDimMasks_); + } + for (int i = 0; i < outputScales_.size(); i++) { + output.writeMessage(20, outputScales_.get(i)); + } + if (isMklInt8Enabled_ != false) { + output.writeBool(21, isMklInt8Enabled_); + } unknownFields.writeTo(output); } @@ -2332,6 +2551,26 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeMessageSize(16, parameters_.get(i)); } + if (inputDimMasks_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(17, inputDimMasks_); + } + for (int i = 0; i < inputScales_.size(); i++) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(18, inputScales_.get(i)); + } + if (outputDimMasks_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(19, outputDimMasks_); + } + for (int i = 0; i < outputScales_.size(); i++) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(20, outputScales_.get(i)); + } + if (isMklInt8Enabled_ != false) { + size += com.google.protobuf.CodedOutputStream + .computeBoolSize(21, isMklInt8Enabled_); + } size += unknownFields.getSerializedSize(); memoizedSize = size; return size; @@ -2392,6 +2631,16 @@ public boolean equals(final java.lang.Object obj) { == other.getHasParameters()); result = result && getParametersList() .equals(other.getParametersList()); + result = result && (getInputDimMasks() + == other.getInputDimMasks()); + result = result && getInputScalesList() + .equals(other.getInputScalesList()); + result = result && (getOutputDimMasks() + == other.getOutputDimMasks()); + result = result && getOutputScalesList() + .equals(other.getOutputScalesList()); + result = result && (getIsMklInt8Enabled() + == other.getIsMklInt8Enabled()); result = result && unknownFields.equals(other.unknownFields); return result; } @@ -2455,6 +2704,21 @@ public int hashCode() { hash = (37 * hash) + PARAMETERS_FIELD_NUMBER; hash = (53 * hash) + getParametersList().hashCode(); } + hash = (37 * hash) + INPUTDIMMASKS_FIELD_NUMBER; + hash = (53 * hash) + getInputDimMasks(); + if (getInputScalesCount() > 0) { + hash = (37 * hash) + INPUTSCALES_FIELD_NUMBER; + hash = (53 * hash) + getInputScalesList().hashCode(); + } + hash = (37 * hash) + OUTPUTDIMMASKS_FIELD_NUMBER; + hash = (53 * hash) + getOutputDimMasks(); + if (getOutputScalesCount() > 0) { + hash = (37 * hash) + OUTPUTSCALES_FIELD_NUMBER; + hash = (53 * hash) + getOutputScalesList().hashCode(); + } + hash = (37 * hash) + ISMKLINT8ENABLED_FIELD_NUMBER; + hash = (53 * hash) + com.google.protobuf.Internal.hashBoolean( + getIsMklInt8Enabled()); hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; @@ -2604,6 +2868,8 @@ private void maybeForceBuilderInitialization() { .alwaysUseFieldBuilders) { getSubModulesFieldBuilder(); getParametersFieldBuilder(); + getInputScalesFieldBuilder(); + getOutputScalesFieldBuilder(); } } public Builder clear() { @@ -2663,6 +2929,24 @@ public Builder clear() { } else { parametersBuilder_.clear(); } + inputDimMasks_ = 0; + + if (inputScalesBuilder_ == null) { + inputScales_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00020000); + } else { + inputScalesBuilder_.clear(); + } + outputDimMasks_ = 0; + + if (outputScalesBuilder_ == null) { + outputScales_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00080000); + } else { + outputScalesBuilder_.clear(); + } + isMklInt8Enabled_ = false; + return this; } @@ -2744,6 +3028,27 @@ public com.intel.analytics.bigdl.serialization.Bigdl.BigDLModule buildPartial() } else { result.parameters_ = parametersBuilder_.build(); } + result.inputDimMasks_ = inputDimMasks_; + if (inputScalesBuilder_ == null) { + if (((bitField0_ & 0x00020000) == 0x00020000)) { + inputScales_ = java.util.Collections.unmodifiableList(inputScales_); + bitField0_ = (bitField0_ & ~0x00020000); + } + result.inputScales_ = inputScales_; + } else { + result.inputScales_ = inputScalesBuilder_.build(); + } + result.outputDimMasks_ = outputDimMasks_; + if (outputScalesBuilder_ == null) { + if (((bitField0_ & 0x00080000) == 0x00080000)) { + outputScales_ = java.util.Collections.unmodifiableList(outputScales_); + bitField0_ = (bitField0_ & ~0x00080000); + } + result.outputScales_ = outputScales_; + } else { + result.outputScales_ = outputScalesBuilder_.build(); + } + result.isMklInt8Enabled_ = isMklInt8Enabled_; result.bitField0_ = to_bitField0_; onBuilt(); return result; @@ -2897,6 +3202,67 @@ public Builder mergeFrom(com.intel.analytics.bigdl.serialization.Bigdl.BigDLModu } } } + if (other.getInputDimMasks() != 0) { + setInputDimMasks(other.getInputDimMasks()); + } + if (inputScalesBuilder_ == null) { + if (!other.inputScales_.isEmpty()) { + if (inputScales_.isEmpty()) { + inputScales_ = other.inputScales_; + bitField0_ = (bitField0_ & ~0x00020000); + } else { + ensureInputScalesIsMutable(); + inputScales_.addAll(other.inputScales_); + } + onChanged(); + } + } else { + if (!other.inputScales_.isEmpty()) { + if (inputScalesBuilder_.isEmpty()) { + inputScalesBuilder_.dispose(); + inputScalesBuilder_ = null; + inputScales_ = other.inputScales_; + bitField0_ = (bitField0_ & ~0x00020000); + inputScalesBuilder_ = + com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders ? + getInputScalesFieldBuilder() : null; + } else { + inputScalesBuilder_.addAllMessages(other.inputScales_); + } + } + } + if (other.getOutputDimMasks() != 0) { + setOutputDimMasks(other.getOutputDimMasks()); + } + if (outputScalesBuilder_ == null) { + if (!other.outputScales_.isEmpty()) { + if (outputScales_.isEmpty()) { + outputScales_ = other.outputScales_; + bitField0_ = (bitField0_ & ~0x00080000); + } else { + ensureOutputScalesIsMutable(); + outputScales_.addAll(other.outputScales_); + } + onChanged(); + } + } else { + if (!other.outputScales_.isEmpty()) { + if (outputScalesBuilder_.isEmpty()) { + outputScalesBuilder_.dispose(); + outputScalesBuilder_ = null; + outputScales_ = other.outputScales_; + bitField0_ = (bitField0_ & ~0x00080000); + outputScalesBuilder_ = + com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders ? + getOutputScalesFieldBuilder() : null; + } else { + outputScalesBuilder_.addAllMessages(other.outputScales_); + } + } + } + if (other.getIsMklInt8Enabled() != false) { + setIsMklInt8Enabled(other.getIsMklInt8Enabled()); + } this.mergeUnknownFields(other.unknownFields); onChanged(); return this; @@ -5041,6 +5407,564 @@ public com.intel.analytics.bigdl.serialization.Bigdl.BigDLTensor.Builder addPara } return parametersBuilder_; } + + private int inputDimMasks_ ; + /** + * int32 inputDimMasks = 17; + */ + public int getInputDimMasks() { + return inputDimMasks_; + } + /** + * int32 inputDimMasks = 17; + */ + public Builder setInputDimMasks(int value) { + + inputDimMasks_ = value; + onChanged(); + return this; + } + /** + * int32 inputDimMasks = 17; + */ + public Builder clearInputDimMasks() { + + inputDimMasks_ = 0; + onChanged(); + return this; + } + + private java.util.List inputScales_ = + java.util.Collections.emptyList(); + private void ensureInputScalesIsMutable() { + if (!((bitField0_ & 0x00020000) == 0x00020000)) { + inputScales_ = new java.util.ArrayList(inputScales_); + bitField0_ |= 0x00020000; + } + } + + private com.google.protobuf.RepeatedFieldBuilderV3< + com.intel.analytics.bigdl.serialization.Bigdl.AttrValue, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder, com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder> inputScalesBuilder_; + + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public java.util.List getInputScalesList() { + if (inputScalesBuilder_ == null) { + return java.util.Collections.unmodifiableList(inputScales_); + } else { + return inputScalesBuilder_.getMessageList(); + } + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public int getInputScalesCount() { + if (inputScalesBuilder_ == null) { + return inputScales_.size(); + } else { + return inputScalesBuilder_.getCount(); + } + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue getInputScales(int index) { + if (inputScalesBuilder_ == null) { + return inputScales_.get(index); + } else { + return inputScalesBuilder_.getMessage(index); + } + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public Builder setInputScales( + int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue value) { + if (inputScalesBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureInputScalesIsMutable(); + inputScales_.set(index, value); + onChanged(); + } else { + inputScalesBuilder_.setMessage(index, value); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public Builder setInputScales( + int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder builderForValue) { + if (inputScalesBuilder_ == null) { + ensureInputScalesIsMutable(); + inputScales_.set(index, builderForValue.build()); + onChanged(); + } else { + inputScalesBuilder_.setMessage(index, builderForValue.build()); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public Builder addInputScales(com.intel.analytics.bigdl.serialization.Bigdl.AttrValue value) { + if (inputScalesBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureInputScalesIsMutable(); + inputScales_.add(value); + onChanged(); + } else { + inputScalesBuilder_.addMessage(value); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public Builder addInputScales( + int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue value) { + if (inputScalesBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureInputScalesIsMutable(); + inputScales_.add(index, value); + onChanged(); + } else { + inputScalesBuilder_.addMessage(index, value); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public Builder addInputScales( + com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder builderForValue) { + if (inputScalesBuilder_ == null) { + ensureInputScalesIsMutable(); + inputScales_.add(builderForValue.build()); + onChanged(); + } else { + inputScalesBuilder_.addMessage(builderForValue.build()); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public Builder addInputScales( + int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder builderForValue) { + if (inputScalesBuilder_ == null) { + ensureInputScalesIsMutable(); + inputScales_.add(index, builderForValue.build()); + onChanged(); + } else { + inputScalesBuilder_.addMessage(index, builderForValue.build()); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public Builder addAllInputScales( + java.lang.Iterable values) { + if (inputScalesBuilder_ == null) { + ensureInputScalesIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll( + values, inputScales_); + onChanged(); + } else { + inputScalesBuilder_.addAllMessages(values); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public Builder clearInputScales() { + if (inputScalesBuilder_ == null) { + inputScales_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00020000); + onChanged(); + } else { + inputScalesBuilder_.clear(); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public Builder removeInputScales(int index) { + if (inputScalesBuilder_ == null) { + ensureInputScalesIsMutable(); + inputScales_.remove(index); + onChanged(); + } else { + inputScalesBuilder_.remove(index); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder getInputScalesBuilder( + int index) { + return getInputScalesFieldBuilder().getBuilder(index); + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder getInputScalesOrBuilder( + int index) { + if (inputScalesBuilder_ == null) { + return inputScales_.get(index); } else { + return inputScalesBuilder_.getMessageOrBuilder(index); + } + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public java.util.List + getInputScalesOrBuilderList() { + if (inputScalesBuilder_ != null) { + return inputScalesBuilder_.getMessageOrBuilderList(); + } else { + return java.util.Collections.unmodifiableList(inputScales_); + } + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder addInputScalesBuilder() { + return getInputScalesFieldBuilder().addBuilder( + com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.getDefaultInstance()); + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder addInputScalesBuilder( + int index) { + return getInputScalesFieldBuilder().addBuilder( + index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.getDefaultInstance()); + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue inputScales = 18; + */ + public java.util.List + getInputScalesBuilderList() { + return getInputScalesFieldBuilder().getBuilderList(); + } + private com.google.protobuf.RepeatedFieldBuilderV3< + com.intel.analytics.bigdl.serialization.Bigdl.AttrValue, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder, com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder> + getInputScalesFieldBuilder() { + if (inputScalesBuilder_ == null) { + inputScalesBuilder_ = new com.google.protobuf.RepeatedFieldBuilderV3< + com.intel.analytics.bigdl.serialization.Bigdl.AttrValue, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder, com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder>( + inputScales_, + ((bitField0_ & 0x00020000) == 0x00020000), + getParentForChildren(), + isClean()); + inputScales_ = null; + } + return inputScalesBuilder_; + } + + private int outputDimMasks_ ; + /** + * int32 outputDimMasks = 19; + */ + public int getOutputDimMasks() { + return outputDimMasks_; + } + /** + * int32 outputDimMasks = 19; + */ + public Builder setOutputDimMasks(int value) { + + outputDimMasks_ = value; + onChanged(); + return this; + } + /** + * int32 outputDimMasks = 19; + */ + public Builder clearOutputDimMasks() { + + outputDimMasks_ = 0; + onChanged(); + return this; + } + + private java.util.List outputScales_ = + java.util.Collections.emptyList(); + private void ensureOutputScalesIsMutable() { + if (!((bitField0_ & 0x00080000) == 0x00080000)) { + outputScales_ = new java.util.ArrayList(outputScales_); + bitField0_ |= 0x00080000; + } + } + + private com.google.protobuf.RepeatedFieldBuilderV3< + com.intel.analytics.bigdl.serialization.Bigdl.AttrValue, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder, com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder> outputScalesBuilder_; + + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public java.util.List getOutputScalesList() { + if (outputScalesBuilder_ == null) { + return java.util.Collections.unmodifiableList(outputScales_); + } else { + return outputScalesBuilder_.getMessageList(); + } + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public int getOutputScalesCount() { + if (outputScalesBuilder_ == null) { + return outputScales_.size(); + } else { + return outputScalesBuilder_.getCount(); + } + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue getOutputScales(int index) { + if (outputScalesBuilder_ == null) { + return outputScales_.get(index); + } else { + return outputScalesBuilder_.getMessage(index); + } + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public Builder setOutputScales( + int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue value) { + if (outputScalesBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureOutputScalesIsMutable(); + outputScales_.set(index, value); + onChanged(); + } else { + outputScalesBuilder_.setMessage(index, value); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public Builder setOutputScales( + int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder builderForValue) { + if (outputScalesBuilder_ == null) { + ensureOutputScalesIsMutable(); + outputScales_.set(index, builderForValue.build()); + onChanged(); + } else { + outputScalesBuilder_.setMessage(index, builderForValue.build()); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public Builder addOutputScales(com.intel.analytics.bigdl.serialization.Bigdl.AttrValue value) { + if (outputScalesBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureOutputScalesIsMutable(); + outputScales_.add(value); + onChanged(); + } else { + outputScalesBuilder_.addMessage(value); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public Builder addOutputScales( + int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue value) { + if (outputScalesBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureOutputScalesIsMutable(); + outputScales_.add(index, value); + onChanged(); + } else { + outputScalesBuilder_.addMessage(index, value); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public Builder addOutputScales( + com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder builderForValue) { + if (outputScalesBuilder_ == null) { + ensureOutputScalesIsMutable(); + outputScales_.add(builderForValue.build()); + onChanged(); + } else { + outputScalesBuilder_.addMessage(builderForValue.build()); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public Builder addOutputScales( + int index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder builderForValue) { + if (outputScalesBuilder_ == null) { + ensureOutputScalesIsMutable(); + outputScales_.add(index, builderForValue.build()); + onChanged(); + } else { + outputScalesBuilder_.addMessage(index, builderForValue.build()); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public Builder addAllOutputScales( + java.lang.Iterable values) { + if (outputScalesBuilder_ == null) { + ensureOutputScalesIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll( + values, outputScales_); + onChanged(); + } else { + outputScalesBuilder_.addAllMessages(values); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public Builder clearOutputScales() { + if (outputScalesBuilder_ == null) { + outputScales_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00080000); + onChanged(); + } else { + outputScalesBuilder_.clear(); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public Builder removeOutputScales(int index) { + if (outputScalesBuilder_ == null) { + ensureOutputScalesIsMutable(); + outputScales_.remove(index); + onChanged(); + } else { + outputScalesBuilder_.remove(index); + } + return this; + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder getOutputScalesBuilder( + int index) { + return getOutputScalesFieldBuilder().getBuilder(index); + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder getOutputScalesOrBuilder( + int index) { + if (outputScalesBuilder_ == null) { + return outputScales_.get(index); } else { + return outputScalesBuilder_.getMessageOrBuilder(index); + } + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public java.util.List + getOutputScalesOrBuilderList() { + if (outputScalesBuilder_ != null) { + return outputScalesBuilder_.getMessageOrBuilderList(); + } else { + return java.util.Collections.unmodifiableList(outputScales_); + } + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder addOutputScalesBuilder() { + return getOutputScalesFieldBuilder().addBuilder( + com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.getDefaultInstance()); + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder addOutputScalesBuilder( + int index) { + return getOutputScalesFieldBuilder().addBuilder( + index, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.getDefaultInstance()); + } + /** + * repeated .com.intel.analytics.bigdl.serialization.AttrValue outputScales = 20; + */ + public java.util.List + getOutputScalesBuilderList() { + return getOutputScalesFieldBuilder().getBuilderList(); + } + private com.google.protobuf.RepeatedFieldBuilderV3< + com.intel.analytics.bigdl.serialization.Bigdl.AttrValue, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder, com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder> + getOutputScalesFieldBuilder() { + if (outputScalesBuilder_ == null) { + outputScalesBuilder_ = new com.google.protobuf.RepeatedFieldBuilderV3< + com.intel.analytics.bigdl.serialization.Bigdl.AttrValue, com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.Builder, com.intel.analytics.bigdl.serialization.Bigdl.AttrValueOrBuilder>( + outputScales_, + ((bitField0_ & 0x00080000) == 0x00080000), + getParentForChildren(), + isClean()); + outputScales_ = null; + } + return outputScalesBuilder_; + } + + private boolean isMklInt8Enabled_ ; + /** + * bool isMklInt8Enabled = 21; + */ + public boolean getIsMklInt8Enabled() { + return isMklInt8Enabled_; + } + /** + * bool isMklInt8Enabled = 21; + */ + public Builder setIsMklInt8Enabled(boolean value) { + + isMklInt8Enabled_ = value; + onChanged(); + return this; + } + /** + * bool isMklInt8Enabled = 21; + */ + public Builder clearIsMklInt8Enabled() { + + isMklInt8Enabled_ = false; + onChanged(); + return this; + } public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { return super.setUnknownFieldsProto3(unknownFields); @@ -22144,7 +23068,7 @@ public com.intel.analytics.bigdl.serialization.Bigdl.Shape getDefaultInstanceFor java.lang.String[] descriptorData = { "\n\013bigdl.proto\022\'com.intel.analytics.bigdl" + ".serialization\032\031google/protobuf/any.prot" + - "o\"\206\006\n\013BigDLModule\022\014\n\004name\030\001 \001(\t\022H\n\nsubMo" + + "o\"\342\007\n\013BigDLModule\022\014\n\004name\030\001 \001(\t\022H\n\nsubMo" + "dules\030\002 \003(\01324.com.intel.analytics.bigdl." + "serialization.BigDLModule\022D\n\006weight\030\003 \001(" + "\01324.com.intel.analytics.bigdl.serializat" + @@ -22161,105 +23085,110 @@ public com.intel.analytics.bigdl.serialization.Bigdl.Shape getDefaultInstanceFor "igdl.serialization.Shape\022\025\n\rhasParameter" + "s\030\017 \001(\010\022H\n\nparameters\030\020 \003(\01324.com.intel." + "analytics.bigdl.serialization.BigDLTenso" + - "r\032_\n\tAttrEntry\022\013\n\003key\030\001 \001(\t\022A\n\005value\030\002 \001", - "(\01322.com.intel.analytics.bigdl.serializa" + - "tion.AttrValue:\0028\001\"g\n\nInitMethod\022K\n\nmeth" + - "odType\030\001 \001(\01627.com.intel.analytics.bigdl" + - ".serialization.InitMethodType\022\014\n\004data\030\002 " + - "\003(\001\"\326\002\n\013BigDLTensor\022C\n\010datatype\030\001 \001(\01621." + - "com.intel.analytics.bigdl.serialization." + - "DataType\022\014\n\004size\030\002 \003(\005\022\016\n\006stride\030\003 \003(\005\022\016" + - "\n\006offset\030\004 \001(\005\022\021\n\tdimension\030\005 \001(\005\022\021\n\tnEl" + - "ements\030\006 \001(\005\022\020\n\010isScalar\030\007 \001(\010\022G\n\007storag" + - "e\030\010 \001(\01326.com.intel.analytics.bigdl.seri", - "alization.TensorStorage\022\n\n\002id\030\t \001(\005\022G\n\nt" + - "ensorType\030\n \001(\01623.com.intel.analytics.bi" + - "gdl.serialization.TensorType\"\352\001\n\rTensorS" + - "torage\022C\n\010datatype\030\001 \001(\01621.com.intel.ana" + - "lytics.bigdl.serialization.DataType\022\022\n\nf" + - "loat_data\030\002 \003(\002\022\023\n\013double_data\030\003 \003(\001\022\021\n\t" + - "bool_data\030\004 \003(\010\022\023\n\013string_data\030\005 \003(\t\022\020\n\010" + - "int_data\030\006 \003(\005\022\021\n\tlong_data\030\007 \003(\003\022\022\n\nbyt" + - "es_data\030\010 \003(\014\022\n\n\002id\030\t \001(\005\"u\n\013Regularizer" + - "\022Q\n\017regularizerType\030\001 \001(\01628.com.intel.an", - "alytics.bigdl.serialization.RegularizerT" + - "ype\022\023\n\013regularData\030\002 \003(\001\"\224\016\n\tAttrValue\022C" + - "\n\010dataType\030\001 \001(\01621.com.intel.analytics.b" + - "igdl.serialization.DataType\022\017\n\007subType\030\002" + - " \001(\t\022\024\n\nint32Value\030\003 \001(\005H\000\022\024\n\nint64Value" + - "\030\004 \001(\003H\000\022\024\n\nfloatValue\030\005 \001(\002H\000\022\025\n\013double" + - "Value\030\006 \001(\001H\000\022\025\n\013stringValue\030\007 \001(\tH\000\022\023\n\t" + - "boolValue\030\010 \001(\010H\000\022P\n\020regularizerValue\030\t " + - "\001(\01324.com.intel.analytics.bigdl.serializ" + - "ation.RegularizerH\000\022K\n\013tensorValue\030\n \001(\013", - "24.com.intel.analytics.bigdl.serializati" + - "on.BigDLTensorH\000\022Q\n\023variableFormatValue\030" + - "\013 \001(\01622.com.intel.analytics.bigdl.serial" + - "ization.VarFormatH\000\022N\n\017initMethodValue\030\014" + - " \001(\01323.com.intel.analytics.bigdl.seriali" + - "zation.InitMethodH\000\022P\n\020bigDLModuleValue\030" + - "\r \001(\01324.com.intel.analytics.bigdl.serial" + - "ization.BigDLModuleH\000\022R\n\021nameAttrListVal" + - "ue\030\016 \001(\01325.com.intel.analytics.bigdl.ser" + - "ialization.NameAttrListH\000\022S\n\narrayValue\030", - "\017 \001(\0132=.com.intel.analytics.bigdl.serial" + - "ization.AttrValue.ArrayValueH\000\022S\n\017dataFo" + - "rmatValue\030\020 \001(\01628.com.intel.analytics.bi" + - "gdl.serialization.InputDataFormatH\000\022+\n\013c" + - "ustomValue\030\021 \001(\0132\024.google.protobuf.AnyH\000" + - "\022?\n\005shape\030\022 \001(\0132..com.intel.analytics.bi" + - "gdl.serialization.ShapeH\000\032\242\006\n\nArrayValue" + - "\022\014\n\004size\030\001 \001(\005\022C\n\010datatype\030\002 \001(\01621.com.i" + - "ntel.analytics.bigdl.serialization.DataT" + - "ype\022\013\n\003i32\030\003 \003(\005\022\013\n\003i64\030\004 \003(\003\022\013\n\003flt\030\005 \003", - "(\002\022\013\n\003dbl\030\006 \003(\001\022\013\n\003str\030\007 \003(\t\022\017\n\007boolean\030" + - "\010 \003(\010\022I\n\013Regularizer\030\t \003(\01324.com.intel.a" + - "nalytics.bigdl.serialization.Regularizer" + - "\022D\n\006tensor\030\n \003(\01324.com.intel.analytics.b" + - "igdl.serialization.BigDLTensor\022J\n\016variab" + - "leFormat\030\013 \003(\01622.com.intel.analytics.big" + - "dl.serialization.VarFormat\022G\n\ninitMethod" + - "\030\014 \003(\01323.com.intel.analytics.bigdl.seria" + - "lization.InitMethod\022I\n\013bigDLModule\030\r \003(\013" + - "24.com.intel.analytics.bigdl.serializati", - "on.BigDLModule\022K\n\014nameAttrList\030\016 \003(\01325.c" + - "om.intel.analytics.bigdl.serialization.N" + - "ameAttrList\022L\n\ndataFormat\030\017 \003(\01628.com.in" + - "tel.analytics.bigdl.serialization.InputD" + - "ataFormat\022$\n\006custom\030\020 \003(\0132\024.google.proto" + - "buf.Any\022=\n\005shape\030\021 \003(\0132..com.intel.analy" + - "tics.bigdl.serialization.ShapeB\007\n\005value\"" + - "\314\001\n\014NameAttrList\022\014\n\004name\030\001 \001(\t\022M\n\004attr\030\002" + - " \003(\0132?.com.intel.analytics.bigdl.seriali" + - "zation.NameAttrList.AttrEntry\032_\n\tAttrEnt", - "ry\022\013\n\003key\030\001 \001(\t\022A\n\005value\030\002 \001(\01322.com.int" + - "el.analytics.bigdl.serialization.AttrVal" + - "ue:\0028\001\"\332\001\n\005Shape\022K\n\tshapeType\030\001 \001(\01628.co" + + "r\022\025\n\rinputDimMasks\030\021 \001(\005\022G\n\013inputScales\030", + "\022 \003(\01322.com.intel.analytics.bigdl.serial" + + "ization.AttrValue\022\026\n\016outputDimMasks\030\023 \001(" + + "\005\022H\n\014outputScales\030\024 \003(\01322.com.intel.anal" + + "ytics.bigdl.serialization.AttrValue\022\030\n\020i" + + "sMklInt8Enabled\030\025 \001(\010\032_\n\tAttrEntry\022\013\n\003ke" + + "y\030\001 \001(\t\022A\n\005value\030\002 \001(\01322.com.intel.analy" + + "tics.bigdl.serialization.AttrValue:\0028\001\"g" + + "\n\nInitMethod\022K\n\nmethodType\030\001 \001(\01627.com.i" + + "ntel.analytics.bigdl.serialization.InitM" + + "ethodType\022\014\n\004data\030\002 \003(\001\"\326\002\n\013BigDLTensor\022", + "C\n\010datatype\030\001 \001(\01621.com.intel.analytics." + + "bigdl.serialization.DataType\022\014\n\004size\030\002 \003" + + "(\005\022\016\n\006stride\030\003 \003(\005\022\016\n\006offset\030\004 \001(\005\022\021\n\tdi" + + "mension\030\005 \001(\005\022\021\n\tnElements\030\006 \001(\005\022\020\n\010isSc" + + "alar\030\007 \001(\010\022G\n\007storage\030\010 \001(\01326.com.intel." + + "analytics.bigdl.serialization.TensorStor" + + "age\022\n\n\002id\030\t \001(\005\022G\n\ntensorType\030\n \001(\01623.co" + + "m.intel.analytics.bigdl.serialization.Te" + + "nsorType\"\352\001\n\rTensorStorage\022C\n\010datatype\030\001" + + " \001(\01621.com.intel.analytics.bigdl.seriali", + "zation.DataType\022\022\n\nfloat_data\030\002 \003(\002\022\023\n\013d" + + "ouble_data\030\003 \003(\001\022\021\n\tbool_data\030\004 \003(\010\022\023\n\013s" + + "tring_data\030\005 \003(\t\022\020\n\010int_data\030\006 \003(\005\022\021\n\tlo" + + "ng_data\030\007 \003(\003\022\022\n\nbytes_data\030\010 \003(\014\022\n\n\002id\030" + + "\t \001(\005\"u\n\013Regularizer\022Q\n\017regularizerType\030" + + "\001 \001(\01628.com.intel.analytics.bigdl.serial" + + "ization.RegularizerType\022\023\n\013regularData\030\002" + + " \003(\001\"\224\016\n\tAttrValue\022C\n\010dataType\030\001 \001(\01621.c" + + "om.intel.analytics.bigdl.serialization.D" + + "ataType\022\017\n\007subType\030\002 \001(\t\022\024\n\nint32Value\030\003", + " \001(\005H\000\022\024\n\nint64Value\030\004 \001(\003H\000\022\024\n\nfloatVal" + + "ue\030\005 \001(\002H\000\022\025\n\013doubleValue\030\006 \001(\001H\000\022\025\n\013str" + + "ingValue\030\007 \001(\tH\000\022\023\n\tboolValue\030\010 \001(\010H\000\022P\n" + + "\020regularizerValue\030\t \001(\01324.com.intel.anal" + + "ytics.bigdl.serialization.RegularizerH\000\022" + + "K\n\013tensorValue\030\n \001(\01324.com.intel.analyti" + + "cs.bigdl.serialization.BigDLTensorH\000\022Q\n\023" + + "variableFormatValue\030\013 \001(\01622.com.intel.an" + + "alytics.bigdl.serialization.VarFormatH\000\022" + + "N\n\017initMethodValue\030\014 \001(\01323.com.intel.ana", + "lytics.bigdl.serialization.InitMethodH\000\022" + + "P\n\020bigDLModuleValue\030\r \001(\01324.com.intel.an" + + "alytics.bigdl.serialization.BigDLModuleH" + + "\000\022R\n\021nameAttrListValue\030\016 \001(\01325.com.intel" + + ".analytics.bigdl.serialization.NameAttrL" + + "istH\000\022S\n\narrayValue\030\017 \001(\0132=.com.intel.an" + + "alytics.bigdl.serialization.AttrValue.Ar" + + "rayValueH\000\022S\n\017dataFormatValue\030\020 \001(\01628.co" + + "m.intel.analytics.bigdl.serialization.In" + + "putDataFormatH\000\022+\n\013customValue\030\021 \001(\0132\024.g", + "oogle.protobuf.AnyH\000\022?\n\005shape\030\022 \001(\0132..co" + "m.intel.analytics.bigdl.serialization.Sh" + - "ape.ShapeType\022\r\n\005ssize\030\002 \001(\005\022\022\n\nshapeVal" + - "ue\030\003 \003(\005\022=\n\005shape\030\004 \003(\0132..com.intel.anal" + - "ytics.bigdl.serialization.Shape\"\"\n\tShape" + - "Type\022\n\n\006SINGLE\020\000\022\t\n\005MULTI\020\001*\260\001\n\tVarForma" + - "t\022\020\n\014EMPTY_FORMAT\020\000\022\013\n\007DEFAULT\020\001\022\t\n\005ONE_" + - "D\020\002\022\n\n\006IN_OUT\020\003\022\n\n\006OUT_IN\020\004\022\020\n\014IN_OUT_KW", - "_KH\020\005\022\020\n\014OUT_IN_KW_KH\020\006\022\023\n\017GP_OUT_IN_KW_" + - "KH\020\007\022\023\n\017GP_IN_OUT_KW_KH\020\010\022\023\n\017OUT_IN_KT_K" + - "H_KW\020\t*\253\001\n\016InitMethodType\022\030\n\024EMPTY_INITI" + - "ALIZATION\020\000\022\022\n\016RANDOM_UNIFORM\020\001\022\030\n\024RANDO" + - "M_UNIFORM_PARAM\020\002\022\021\n\rRANDOM_NORMAL\020\003\022\t\n\005" + - "ZEROS\020\004\022\010\n\004ONES\020\005\022\t\n\005CONST\020\006\022\n\n\006XAVIER\020\007" + - "\022\022\n\016BILINEARFILLER\020\010*L\n\017RegularizerType\022" + - "\023\n\017L1L2Regularizer\020\000\022\021\n\rL1Regularizer\020\001\022" + - "\021\n\rL2Regularizer\020\002*%\n\017InputDataFormat\022\010\n" + - "\004NCHW\020\000\022\010\n\004NHWC\020\001*\"\n\nTensorType\022\t\n\005DENSE", - "\020\000\022\t\n\005QUANT\020\001*\210\002\n\010DataType\022\t\n\005INT32\020\000\022\t\n" + - "\005INT64\020\001\022\t\n\005FLOAT\020\002\022\n\n\006DOUBLE\020\003\022\n\n\006STRIN" + - "G\020\004\022\010\n\004BOOL\020\005\022\010\n\004CHAR\020\006\022\t\n\005SHORT\020\007\022\t\n\005BY" + - "TES\020\010\022\017\n\013REGULARIZER\020\t\022\n\n\006TENSOR\020\n\022\023\n\017VA" + - "RIABLE_FORMAT\020\013\022\016\n\nINITMETHOD\020\014\022\n\n\006MODUL" + - "E\020\r\022\022\n\016NAME_ATTR_LIST\020\016\022\017\n\013ARRAY_VALUE\020\017" + - "\022\017\n\013DATA_FORMAT\020\020\022\n\n\006CUSTOM\020\021\022\t\n\005SHAPE\020\022" + - "b\006proto3" + "apeH\000\032\242\006\n\nArrayValue\022\014\n\004size\030\001 \001(\005\022C\n\010da" + + "tatype\030\002 \001(\01621.com.intel.analytics.bigdl" + + ".serialization.DataType\022\013\n\003i32\030\003 \003(\005\022\013\n\003" + + "i64\030\004 \003(\003\022\013\n\003flt\030\005 \003(\002\022\013\n\003dbl\030\006 \003(\001\022\013\n\003s" + + "tr\030\007 \003(\t\022\017\n\007boolean\030\010 \003(\010\022I\n\013Regularizer" + + "\030\t \003(\01324.com.intel.analytics.bigdl.seria" + + "lization.Regularizer\022D\n\006tensor\030\n \003(\01324.c" + + "om.intel.analytics.bigdl.serialization.B", + "igDLTensor\022J\n\016variableFormat\030\013 \003(\01622.com" + + ".intel.analytics.bigdl.serialization.Var" + + "Format\022G\n\ninitMethod\030\014 \003(\01323.com.intel.a" + + "nalytics.bigdl.serialization.InitMethod\022" + + "I\n\013bigDLModule\030\r \003(\01324.com.intel.analyti" + + "cs.bigdl.serialization.BigDLModule\022K\n\014na" + + "meAttrList\030\016 \003(\01325.com.intel.analytics.b" + + "igdl.serialization.NameAttrList\022L\n\ndataF" + + "ormat\030\017 \003(\01628.com.intel.analytics.bigdl." + + "serialization.InputDataFormat\022$\n\006custom\030", + "\020 \003(\0132\024.google.protobuf.Any\022=\n\005shape\030\021 \003" + + "(\0132..com.intel.analytics.bigdl.serializa" + + "tion.ShapeB\007\n\005value\"\314\001\n\014NameAttrList\022\014\n\004" + + "name\030\001 \001(\t\022M\n\004attr\030\002 \003(\0132?.com.intel.ana" + + "lytics.bigdl.serialization.NameAttrList." + + "AttrEntry\032_\n\tAttrEntry\022\013\n\003key\030\001 \001(\t\022A\n\005v" + + "alue\030\002 \001(\01322.com.intel.analytics.bigdl.s" + + "erialization.AttrValue:\0028\001\"\332\001\n\005Shape\022K\n\t" + + "shapeType\030\001 \001(\01628.com.intel.analytics.bi" + + "gdl.serialization.Shape.ShapeType\022\r\n\005ssi", + "ze\030\002 \001(\005\022\022\n\nshapeValue\030\003 \003(\005\022=\n\005shape\030\004 " + + "\003(\0132..com.intel.analytics.bigdl.serializ" + + "ation.Shape\"\"\n\tShapeType\022\n\n\006SINGLE\020\000\022\t\n\005" + + "MULTI\020\001*\260\001\n\tVarFormat\022\020\n\014EMPTY_FORMAT\020\000\022" + + "\013\n\007DEFAULT\020\001\022\t\n\005ONE_D\020\002\022\n\n\006IN_OUT\020\003\022\n\n\006O" + + "UT_IN\020\004\022\020\n\014IN_OUT_KW_KH\020\005\022\020\n\014OUT_IN_KW_K" + + "H\020\006\022\023\n\017GP_OUT_IN_KW_KH\020\007\022\023\n\017GP_IN_OUT_KW" + + "_KH\020\010\022\023\n\017OUT_IN_KT_KH_KW\020\t*\253\001\n\016InitMetho" + + "dType\022\030\n\024EMPTY_INITIALIZATION\020\000\022\022\n\016RANDO" + + "M_UNIFORM\020\001\022\030\n\024RANDOM_UNIFORM_PARAM\020\002\022\021\n", + "\rRANDOM_NORMAL\020\003\022\t\n\005ZEROS\020\004\022\010\n\004ONES\020\005\022\t\n" + + "\005CONST\020\006\022\n\n\006XAVIER\020\007\022\022\n\016BILINEARFILLER\020\010" + + "*L\n\017RegularizerType\022\023\n\017L1L2Regularizer\020\000" + + "\022\021\n\rL1Regularizer\020\001\022\021\n\rL2Regularizer\020\002*%" + + "\n\017InputDataFormat\022\010\n\004NCHW\020\000\022\010\n\004NHWC\020\001*\"\n" + + "\nTensorType\022\t\n\005DENSE\020\000\022\t\n\005QUANT\020\001*\210\002\n\010Da" + + "taType\022\t\n\005INT32\020\000\022\t\n\005INT64\020\001\022\t\n\005FLOAT\020\002\022" + + "\n\n\006DOUBLE\020\003\022\n\n\006STRING\020\004\022\010\n\004BOOL\020\005\022\010\n\004CHA" + + "R\020\006\022\t\n\005SHORT\020\007\022\t\n\005BYTES\020\010\022\017\n\013REGULARIZER" + + "\020\t\022\n\n\006TENSOR\020\n\022\023\n\017VARIABLE_FORMAT\020\013\022\016\n\nI", + "NITMETHOD\020\014\022\n\n\006MODULE\020\r\022\022\n\016NAME_ATTR_LIS" + + "T\020\016\022\017\n\013ARRAY_VALUE\020\017\022\017\n\013DATA_FORMAT\020\020\022\n\n" + + "\006CUSTOM\020\021\022\t\n\005SHAPE\020\022b\006proto3" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { @@ -22279,7 +23208,7 @@ public com.google.protobuf.ExtensionRegistry assignDescriptors( internal_static_com_intel_analytics_bigdl_serialization_BigDLModule_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_com_intel_analytics_bigdl_serialization_BigDLModule_descriptor, - new java.lang.String[] { "Name", "SubModules", "Weight", "Bias", "PreModules", "NextModules", "ModuleType", "Attr", "Version", "Train", "NamePostfix", "Id", "InputShape", "OutputShape", "HasParameters", "Parameters", }); + new java.lang.String[] { "Name", "SubModules", "Weight", "Bias", "PreModules", "NextModules", "ModuleType", "Attr", "Version", "Train", "NamePostfix", "Id", "InputShape", "OutputShape", "HasParameters", "Parameters", "InputDimMasks", "InputScales", "OutputDimMasks", "OutputScales", "IsMklInt8Enabled", }); internal_static_com_intel_analytics_bigdl_serialization_BigDLModule_AttrEntry_descriptor = internal_static_com_intel_analytics_bigdl_serialization_BigDLModule_descriptor.getNestedTypes().get(0); internal_static_com_intel_analytics_bigdl_serialization_BigDLModule_AttrEntry_fieldAccessorTable = new diff --git a/spark/dl/src/main/resources/serialization/bigdl.proto b/spark/dl/src/main/resources/serialization/bigdl.proto index 9f0e276304d..5e53c6e32c3 100644 --- a/spark/dl/src/main/resources/serialization/bigdl.proto +++ b/spark/dl/src/main/resources/serialization/bigdl.proto @@ -21,6 +21,11 @@ message BigDLModule Shape outputShape = 14; //output shape bool hasParameters = 15; // indicator if module has parameters repeated BigDLTensor parameters = 16; // parameters, e.g., weight and bias + int32 inputDimMasks = 17; + repeated AttrValue inputScales = 18; + int32 outputDimMasks = 19; + repeated AttrValue outputScales = 20; + bool isMklInt8Enabled = 21; } enum VarFormat { EMPTY_FORMAT = 0; diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/abstractnn/AbstractModule.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/abstractnn/AbstractModule.scala index 3dc2e4832b0..8cccc7d0123 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/abstractnn/AbstractModule.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/abstractnn/AbstractModule.scala @@ -59,8 +59,6 @@ abstract class AbstractModule[A <: Activity: ClassTag, B <: Activity: ClassTag, implicit ev: TensorNumeric[T]) extends Serializable with InferShape{ // ================================= Public APIs ============================================= - - /** * The cached output. So we don't compute it again when need it */ @@ -1183,3 +1181,5 @@ abstract class AbstractModule[A <: Activity: ClassTag, B <: Activity: ClassTag, } } + + diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/Linear.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/Linear.scala index 0cca0ba0941..9bfbaeba118 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/Linear.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/Linear.scala @@ -32,7 +32,8 @@ class Linear( private val initWeight: Tensor[Float] = null, private val initBias: Tensor[Float] = null, private val initGradWeight: Tensor[Float] = null, - private val initGradBias: Tensor[Float] = null) extends MklDnnLayer with Initializable { + private val initGradBias: Tensor[Float] = null +) extends MklDnnLayer with Initializable with MklInt8Convertible { private[mkldnn] val weight: Blob = new Blob(Array(outputSize, inputSize)) private[mkldnn] val bias: Blob = new Blob(Array(outputSize)) diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/MklInt8Convertible.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/MklInt8Convertible.scala new file mode 100644 index 00000000000..b944870528c --- /dev/null +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/MklInt8Convertible.scala @@ -0,0 +1,164 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.intel.analytics.bigdl.nn.mkldnn + +import scala.collection.mutable.ArrayBuffer + +/** +* Trait which provides MKL-DNN functionality to convert FP32 model to INT8 model +*/ +trait MklInt8Convertible { + // input dimension mask + protected var inDimMask: Int = 0 + // output dimension mask + protected var outDimMask: Int = 0 + // input scales + private[mkldnn] var inScalesBuffer: ArrayBuffer[Array[Float]] = ArrayBuffer.empty[Array[Float]] + // output scales + private[mkldnn] var outScalesBuffer: ArrayBuffer[Array[Float]] = ArrayBuffer.empty[Array[Float]] + + + /** + * Get dimension mask of input + * @return inDimMask field which stores value of input dimension mask + */ + def getInputDimMask(): Int = { + inDimMask + } + + /** + * Set dimension mask of input + * @param mask value of input dimension mask to be set + * @return Unit + */ + def setInputDimMask(mask: Int) : Unit = { + inDimMask = mask + } + + /** + * Get dimension mask of output + * @return outDimMask field which stores value of output dimension mask + */ + def getOutputDimMask(): Int = { + outDimMask + } + + /** + * Set dimension mask of output + * @param mask value of output dimension mask to be set + * @return Unit + */ + def setOutputDimMask(mask: Int): Unit = { + outDimMask = mask + } + + /** + * Get input scales + * @return field which stores value of input scales + */ + def getInputScales(): Array[Array[Float]] = { + inScalesBuffer.toArray + } + + /** + * Set input scales + * Clear existing buffer of input scales, and place updated scales into the cleared buffer + * @param inScales value of input scales to be set + * @return Unit + */ + def setInputScales(inScales: Array[Array[Float]]): Unit = { + inScalesBuffer.clear() + inScales.foreach(appendInputScales) + } + + /** + * Get output scales + * @return field which stores value of output scales + */ + def getOutputScales(): Array[Array[Float]] = { + outScalesBuffer.toArray + } + + /** + * Set output scales + * Clear existing buffer of output scales, and place updated scales into the cleared buffer + * @param outScales value of output scales to be set + * @return Unit + */ + def setOutputScales(outScales: Array[Array[Float]]): Unit = { + outScalesBuffer.clear() + outScales.foreach(appendOutputScales) + } + + /** + * Append a scale, an array of float, into input scales buffer + * @param scale value of an input scale to be appended + * @return Unit + */ + private def appendInputScales(scale: Array[Float]): Unit = { + inScalesBuffer.append(scale) + } + + /** + * Append a scale, an array of float, into output scales buffer + * @param scale value of an output scale to be appended + * @return Unit + */ + private def appendOutputScales(scale: Array[Float]): Unit = { + outScalesBuffer.append(scale) + } + + /** + * Update input scales at specific index with provided new scale + * @param scale the new scale + * @param index the index of which the scale need to be updated + * @return Unit + */ + def updateInputScales(scale: Array[Float], index: Int): Unit = { + updateScalesHelper(inScalesBuffer, scale, index) + } + + /** + * Update output scales at specific index with provided new scale + * @param scale the new scale + * @param index the index of which the scale need to be updated + * @return Unit + */ + def updateOutputSclaes(scale: Array[Float], index: Int): Unit = { + updateScalesHelper(outScalesBuffer, scale, index) + } + + + /** + * Scales update helper. Replace scale at specific index with provided new scale + * @param scales the scales arrayBuffer to be updated + * @param scale the new scale + * @param index the index of which the scale need to be updated + * @return Unit + */ + private def updateScalesHelper(scales: ArrayBuffer[Array[Float]], + scale: Array[Float], index: Int): Unit = { + if (scales.length - 1 < index) { + scales.append(scale) + } + + scales(index).indices.foreach(i => + if (scale(i) > scales(index)(i)) { + scales(index)(i) = scale(i) + }) + } + +} diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializable.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializable.scala index 56bb2ea1686..61442dfabb6 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializable.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializable.scala @@ -15,12 +15,15 @@ */ package com.intel.analytics.bigdl.utils.serializer + import java.lang.reflect.Field +import scala.collection.JavaConverters._ import com.intel.analytics.bigdl.nn.Container import scala.collection.JavaConverters._ import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity} +import com.intel.analytics.bigdl.nn.mkldnn.MklInt8Convertible import com.intel.analytics.bigdl.serialization.Bigdl.AttrValue.ArrayValue import com.intel.analytics.bigdl.tensor.Tensor import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric @@ -29,7 +32,6 @@ import com.intel.analytics.bigdl.utils.serializer.converters.{DataConverter, Sha import com.intel.analytics.bigdl.utils.serializer.ModuleSerializer._ import com.intel.analytics.bigdl.serialization.Bigdl._ -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import scala.reflect.runtime.universe @@ -85,7 +87,6 @@ trait ModuleSerializable extends Loadable with Savable{ // step2 : module specific logic to load module, either default, cell, container or graph val moduleId = context.bigdlModule.getId - val storages = context.storages val module = if (storages.contains(moduleId)) { @@ -121,6 +122,7 @@ trait ModuleSerializable extends Loadable with Savable{ val constructorFullParams = constructorMirror.symbol.paramss val args = new Array[Object](constructorFullParams.map(_.size).sum) var i = 0 + constructorFullParams.foreach(map => { map.foreach(param => { val name = param.name.decodedName.toString @@ -185,8 +187,10 @@ trait ModuleSerializable extends Loadable with Savable{ } getLock.synchronized { + // step 4 : set data types (ClassTag and TensorNumric) setDataTypes(context, bigDLModelBuilder) + // step 5 : apply module specific logic to create module doSerializeModule(context, bigDLModelBuilder) } @@ -218,6 +222,7 @@ trait ModuleSerializable extends Loadable with Savable{ val fullParams = getCostructorMirror(cls).symbol.paramss val constructorParams = fullParams(0) constructorParams.foreach(param => { + val paramName = param.name.decodedName.toString var ptype = param.typeSignature val attrBuilder = AttrValue.newBuilder @@ -232,28 +237,36 @@ trait ModuleSerializable extends Loadable with Savable{ field.setAccessible(true) val fieldValue = field.get(module) DataConverter.setAttributeValue(context, attrBuilder, fieldValue, ptype) - bigDLModelBuilder.putAttr(paramName, attrBuilder.build) }) } + /* + * Re-create BigDL module by deserializing protobuf context. + * @param DeserializeContext: Deserialization context + * @param AbstractModule: The BigDL module to be re-created + * @return ModuleData: Tuple3 contains information of current module and modules adjacent to it + */ protected def createBigDLModule[T: ClassTag](context: DeserializeContext, module : AbstractModule[Activity, Activity, T]) - (implicit ev: TensorNumeric[T]) - : ModuleData[T] = { + (implicit ev: TensorNumeric[T]): ModuleData[T] = { val model = context.bigdlModule val preModules = model.getPreModulesList.asScala val nextModules = model.getNextModulesList.asScala val bigDLModule = ModuleData(module, preModules, nextModules) + if (model.getName != "") { module.setName(model.getName) } + module.setNamePostfix(model.getNamePostfix) + if (model.getTrain) { module.training() } else { module.evaluate() } + module.inputShapeValue = ShapeConverter.shapeToBigDL(context, model, "input") module.outputShapeValue = ShapeConverter.shapeToBigDL(context, model, "output") @@ -261,9 +274,23 @@ trait ModuleSerializable extends Loadable with Savable{ if (_copyWeightAndBias && context.bigdlModule.getSubModulesCount == 0) { copy2BigDL(context, bigDLModule) } + + // Load MKL-DNN INT8 attributes (scales&mask of input&output) into + // BigDL Module from protobuf definition if the MKL-DNN INT8 flag is ON + if (model.getIsMklInt8Enabled) { + loadMklInt8Attr(context, module.asInstanceOf[MklInt8Convertible]) + } + bigDLModule } + + /* + * Create BigDL model's protobuf definition by serializing BigDL Module object + * @param BigDLModule.Builder: BigDL model builder of protobuf definition + * @param SerializeContext: Serialized context of BigDL module + * @return SerializeResult: + */ protected def createSerializeBigDLModule[T: ClassTag]( modelBuilder : BigDLModule.Builder, context: SerializeContext[T])(implicit ev: TensorNumeric[T]) : SerializeResult = { @@ -288,6 +315,16 @@ trait ModuleSerializable extends Loadable with Savable{ if (_copyWeightAndBias && !module.isInstanceOf[Container[_, _, _]]) { copyFromBigDL(context, modelBuilder) } + + // Save MKL-DNN attributes (scales and masks) into model of protobuf definition if + // the module is with trait of MklInt8COnvertible, and set the MKL-DNN INT8 flag to true + if (module.module.isInstanceOf[MklInt8Convertible]) { + saveMklInt8Attr(context.moduleData.module.asInstanceOf[MklInt8Convertible], modelBuilder) + modelBuilder.setIsMklInt8Enabled(true) + } else { + modelBuilder.setIsMklInt8Enabled(false) + } + SerializeResult(modelBuilder, context.storages) } @@ -298,7 +335,6 @@ trait ModuleSerializable extends Loadable with Savable{ */ protected def copy2BigDL[T: ClassTag](context: DeserializeContext, module : ModuleData[T]) (implicit ev: TensorNumeric[T]): Unit = { - if (context.bigdlModule.getHasParameters) { copyParameters2BigDL(context, module) } else { @@ -356,12 +392,50 @@ trait ModuleSerializable extends Loadable with Savable{ } /** - * copy BigDL module data (weight and bias if exist) to BigDL Model to be persisted + * Deserialize MKL-DNN INT8 attributes from protobuf context + * and load them into BigDL Module object + * @param context deserialized context + * @param module bigDL Module with relationships + */ + private def loadMklInt8Attr[T: ClassTag](context: DeserializeContext, + module: MklInt8Convertible) + (implicit ev: TensorNumeric[T]): Unit = { + + val protobufModel = context.bigdlModule + + // Extract ArrayValue for each AttrValue, and then get FltList as input scales + val inputScales = protobufModel.getInputScalesList.iterator().asScala + .map(attrValueToFloatArray) + + // Extract ArrayValue for each AttrValue, and then get FltList as output scales + val outputScales = protobufModel.getOutputScalesList.iterator().asScala + .map(attrValueToFloatArray) + + module.setInputDimMask(protobufModel.getInputDimMasks) + module.setInputScales(inputScales.toArray) + module.setOutputDimMask(protobufModel.getOutputDimMasks) + module.setOutputScales(outputScales.toArray) + } + + /** + * Convert Attr Value object to Array of Float + * @param AttrValue + * @return Array[Float] + */ + def attrValueToFloatArray(attr: AttrValue): Array[Float] = { + attr.getArrayValue.getFltList.asScala.toArray.map(_.asInstanceOf[Float]) + } + + + + /** + * Copy BigDL module data (weight and bias if exist) into BigDL Model's protobuf definition * @param modelBuilder serialized module builder * @param context serialization context */ protected def copyFromBigDL[T: ClassTag](context : SerializeContext[T], modelBuilder : BigDLModule.Builder)(implicit ev : TensorNumeric[T]) : Unit = { + val parameters = context.moduleData.module.parameters if (parameters != null && parameters._1 != null) { modelBuilder.setHasParameters(true) @@ -373,6 +447,57 @@ trait ModuleSerializable extends Loadable with Savable{ } } + + /** + * Serialize and save MKL DNN INT8 attributes into BigDL Model of protobuf definition + * @param modelBuilder serialized module builder + * @param context serialization context + */ + protected def saveMklInt8Attr[T: ClassTag](module : MklInt8Convertible, + modelBuilder : BigDLModule.Builder) + (implicit ev : TensorNumeric[T]) : Unit = { + + + // Save scale and mask of input into BigDL model builder + val inputScales : Array[Array[Float]] = module.getInputScales() + val inputMasks : Int = module.getInputDimMask() + + + val inputScalesAttrList = inputScales.map(floatArrayToAttrValue) + + modelBuilder.addAllInputScales(inputScalesAttrList.toIterable.asJava) + modelBuilder.setInputDimMasks(inputMasks) + + + // Save scale and mask of output into BigDL model builder + val outputScales : Array[Array[Float]] = module.getOutputScales() + val outputMasks : Int = module.getOutputDimMask() + + val outputScalesAttrList = outputScales.map(floatArrayToAttrValue) + + modelBuilder.addAllOutputScales(outputScalesAttrList.toIterable.asJava) + modelBuilder.setOutputDimMasks(outputMasks) + } + + + /** + * Convert an array of float into an attr value object + * @param Array[Float] + * @return AttrValue + */ + private def floatArrayToAttrValue(arry : Array[Float]) : AttrValue = { + val tempAttrValBuilder = AttrValue.newBuilder() + tempAttrValBuilder.setDataType(DataType.ARRAY_VALUE) + + val tempArryValBuilder = ArrayValue.newBuilder() + tempArryValBuilder.setSize(arry.length) + tempArryValBuilder.setDatatype(DataType.FLOAT) + + arry.foreach(tempArryValBuilder.addFlt) + tempAttrValBuilder.setArrayValue(tempArryValBuilder).build() + } + + } trait ContainerSerializable extends ModuleSerializable { diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializer.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializer.scala index 734c1a9369a..fb0845c4f05 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializer.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/ModuleSerializer.scala @@ -119,7 +119,7 @@ object ModuleSerializer extends ModuleSerializable{ (implicit ev: TensorNumeric[T]) : ModuleData[T] = { try { val model = context.bigdlModule - val deSerializer = if (serializerMaps.contains(model.getModuleType)) { + val deserializer = if (serializerMaps.contains(model.getModuleType)) { serializerMaps(model.getModuleType) } else { val attrMap = model.getAttrMap @@ -145,8 +145,8 @@ object ModuleSerializer extends ModuleSerializable{ } } } - deSerializer.setCopyWeightAndBias(context.copyWeightAndBias). - loadModule(context) + deserializer.setCopyWeightAndBias(context.copyWeightAndBias).loadModule(context) + } catch { case e: Exception => throw new RuntimeException( diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/converters/DataConverter.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/converters/DataConverter.scala index edeb4b85a54..11aee0ea88a 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/converters/DataConverter.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/serializer/converters/DataConverter.scala @@ -166,7 +166,8 @@ object DataConverter extends DataConverter{ VariableFormatConverter.setAttributeValue(context, attributeBuilder, value) } else if (valueType =:= universe.typeOf[InitializationMethod]) { InitMethodConverter.setAttributeValue(context, attributeBuilder, value) - } else if (valueType.toString == ModuleSerializer.regularizerType.toString) { + } else if (valueType.toString == ModuleSerializer.regularizerType.toString + || valueType <:< universe.typeOf[Regularizer[_]]) { RegularizerConverter.setAttributeValue(context, attributeBuilder, value) } else if (valueType <:< universe.typeOf[Tensor[_]]) { TensorConverter.setAttributeValue(context, attributeBuilder, value) diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/Fp32ToInt8Spec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/Fp32ToInt8Spec.scala new file mode 100644 index 00000000000..a3ffe09a114 --- /dev/null +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/Fp32ToInt8Spec.scala @@ -0,0 +1,81 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.bigdl.nn.mkldnn + +import java.io.File +import java.util.UUID + +import com.intel.analytics.bigdl.nn.Module +import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers} + +import scala.util.Random + +class Fp32ToInt8Spec extends FlatSpec with Matchers with BeforeAndAfter { + + val modelPath: String = "myTestModel" + UUID.randomUUID().toString + val weightPath: String = "myTestModelWeight" + UUID.randomUUID().toString + + "Saving and loading scale and mask" should "work properly" in { + + val myModel = Linear(3, 4) + + val assignedInputMask: Int = Random.nextInt(100) + val assignedInputScales: Array[Array[Float]] = Array.ofDim[Float](3, 4).map( + (arry: Array[Float]) => { + arry.map((x: Float) => { + Random.nextFloat() + }) + } + ) + + val assignedOutputMask: Int = Random.nextInt() + val assignedOutputScales: Array[Array[Float]] = Array.ofDim[Float](3, 4).map( + (arry: Array[Float]) => { + arry.map((x: Float) => { + Random.nextFloat() + }) + } + ) + + myModel.setInputDimMask(assignedInputMask) + myModel.setInputScales(assignedInputScales) + + myModel.setOutputDimMask(assignedOutputMask) + myModel.setOutputScales(assignedOutputScales) + + myModel.saveModule(modelPath, weightPath, true) + + val loadedModel = Module.loadModule[Float](modelPath, weightPath).asInstanceOf[Linear] + + val loadedInputMask = loadedModel.getInputDimMask() + val loadedInputScales = loadedModel.getInputScales() + val loadedOutputMask = loadedModel.getOutputDimMask() + val loadedOutputScales = loadedModel.getOutputScales() + + loadedInputMask should be (assignedInputMask) + loadedInputScales should be (assignedInputScales) + + loadedOutputMask should be (assignedOutputMask) + loadedOutputScales should be (assignedOutputScales) + + } + + after { + new File(modelPath).delete() + new File(weightPath).delete() + } +}