diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp index 70925e32acce..fe0aafe9c990 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp @@ -133,6 +133,11 @@ namespace helpers { return ND4J_STATUS_OK; } + // Special case for TF compatibility + if((center && inHeight < 2) || (center && inWidth < 2)){ + center = false; + } + if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) || (center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) { // wrong input data diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TestPad.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/CustomOpTests.java similarity index 61% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TestPad.java rename to nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/CustomOpTests.java index 932600f43f23..4e533e81f273 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TestPad.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/CustomOpTests.java @@ -10,7 +10,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -public class TestPad { +public class CustomOpTests { @Test public void testPad(){ @@ -31,4 +31,22 @@ public void testPad(){ Nd4j.getExecutioner().exec(op); //Crash here } + + @Test + public void testResizeBilinearEdgeCase(){ + INDArray in = Nd4j.ones(DataType.FLOAT, 1, 1, 1, 3); + INDArray size = Nd4j.createFromArray(8, 8); + INDArray out = Nd4j.create(DataType.FLOAT, 1, 8, 8, 3); + + DynamicCustomOp op = DynamicCustomOp.builder("resize_bilinear") + .addInputs(in, size) + .addOutputs(out) + .addIntegerArguments(1) //1 = center. Though TF works with align_corners == false or true + .build(); + + Nd4j.getExecutioner().exec(op); + + INDArray exp = Nd4j.ones(DataType.FLOAT, 1, 8, 8, 3); + assertEquals(exp, out); + } }