Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5026 from deeplearning4j/mp_up_sub_sampling_3d
[WIP] 3D Utility layer suite
- Loading branch information
Showing
38 changed files
with
2,901 additions
and
145 deletions.
There are no files selected for viewing
385 changes: 378 additions & 7 deletions
385
...earning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
...in/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasCropping3D.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional; | ||
|
||
import lombok.Data; | ||
import lombok.EqualsAndHashCode; | ||
import lombok.extern.slf4j.Slf4j; | ||
import org.deeplearning4j.nn.conf.inputs.InputType; | ||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D; | ||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; | ||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; | ||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; | ||
|
||
import java.util.Map; | ||
|
||
import static org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolutionUtils.getPaddingFromConfig; | ||
|
||
/** | ||
* Imports a Keras Cropping 3D layer. | ||
* | ||
* @author Max Pumperla | ||
*/ | ||
@Slf4j | ||
@Data | ||
@EqualsAndHashCode(callSuper = false) | ||
public class KerasCropping3D extends KerasLayer { | ||
|
||
/** | ||
* Constructor from parsed Keras layer configuration dictionary. | ||
* | ||
* @param layerConfig dictionary containing Keras layer configuration. | ||
* @throws InvalidKerasConfigurationException Invalid Keras config | ||
* @throws UnsupportedKerasConfigurationException Unsupported Keras config | ||
*/ | ||
public KerasCropping3D(Map<String, Object> layerConfig) | ||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { | ||
this(layerConfig, true); | ||
} | ||
|
||
/** | ||
* Constructor from parsed Keras layer configuration dictionary. | ||
* | ||
* @param layerConfig dictionary containing Keras layer configuration | ||
* @param enforceTrainingConfig whether to enforce training-related configuration options | ||
* @throws InvalidKerasConfigurationException Invalid Keras config | ||
* @throws UnsupportedKerasConfigurationException Unsupported Keras config | ||
*/ | ||
public KerasCropping3D(Map<String, Object> layerConfig, boolean enforceTrainingConfig) | ||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { | ||
super(layerConfig, enforceTrainingConfig); | ||
String croppingField = conf.getLAYER_FIELD_CROPPING(); | ||
int[] cropping = getPaddingFromConfig(layerConfig, conf, croppingField, 3); | ||
Cropping3D.Builder builder = new Cropping3D.Builder(cropping) | ||
.name(this.layerName).dropOut(this.dropout); | ||
this.layer = builder.build(); | ||
this.vertex = null; | ||
} | ||
|
||
/** | ||
* Get DL4J Cropping3D layer. | ||
* | ||
* @return Cropping3D layer | ||
*/ | ||
public Cropping3D getCropping3DLayer() { | ||
return (Cropping3D) this.layer; | ||
} | ||
|
||
/** | ||
* Get layer output type. | ||
* | ||
* @param inputType Array of InputTypes | ||
* @return output type as InputType | ||
* @throws InvalidKerasConfigurationException Invalid Keras config | ||
*/ | ||
@Override | ||
public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException { | ||
if (inputType.length > 1) | ||
throw new InvalidKerasConfigurationException( | ||
"Keras Cropping 3D layer accepts only one input (received " + inputType.length + ")"); | ||
return this.getCropping3DLayer().getOutputType(-1, inputType[0]); | ||
} | ||
} |
97 changes: 97 additions & 0 deletions
97
.../java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasUpsampling3D.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
/*- | ||
* | ||
* * Copyright 2017 Skymind,Inc. | ||
* * | ||
* * Licensed under the Apache License, Version 2.0 (the "License"); | ||
* * you may not use this file except in compliance with the License. | ||
* * You may obtain a copy of the License at | ||
* * | ||
* * http://www.apache.org/licenses/LICENSE-2.0 | ||
* * | ||
* * Unless required by applicable law or agreed to in writing, software | ||
* * distributed under the License is distributed on an "AS IS" BASIS, | ||
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* * See the License for the specific language governing permissions and | ||
* * limitations under the License. | ||
* | ||
*/ | ||
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional; | ||
|
||
import org.deeplearning4j.nn.conf.inputs.InputType; | ||
import org.deeplearning4j.nn.conf.layers.Upsampling2D; | ||
import org.deeplearning4j.nn.conf.layers.Upsampling3D; | ||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; | ||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; | ||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; | ||
|
||
import java.util.Map; | ||
|
||
|
||
/** | ||
* Keras Upsampling3D layer support | ||
* | ||
* @author Max Pumperla | ||
*/ | ||
public class KerasUpsampling3D extends KerasLayer { | ||
|
||
/** | ||
* Constructor from parsed Keras layer configuration dictionary. | ||
* | ||
* @param layerConfig dictionary containing Keras layer configuration. | ||
* @throws InvalidKerasConfigurationException Invalid Keras configuration exception | ||
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration exception | ||
*/ | ||
public KerasUpsampling3D(Map<String, Object> layerConfig) | ||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { | ||
this(layerConfig, true); | ||
} | ||
|
||
/** | ||
* Constructor from parsed Keras layer configuration dictionary. | ||
* | ||
* @param layerConfig dictionary containing Keras layer configuration | ||
* @param enforceTrainingConfig whether to enforce training-related configuration options | ||
* @throws InvalidKerasConfigurationException Invalid Keras configuration exception | ||
* @throws UnsupportedKerasConfigurationException Invalid Keras configuration exception | ||
*/ | ||
public KerasUpsampling3D(Map<String, Object> layerConfig, boolean enforceTrainingConfig) | ||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { | ||
super(layerConfig, enforceTrainingConfig); | ||
|
||
int[] size = KerasConvolutionUtils.getUpsamplingSizeFromConfig(layerConfig, 3, conf); | ||
// TODO: make sure to allow different sizes. | ||
|
||
Upsampling3D.Builder builder = new Upsampling3D.Builder() | ||
.name(this.layerName) | ||
.dropOut(this.dropout) | ||
.size(size[0]); | ||
|
||
this.layer = builder.build(); | ||
this.vertex = null; | ||
} | ||
|
||
/** | ||
* Get DL4J Upsampling3D layer. | ||
* | ||
* @return Upsampling3D layer | ||
*/ | ||
public Upsampling3D getUpsampling3DLayer() { | ||
return (Upsampling3D) this.layer; | ||
} | ||
|
||
/** | ||
* Get layer output type. | ||
* | ||
* @param inputType Array of InputTypes | ||
* @return output type as InputType | ||
* @throws InvalidKerasConfigurationException Invalid Keras config | ||
*/ | ||
@Override | ||
public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException { | ||
if (inputType.length > 1) | ||
throw new InvalidKerasConfigurationException( | ||
"Keras Upsampling 3D layer accepts only one input (received " + inputType.length + ")"); | ||
return this.getUpsampling3DLayer().getOutputType(-1, inputType[0]); | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
82 changes: 82 additions & 0 deletions
82
...java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasZeroPadding3D.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional; | ||
|
||
import lombok.Data; | ||
import lombok.EqualsAndHashCode; | ||
import lombok.extern.slf4j.Slf4j; | ||
import org.deeplearning4j.nn.conf.inputs.InputType; | ||
import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer; | ||
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer; | ||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; | ||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; | ||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; | ||
|
||
import java.util.Map; | ||
|
||
import static org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolutionUtils.getPaddingFromConfig; | ||
|
||
/** | ||
* Imports a Keras ZeroPadding 3D layer. | ||
* | ||
* @author Max Pumperla | ||
*/ | ||
@Slf4j | ||
@Data | ||
@EqualsAndHashCode(callSuper = false) | ||
public class KerasZeroPadding3D extends KerasLayer { | ||
|
||
/** | ||
* Constructor from parsed Keras layer configuration dictionary. | ||
* | ||
* @param layerConfig dictionary containing Keras layer configuration. | ||
* | ||
* @throws InvalidKerasConfigurationException Invalid Keras config | ||
* @throws UnsupportedKerasConfigurationException Unsupported Keras config | ||
*/ | ||
public KerasZeroPadding3D(Map<String, Object> layerConfig) | ||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { | ||
this(layerConfig, true); | ||
} | ||
|
||
/** | ||
* Constructor from parsed Keras layer configuration dictionary. | ||
* | ||
* @param layerConfig dictionary containing Keras layer configuration | ||
* @param enforceTrainingConfig whether to enforce training-related configuration options | ||
* @throws InvalidKerasConfigurationException Invalid Keras config | ||
* @throws UnsupportedKerasConfigurationException Unsupported Keras config | ||
*/ | ||
public KerasZeroPadding3D(Map<String, Object> layerConfig, boolean enforceTrainingConfig) | ||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { | ||
super(layerConfig, enforceTrainingConfig); | ||
String paddingField = conf.getLAYER_FIELD_ZERO_PADDING(); | ||
int[] padding = getPaddingFromConfig(layerConfig, conf, paddingField,3); | ||
ZeroPadding3DLayer.Builder builder = new ZeroPadding3DLayer.Builder(padding) | ||
.name(this.layerName).dropOut(this.dropout); | ||
this.layer = builder.build(); | ||
this.vertex = null; | ||
} | ||
|
||
/** | ||
* Get DL4J ZeroPadding3DLayer. | ||
* | ||
* @return ZeroPadding3DLayer | ||
*/ | ||
public ZeroPadding3DLayer getZeroPadding3DLayer() { | ||
return (ZeroPadding3DLayer) this.layer; | ||
} | ||
|
||
/** | ||
* Get layer output type. | ||
* | ||
* @param inputType Array of InputTypes | ||
* @return output type as InputType | ||
* @throws InvalidKerasConfigurationException Invalid Keras config | ||
*/ | ||
@Override | ||
public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException { | ||
if (inputType.length > 1) | ||
throw new InvalidKerasConfigurationException( | ||
"Keras ZeroPadding3D layer accepts only one input (received " + inputType.length + ")"); | ||
return this.getZeroPadding3DLayer().getOutputType(-1, inputType[0]); | ||
} | ||
} |
Oops, something went wrong.