Skip to content

Commit

Permalink
Merge branch 'master' into sa_javacpp
Browse files Browse the repository at this point in the history
  • Loading branch information
farizrahman4u committed Mar 26, 2019
2 parents c587ed4 + 6525897 commit 39de698
Show file tree
Hide file tree
Showing 119 changed files with 4,727 additions and 1,335 deletions.
File renamed without changes.
Expand Up @@ -477,11 +477,14 @@ protected void fillNDArray(Mat image, INDArray ret) {
for (long k = 0; k < channels; k++) {
for (long i = 0; i < rows; i++) {
for (long j = 0; j < cols; j++) {
if (channels > 1) {
if (ret.rank() == 3) {
ret.putScalar(k, i, j, idx.getDouble(i, j, k));
} else {
} else if (ret.rank() == 4) {
ret.putScalar(1, k, i, j, idx.getDouble(i, j, k));
} else if (ret.rank() == 2) {
ret.putScalar(i, j, idx.getDouble(i, j));
}
} else
throw new ND4JIllegalStateException("NativeImageLoader expects 2D, 3D or 4D output array, but " + ret.rank() + "D array was given");
}
}
}
Expand Down
Expand Up @@ -16,6 +16,7 @@

package org.datavec.image.loader;

import lombok.val;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.indexer.UByteIndexer;
Expand All @@ -24,7 +25,9 @@
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.ImageWritable;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;

import java.awt.image.BufferedImage;
Expand Down Expand Up @@ -192,6 +195,46 @@ public void testAsRowVector() throws Exception {
assertNotEquals(0.0, array4.sum().getDouble(0), 0.0);
}

@Test
public void testDataTypes_1() throws Exception {
val dtypes = new DataType[]{DataType.FLOAT, DataType.HALF, DataType.SHORT, DataType.INT};

val dt = Nd4j.dataType();

for (val dtype: dtypes) {
Nd4j.setDataType(dtype);
int w3 = 123, h3 = 77, ch3 = 3;
val loader = new NativeImageLoader(h3, w3, ch3);
File f3 = new ClassPathResource("datavec-data-image/testimages/class0/2.jpg").getFile();
ImageWritable iw3 = loader.asWritable(f3);

val array = loader.asMatrix(iw3);

assertEquals(dtype, array.dataType());
}

Nd4j.setDataType(dt);
}

@Test
public void testDataTypes_2() throws Exception {
val dtypes = new DataType[]{DataType.FLOAT, DataType.HALF, DataType.SHORT, DataType.INT};

val dt = Nd4j.dataType();

for (val dtype: dtypes) {
Nd4j.setDataType(dtype);
int w3 = 123, h3 = 77, ch3 = 3;
val loader = new NativeImageLoader(h3, w3, 1);
File f3 = new ClassPathResource("datavec-data-image/testimages/class0/2.jpg").getFile();
val array = loader.asMatrix(f3);

assertEquals(dtype, array.dataType());
}

Nd4j.setDataType(dt);
}

@Test
public void testAsMatrix() throws Exception {
BufferedImage img1 = makeRandomBufferedImage(0, 0, 3);
Expand Down
Expand Up @@ -24,7 +24,11 @@
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.*;
import java.util.Map;
import java.util.HashMap;
import java.util.List;
import java.util.Collections;
import java.util.ServiceLoader;

/**
* Default internationalization implementation.<br>
Expand All @@ -50,9 +54,11 @@ public class DefaultI18N implements I18N {
public static final String FALLBACK_LANGUAGE = "en"; //use this if the specified language doesn't have the requested message

private static DefaultI18N instance;
private static Map<String, I18N> sessionInstances = Collections.synchronizedMap(new WeakHashMap<>());
private static Map<String, I18N> sessionInstances = Collections.synchronizedMap(new HashMap<>());
private static Throwable languageLoadingException = null;


private String currentLanguage = DEFAULT_LANGUAGE;
private Map<String, Map<String, String>> messagesByLanguage = new HashMap<>();

/**
Expand Down Expand Up @@ -88,8 +94,6 @@ public static synchronized I18N removeInstance(String sessionId) {
}


private String currentLanguage = DEFAULT_LANGUAGE;

private DefaultI18N() {
loadLanguages();
}
Expand Down
Expand Up @@ -57,7 +57,16 @@
import java.text.DateFormat;
import java.text.DecimalFormat;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import java.util.Set;
import java.util.HashSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Arrays;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

Expand Down Expand Up @@ -88,7 +97,7 @@ private enum ModelType {
}

private final int maxChartPoints; //Technically, the way it's set up: won't exceed 2*maxChartPoints
private Map<String, StatsStorage> knownSessionIDs = Collections.synchronizedMap(new WeakHashMap<>());
private Map<String, StatsStorage> knownSessionIDs = Collections.synchronizedMap(new HashMap<>());
private String currentSessionID;
private int currentWorkerIdx;
private Map<String, AtomicInteger> workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID
Expand Down Expand Up @@ -309,7 +318,6 @@ public synchronized void onDetach(StatsStorage statsStorage) {
addressSupplier.get(), s, statsStorage);
}
lastUpdateForSession.remove(s);
I18NProvider.removeInstance(s);
}
getDefaultSession();
}
Expand Down
Expand Up @@ -46,7 +46,6 @@
import org.deeplearning4j.ui.storage.impl.QueueStatsStorageListener;
import org.deeplearning4j.util.DL4JFileUtils;
import org.nd4j.linalg.function.Function;
import org.nd4j.linalg.function.Supplier;
import org.nd4j.linalg.primitives.Pair;
import play.Mode;
import play.api.routing.Router;
Expand Down
4 changes: 2 additions & 2 deletions libnd4j/UnderstandingGraph.md
Expand Up @@ -78,9 +78,9 @@ Despite being simple - it still provides you with time spent in various parts of

```c++
Environment::getInstance()->setProfiling(true);
auto graph = GraphExecutioner<float>::importFromFlatBuffers("./resources/ae_00.fb");
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb");

auto profile = GraphProfilingHelper<float>::profile(graph, 1000);
auto profile = GraphProfilingHelper::profile(graph, 1000);
profile->printOut();

delete graph;
Expand Down
3 changes: 2 additions & 1 deletion libnd4j/blas/NDArray.h
Expand Up @@ -1557,7 +1557,8 @@ namespace nd4j {
bool NDArray::isVector() const {
if (isEmpty())
return false;

if (rankOf() == 1)
return true;
return !isScalar() && shape::isVector(this->_shapeInfo);
}

Expand Down
28 changes: 28 additions & 0 deletions libnd4j/blas/NativeOpExcutioner.h
Expand Up @@ -94,6 +94,20 @@ class ND4J_EXPORT NativeOpExcutioner {
Nd4jLong *tadOnlyShapeInfoZ,
Nd4jLong *tadOffsetsZ);

static void execInverseBroadcast(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadOnlyShapeInfo,
Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,
Nd4jLong *tadOffsetsZ);


static void execBroadcastBool(int opNum,
void *x,
Expand All @@ -109,6 +123,20 @@ class ND4J_EXPORT NativeOpExcutioner {
Nd4jLong *tadOnlyShapeInfoZ,
Nd4jLong *tadOffsetsZ);

static void execInverseBroadcastBool(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadOnlyShapeInfo,
Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,
Nd4jLong *tadOffsetsZ);


/**
*
Expand Down

0 comments on commit 39de698

Please sign in to comment.