-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial commit * good draft * additional loadbalance mode * small javadoc update * small javadoc update * OutputAdapter prototype * couple of tests * javadoc update
- Loading branch information
Showing
13 changed files
with
807 additions
and
27 deletions.
There are no files selected for viewing
55 changes: 55 additions & 0 deletions
55
...rning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/ArgmaxAdapter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/******************************************************************************* | ||
* Copyright (c) 2015-2018 Skymind, Inc. | ||
* | ||
* This program and the accompanying materials are made available under the | ||
* terms of the Apache License, Version 2.0 which is available at | ||
* https://www.apache.org/licenses/LICENSE-2.0. | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
* License for the specific language governing permissions and limitations | ||
* under the License. | ||
* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
******************************************************************************/ | ||
|
||
package org.deeplearning4j.nn.adapters; | ||
|
||
import lombok.val; | ||
import org.deeplearning4j.nn.api.OutputAdapter; | ||
import org.nd4j.base.Preconditions; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.factory.Nd4j; | ||
|
||
/** | ||
* This OutputAdapter implementation is suited for silent conversion of 2D SoftMax output | ||
* | ||
* @author raver119@gmail.com | ||
*/ | ||
public class ArgmaxAdapter implements OutputAdapter<int[]> { | ||
|
||
/** | ||
* This method does conversion from INDArrays to int[], where each element will represents position of the highest element in output INDArray | ||
* I.e. Array of {0.25, 0.1, 0.5, 0.15} will return int array with length of 1, and value {2} | ||
* | ||
* @param outputs | ||
* @return | ||
*/ | ||
@Override | ||
public int[] apply(INDArray... outputs) { | ||
Preconditions.checkArgument(outputs.length == 1, "Argmax adapter can have only 1 output"); | ||
val array = outputs[0]; | ||
Preconditions.checkArgument(array.rank() < 3, "Argmax adapter requires 2D or 1D output"); | ||
val result = array.rank() == 2 ? new int[(int) array.size(0)] : new int[1]; | ||
|
||
if (array.rank() == 2) { | ||
val t = Nd4j.argMax(array, 1); | ||
for (int e = 0; e < t.length(); e++) | ||
result[e] = (int) t.getDouble(e); | ||
} else | ||
result[0] = (int) Nd4j.argMax(array, Integer.MAX_VALUE).getDouble(0); | ||
|
||
return result; | ||
} | ||
} |
49 changes: 49 additions & 0 deletions
49
...j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/Regression2dAdapter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
/******************************************************************************* | ||
* Copyright (c) 2015-2018 Skymind, Inc. | ||
* | ||
* This program and the accompanying materials are made available under the | ||
* terms of the Apache License, Version 2.0 which is available at | ||
* https://www.apache.org/licenses/LICENSE-2.0. | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
* License for the specific language governing permissions and limitations | ||
* under the License. | ||
* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
******************************************************************************/ | ||
|
||
package org.deeplearning4j.nn.adapters; | ||
|
||
import lombok.extern.slf4j.Slf4j; | ||
import lombok.val; | ||
import org.deeplearning4j.nn.api.OutputAdapter; | ||
import org.nd4j.base.Preconditions; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
|
||
/** | ||
* This OutputAdapter implementation takes single 2D nn output in, and returns JVM double[][] array | ||
* | ||
* @author raver119@gmail.com | ||
*/ | ||
@Slf4j | ||
public class Regression2dAdapter implements OutputAdapter<double[][]> { | ||
@Override | ||
public double[][] apply(INDArray... outputs) { | ||
Preconditions.checkArgument(outputs.length == 1, "Argmax adapter can have only 1 output"); | ||
val array = outputs[0]; | ||
Preconditions.checkArgument(array.rank() < 3, "Argmax adapter requires 2D or 1D output"); | ||
|
||
if (array.rank() == 2 && !array.isVector()) { | ||
return array.toDoubleMatrix(); | ||
} else { | ||
val result = new double[1][(int) array.length()]; | ||
|
||
for (int e = 0; e< array.length(); e++) | ||
result[0][e] = array.getDouble(e); | ||
|
||
return result; | ||
} | ||
} | ||
} |
42 changes: 42 additions & 0 deletions
42
deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/OutputAdapter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
/******************************************************************************* | ||
* Copyright (c) 2015-2018 Skymind, Inc. | ||
* | ||
* This program and the accompanying materials are made available under the | ||
* terms of the Apache License, Version 2.0 which is available at | ||
* https://www.apache.org/licenses/LICENSE-2.0. | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
* License for the specific language governing permissions and limitations | ||
* under the License. | ||
* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
******************************************************************************/ | ||
|
||
package org.deeplearning4j.nn.api; | ||
|
||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
|
||
import java.io.Serializable; | ||
|
||
/** | ||
* This interface describes entity used to conver neural network output to specified class. | ||
* I.e. INDArray -> int[] on the fly. | ||
* | ||
* PLEASE NOTE: Implementation will be used in workspace environment to avoid additional allocations during inference. | ||
* This means you shouldn't store or return the INDArrays passed to OutputAdapter.apply(INDArray...) directly. | ||
* If you need a copy of the output array, use standard network output methods, or use INDArray.detach() before storing the array | ||
* | ||
* @param <T> | ||
*/ | ||
public interface OutputAdapter<T> extends Serializable { | ||
|
||
/** | ||
* This method provides conversion from multiple INDArrays to T | ||
* | ||
* @param outputs | ||
* @return | ||
*/ | ||
T apply(INDArray... outputs); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
45 changes: 45 additions & 0 deletions
45
...g4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/adapters/ArgmaxAdapterTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/******************************************************************************* | ||
* Copyright (c) 2015-2018 Skymind, Inc. | ||
* | ||
* This program and the accompanying materials are made available under the | ||
* terms of the Apache License, Version 2.0 which is available at | ||
* https://www.apache.org/licenses/LICENSE-2.0. | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
* License for the specific language governing permissions and limitations | ||
* under the License. | ||
* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
******************************************************************************/ | ||
|
||
package org.deeplearning4j.nn.adapters; | ||
|
||
import lombok.val; | ||
import org.junit.Test; | ||
import org.nd4j.linalg.factory.Nd4j; | ||
|
||
import static org.junit.Assert.*; | ||
|
||
public class ArgmaxAdapterTest { | ||
@Test | ||
public void testSoftmax_2D_1() { | ||
val in = new double[][] {{1, 3, 2}, { 4, 5, 6}}; | ||
|
||
val adapter = new ArgmaxAdapter(); | ||
val result = adapter.apply(Nd4j.create(in)); | ||
|
||
assertArrayEquals(new int[]{1, 2}, result); | ||
} | ||
|
||
@Test | ||
public void testSoftmax_1D_1() { | ||
val in = new double[] {1, 3, 2}; | ||
|
||
val adapter = new ArgmaxAdapter(); | ||
val result = adapter.apply(Nd4j.create(in)); | ||
|
||
assertArrayEquals(new int[]{1}, result); | ||
} | ||
} |
46 changes: 46 additions & 0 deletions
46
...eplearning4j-nn/src/test/java/org/deeplearning4j/nn/adapters/Regression2dAdapterTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/******************************************************************************* | ||
* Copyright (c) 2015-2018 Skymind, Inc. | ||
* | ||
* This program and the accompanying materials are made available under the | ||
* terms of the Apache License, Version 2.0 which is available at | ||
* https://www.apache.org/licenses/LICENSE-2.0. | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
* License for the specific language governing permissions and limitations | ||
* under the License. | ||
* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
******************************************************************************/ | ||
|
||
package org.deeplearning4j.nn.adapters; | ||
|
||
import lombok.val; | ||
import org.junit.Test; | ||
import org.nd4j.linalg.factory.Nd4j; | ||
import org.nd4j.linalg.util.ArrayUtil; | ||
|
||
import static org.junit.Assert.*; | ||
|
||
public class Regression2dAdapterTest { | ||
@Test | ||
public void testRegressionAdapter_2D_1() throws Exception { | ||
val in = new double[][] {{1, 2, 3}, { 4, 5, 6}}; | ||
|
||
val adapter = new Regression2dAdapter(); | ||
val result = adapter.apply(Nd4j.create(in)); | ||
|
||
assertArrayEquals(ArrayUtil.flatten(in), ArrayUtil.flatten(result), 1e-5); | ||
} | ||
|
||
@Test | ||
public void testRegressionAdapter_2D_2() throws Exception { | ||
val in = new double[]{1, 2, 3}; | ||
|
||
val adapter = new Regression2dAdapter(); | ||
val result = adapter.apply(Nd4j.create(in)); | ||
|
||
assertArrayEquals(in, ArrayUtil.flatten(result), 1e-5); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.