From c499dc962f95f6426c89204c882640c53a3fc6b8 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 17 Jul 2019 15:19:38 +0300 Subject: [PATCH] - numpy import fix for CUDA (#64) - skip tagLocation for empty arrays Signed-off-by: raver119 --- .../jita/allocator/impl/AtomicAllocator.java | 2 +- .../jita/concurrency/CudaAffinityManager.java | 4 ++ .../jcublas/buffer/BaseCudaDataBuffer.java | 40 +++++-------------- .../nd4j/linalg/serde/NumpyFormatTests.java | 7 ++++ 4 files changed, 22 insertions(+), 31 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java index e8e20b9be049..bbefcb0fcd3a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java @@ -529,7 +529,7 @@ public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape require * @param objectId * @return */ - protected AllocationPoint getAllocationPoint(Long objectId) { + protected AllocationPoint getAllocationPoint(@NonNull Long objectId) { return allocationsMap.get(objectId); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java index 2b3e57f5cff4..0655e7cb1455 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java @@ -339,6 +339,10 @@ public DataBuffer replicateToDevice(Integer deviceId, DataBuffer buffer) { */ @Override public void tagLocation(INDArray array, Location location) { + // we can't tag empty arrays. + if (array.isEmpty()) + return; + if (location == Location.HOST) AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite(); else if (location == Location.DEVICE) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index c85202e35808..8d29e2b7b9ac 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -116,6 +116,7 @@ public BaseCudaDataBuffer(Pointer pointer, Indexer indexer, long length) { //cuda specific bits this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false); + this.trackingPoint = allocationPoint.getObjectId(); Nd4j.getDeallocatorService().pickObject(this); @@ -124,40 +125,19 @@ public BaseCudaDataBuffer(Pointer pointer, Indexer indexer, long length) { val perfD = PerformanceTracker.getInstance().helperStartTransaction(); - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getHostPointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream()); - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint.getHostPointer(), length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream()); + if (allocationPoint.getHostPointer() != null) { + NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getHostPointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream()); + NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint.getHostPointer(), length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream()); + } else { + NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()); + } context.getSpecialStream().synchronize(); - PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST); - PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE); - - this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length * getElementSize(), 0); - - switch (dataType()) { - case INT: { - setIndexer(IntIndexer.create(((CudaPointer) this.pointer).asIntPointer())); - } - break; - case FLOAT: { - setIndexer(FloatIndexer.create(((CudaPointer) this.pointer).asFloatPointer())); - } - break; - case DOUBLE: { - setIndexer(DoubleIndexer.create(((CudaPointer) this.pointer).asDoublePointer())); - } - break; - case HALF: { - setIndexer(ShortIndexer.create(((CudaPointer) this.pointer).asShortPointer())); - } - break; - case LONG: { - setIndexer(LongIndexer.create(((CudaPointer) this.pointer).asLongPointer())); - } - break; - } + if (allocationPoint.getHostPointer() != null) + PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST); - this.trackingPoint = allocationPoint.getObjectId(); + PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java index 217f9780f912..22cc2a753bfd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java @@ -310,6 +310,13 @@ public void testAbsentNumpyFile_1() throws Exception { INDArray act1 = Nd4j.createFromNpyFile(f); } + @Test + public void testAbsentNumpyFile_2() throws Exception { + val f = new File("c:/develop/batch-x-1.npy"); + INDArray act1 = Nd4j.createFromNpyFile(f); + log.info("Array shape: {}; sum: {};", act1.shape(), act1.sumNumber().doubleValue()); + } + @Override public char ordering() { return 'c';