From 5531a87657737411f4767b0c3e6ec547e8179f83 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 17 Dec 2018 13:47:06 +0300 Subject: [PATCH 01/19] initial commit --- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 31 ++++++++++++++++++ .../org/nd4j/linalg/api/ndarray/INDArray.java | 16 +++++++++- .../linalg/api/buffer/BaseDataBuffer.java | 32 +++++++++++++++++++ .../nd4j/linalg/api/buffer/DataBuffer.java | 16 +++++++++- 4 files changed, 93 insertions(+), 2 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index b22e63a4fc7a..262f002bcc09 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -112,6 +112,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { //protected transient DataBuffer stride; protected transient boolean compressed = false; + protected transient boolean released = false; + // this field holds jvm copy of shapeInfo protected transient JvmShapeInfo jvmShapeInfo; @@ -6463,4 +6465,33 @@ protected void validateNumericalArray(String opName){ if(dataType() == DataType.BOOL || dataType() == DataType.UTF8) throw new IllegalStateException("Cannot apply operation " + opName + " to array with " + dataType() + " datatype. Array shape: " + Arrays.toString(shape())); } + + @Override + public boolean closeable() { + if (released) + return false; + + // empty arrays have no buffer at all + if (isEmpty()) + return true; + + if (isView()) + return false; + + return data.closeable(); + } + + @Override + public void close() throws Exception { + // empty arrays have no buffer at all + if (released || isEmpty()) + return; + + if (!closeable()) + throw new ND4JIllegalStateException("Can't release view array"); + + data.close(); + + released = true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index 64c4c7152cda..16e5a6d522c4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -35,7 +35,7 @@ * * @author Adam Gibson */ -public interface INDArray extends Serializable { +public interface INDArray extends Serializable, AutoCloseable { /** * Returns the shape information debugging * information @@ -2686,4 +2686,18 @@ public interface INDArray extends Serializable { * @return */ boolean none(); + + /** + * This method checks, if this INDArray instalce can use close() method + * @return true if array can be released, false otherwise + */ + boolean closeable(); + + /** + * This method releases exclusive off-heap resources uses by this INDArray instance. + * If INDArray relies on shared resources, exception will be thrown instead + * + * PLEASE NOTE: This method is NOT safe by any means + */ + void close() throws Exception; } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index c799466d3698..431f96f7abf5 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.buffer.util.AllocUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.memory.MemoryWorkspace; +import org.nd4j.linalg.primitives.AtomicBoolean; import org.nd4j.linalg.primitives.AtomicDouble; import org.nd4j.linalg.primitives.Triple; import org.nd4j.linalg.util.ArrayUtil; @@ -94,6 +95,9 @@ public abstract class BaseDataBuffer implements DataBuffer { protected transient Long trackingPoint; protected transient boolean constant = false; + protected transient boolean released = false; + + protected transient AtomicBoolean referenced = new AtomicBoolean(false); public BaseDataBuffer() {} @@ -168,6 +172,7 @@ protected BaseDataBuffer(DataBuffer underlyingBuffer, long length, long offset) this.elementSize = (byte) underlyingBuffer.getElementSize(); this.underlyingLength = underlyingBuffer.underlyingLength(); this.wrappedDataBuffer = underlyingBuffer; + ((BaseDataBuffer) underlyingBuffer).referenced.compareAndSet(false, true); // Adding link to original databuffer if (underlyingBuffer.originalDataBuffer() == null) { @@ -2246,4 +2251,31 @@ public DataBuffer reallocate(long length) { public long capacity() { return pointer().capacity(); } + + @Override + public boolean closeable() { + if (released) + return false; + + if (wrappedDataBuffer != null && wrappedDataBuffer != this) + return false; + + if (referenced.get()) + return false; + + return true; + } + + @Override + public void close() throws Exception { + if (!closeable()) + throw new IllegalStateException("Can't release view data buffer"); + + release(); + } + + protected void release() { + this.pointer.deallocate(); + this.indexer = null; + } } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java index 8873634c73e9..9e648d683b50 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java @@ -32,7 +32,7 @@ * * @author Adam Gibson */ -public interface DataBuffer extends Serializable { +public interface DataBuffer extends Serializable, AutoCloseable { enum TypeEx { } @@ -681,4 +681,18 @@ enum AllocationMode { * @return the capacity of the databuffer * */ long capacity(); + + /** + * This method checks, if this DataBuffer instalce can use close() method + * @return true if DataBuffer can be released, false otherwise + */ + boolean closeable(); + + /** + * This method releases exclusive off-heap resources uses by this DataBuffer instance. + * If DataBuffer relies on shared resources, exception will be thrown instead + * + * PLEASE NOTE: This method is NOT safe by any means + */ + void close() throws Exception; } From 907ee5a9a34146c01123814178abb3ba644bef3a Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 17 Dec 2018 13:50:42 +0300 Subject: [PATCH 02/19] not closeable for IsAttached() --- .../main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java | 4 ++-- .../org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java | 5 +++++ .../main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java | 4 ++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 262f002bcc09..474e6204dc4f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -6468,7 +6468,7 @@ protected void validateNumericalArray(String opName){ @Override public boolean closeable() { - if (released) + if (released || isAttached()) return false; // empty arrays have no buffer at all @@ -6488,7 +6488,7 @@ public void close() throws Exception { return; if (!closeable()) - throw new ND4JIllegalStateException("Can't release view array"); + throw new ND4JIllegalStateException("Can't release this INDArray"); data.close(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 5b317454d25b..82eb94a62919 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -1648,6 +1648,11 @@ public long capacity() { return pointer.capacity(); } + @Override + protected void release() { + AtomicAllocator.getInstance().freeMemory(allocationPoint); + } + /* protected short fromFloat( float fval ) { int fbits = Float.floatToIntBits( fval ); diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index 431f96f7abf5..6153f94c67e6 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -2254,7 +2254,7 @@ public long capacity() { @Override public boolean closeable() { - if (released) + if (released || isAttached() || isConstant()) return false; if (wrappedDataBuffer != null && wrappedDataBuffer != this) @@ -2269,7 +2269,7 @@ public boolean closeable() { @Override public void close() throws Exception { if (!closeable()) - throw new IllegalStateException("Can't release view data buffer"); + throw new IllegalStateException("Can't release this data buffer"); release(); } From fdd50cc463172068bb66adcd8009146b36c96a33 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 17 Dec 2018 14:06:01 +0300 Subject: [PATCH 03/19] couple of simple tests --- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 2 +- .../api/ndarray/BaseSparseNDArrayCOO.java | 10 +++ .../org/nd4j/linalg/api/ndarray/INDArray.java | 2 +- .../cpu/nativecpu/SparseNDArrayCSR.java | 10 +++ .../nd4j/linalg/memory/CloseableTests.java | 78 +++++++++++++++++++ .../linalg/api/buffer/BaseDataBuffer.java | 2 +- .../nd4j/linalg/api/buffer/DataBuffer.java | 2 +- 7 files changed, 102 insertions(+), 4 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 474e6204dc4f..a071ad2be36e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -6482,7 +6482,7 @@ public boolean closeable() { } @Override - public void close() throws Exception { + public void close() { // empty arrays have no buffer at all if (released || isEmpty()) return; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java index 18e4c0149542..27949bb66dda 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java @@ -1308,6 +1308,16 @@ public boolean none() { return false; } + @Override + public boolean closeable() { + return false; + } + + @Override + public void close() { + + } + @Override public boolean isS() { return false; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index 16e5a6d522c4..85ab328e7e0f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -2699,5 +2699,5 @@ public interface INDArray extends Serializable, AutoCloseable { * * PLEASE NOTE: This method is NOT safe by any means */ - void close() throws Exception; + void close(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/SparseNDArrayCSR.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/SparseNDArrayCSR.java index aa26c1661c14..6c24957c4d5d 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/SparseNDArrayCSR.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/SparseNDArrayCSR.java @@ -162,6 +162,16 @@ public boolean none() { throw new UnsupportedOperationException(); } + @Override + public boolean closeable() { + return false; + } + + @Override + public void close() { + + } + @Override public boolean isS() { return false; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java new file mode 100644 index 000000000000..fefd8db5a4ec --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java @@ -0,0 +1,78 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.memory; + +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.cpu.nativecpu.NDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.indexing.NDArrayIndex; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * @author raver119@gmail.com + */ +@Slf4j +@RunWith(Parameterized.class) +public class CloseableTests extends BaseNd4jTest { + public CloseableTests(Nd4jBackend backend) { + super(backend); + } + + @Test + public void testSimpleRelease_1() { + val array = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5}); + assertTrue(array.closeable()); + + array.close(); + + assertFalse(array.closeable()); + } + + @Test + public void testCyclicRelease_1() { + for (int e = 0; e < 1000; e++) { + try (val array = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5})) { + array.addi(1.0f); + } + System.gc(); + } + } + + @Test + public void testViewRelease_1() { + val array = Nd4j.create(5, 5); + assertTrue(array.closeable()); + + val view = array.get(NDArrayIndex.point(1), NDArrayIndex.all()); + + assertFalse(array.closeable()); + assertFalse(view.closeable()); + } + + @Override + public char ordering() { + return 'c'; + } +} diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index 6153f94c67e6..34f8cf54050f 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -2267,7 +2267,7 @@ public boolean closeable() { } @Override - public void close() throws Exception { + public void close() { if (!closeable()) throw new IllegalStateException("Can't release this data buffer"); diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java index 9e648d683b50..10147633bbf0 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java @@ -694,5 +694,5 @@ enum AllocationMode { * * PLEASE NOTE: This method is NOT safe by any means */ - void close() throws Exception; + void close(); } From 0f4eca0978284940a6c9514ebbef582e0b5b081f Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 17 Dec 2018 14:20:32 +0300 Subject: [PATCH 04/19] one more test --- .../java/org/nd4j/linalg/memory/CloseableTests.java | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java index fefd8db5a4ec..98a5b77f3fd6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java @@ -22,6 +22,7 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.cpu.nativecpu.NDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -71,6 +72,16 @@ public void testViewRelease_1() { assertFalse(view.closeable()); } + @Test + public void testAttachedRelease_1() { + val wsconf = WorkspaceConfiguration.builder().build(); + + try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsconf, "haha72yjhfdfs")) { + val array = Nd4j.create(5, 5); + assertFalse(array.closeable()); + } + } + @Override public char ordering() { return 'c'; From e18632d1f2f78778e5acf6fe6333d4cc1e45cb5c Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 17 Dec 2018 14:31:29 +0300 Subject: [PATCH 05/19] CUDA backend --- .../linalg/jcublas/JcusparseNDArrayCSR.java | 10 ++++++++++ .../org/nd4j/linalg/memory/CloseableTests.java | 3 +-- .../nd4j/linalg/api/buffer/BaseDataBuffer.java | 18 +++++++++++++++--- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JcusparseNDArrayCSR.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JcusparseNDArrayCSR.java index 2545327502a7..d7c28d493bd0 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JcusparseNDArrayCSR.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JcusparseNDArrayCSR.java @@ -161,4 +161,14 @@ public boolean any() { public boolean none() { return false; } + + @Override + public boolean closeable() { + return false; + } + + @Override + public void close() { + + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java index 98a5b77f3fd6..d6140b5d8a4f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java @@ -23,7 +23,6 @@ import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; -import org.nd4j.linalg.cpu.nativecpu.NDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -68,7 +67,7 @@ public void testViewRelease_1() { val view = array.get(NDArrayIndex.point(1), NDArrayIndex.all()); - assertFalse(array.closeable()); + assertTrue(array.closeable()); assertFalse(view.closeable()); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index 34f8cf54050f..3c7eb3cf88f8 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -36,6 +36,7 @@ import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; +import java.util.ArrayList; import java.util.Collection; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -98,6 +99,7 @@ public abstract class BaseDataBuffer implements DataBuffer { protected transient boolean released = false; protected transient AtomicBoolean referenced = new AtomicBoolean(false); + protected transient Collection references = new ArrayList<>(); public BaseDataBuffer() {} @@ -173,6 +175,7 @@ protected BaseDataBuffer(DataBuffer underlyingBuffer, long length, long offset) this.underlyingLength = underlyingBuffer.underlyingLength(); this.wrappedDataBuffer = underlyingBuffer; ((BaseDataBuffer) underlyingBuffer).referenced.compareAndSet(false, true); + ((BaseDataBuffer) underlyingBuffer).references.add(this); // Adding link to original databuffer if (underlyingBuffer.originalDataBuffer() == null) { @@ -2260,22 +2263,31 @@ public boolean closeable() { if (wrappedDataBuffer != null && wrappedDataBuffer != this) return false; - if (referenced.get()) - return false; - return true; } + protected void markReleased() { + this.released = true; + + for (val r:references) + r.markReleased(); + } + @Override public void close() { if (!closeable()) throw new IllegalStateException("Can't release this data buffer"); + // notifying other databuffers that their underlying + for (val r:references) + r.markReleased(); + release(); } protected void release() { this.pointer.deallocate(); this.indexer = null; + this.pointer = null; } } From 5d935d8bbe6a2f4c18e33cdbd30b819c76170e05 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 18 Dec 2018 07:10:24 +0300 Subject: [PATCH 06/19] batched_gemm additional validation --- .../declarable/generic/blas/batched_gemm.cpp | 9 +++++ .../layers_tests/DeclarableOpsTests3.cpp | 34 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp b/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp index f8e82ff6a950..cb5014873453 100644 --- a/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp @@ -64,11 +64,15 @@ namespace nd4j { std::vector vB(batchSize); std::vector vC(batchSize); + auto firstType = INPUT_VARIABLE(0)->dataType(); for(int e = 0; e < batchSize; e++) { vA[e] = INPUT_VARIABLE(e+2); vB[e] = INPUT_VARIABLE(e+2+batchSize); vC[e] = OUTPUT_VARIABLE(e); + + REQUIRE_TRUE(firstType == vC[e]->dataType(), 0, "BatchedGemm: all inputs and outputs must have same data type"); + REQUIRE_TRUE(vA[e]->rankOf() == 2, 0, "BatchedGemm: batch %i, rank of A should be equal to 2", e); REQUIRE_TRUE(vB[e]->rankOf() == 2, 0, "BatchedGemm: batch %i, rank of B should be equal to 2", e); REQUIRE_TRUE(vC[e]->rankOf() == 2, 0, "BatchedGemm: batch %i, rank of C should be equal to 2", e); @@ -98,6 +102,11 @@ namespace nd4j { int ldC = INT_ARG(7); int batchSize = INT_ARG(8); + auto firstType = ArrayOptions::dataType(inputShape->at(0)); + for (int e = 1; e < block.width(); e++) { + REQUIRE_TRUE(firstType == ArrayOptions::dataType(inputShape->at(1)), 0, "BatchedGemm: all inputs must have same data type"); + } + if (!(M > 0 && N > 0 && K > 0 && ldA > 0 && ldB > 0 && ldC > 0 && batchSize > 0)) { Nd4jLong *newShape; ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index 8da563f08fc1..a9923a329a7e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -734,6 +734,8 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) { auto exp = MmulHelper::mmul(&x, &y); + exp->printShapeInfo("exp shape"); + nd4j::ops::batched_gemm op; auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -754,6 +756,38 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) { delete result; } +TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) { + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = NDArrayFactory::create('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + auto y = NDArrayFactory::create('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + + nd4j::ops::batched_gemm op; + try { + auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); + ASSERT_TRUE(false); + } catch (std::invalid_argument &e) { + // + } +} + +TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_2) { + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = NDArrayFactory::create('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + auto y = NDArrayFactory::create('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + + auto z = NDArrayFactory::create('c', {2, 3}); + + nd4j::ops::batched_gemm op; + try { + auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {&z}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}, {}); + ASSERT_TRUE(false); + } catch (std::invalid_argument &e) { + // + } +} + TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_1) { auto x= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); auto y= NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); From 2ae3704be4a4eb8ab7f79a2967cf0cbe461fca35 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 18 Dec 2018 07:31:08 +0300 Subject: [PATCH 07/19] normalize_moments same mode fix --- libnd4j/blas/cpu/NDArray.cpp | 11 ++++++++--- .../layers_tests/DeclarableOpsTests15.cpp | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/libnd4j/blas/cpu/NDArray.cpp b/libnd4j/blas/cpu/NDArray.cpp index 8bbd05b1ba83..af2e42c4cf12 100644 --- a/libnd4j/blas/cpu/NDArray.cpp +++ b/libnd4j/blas/cpu/NDArray.cpp @@ -2283,10 +2283,15 @@ void NDArray::applyScalarArr(nd4j::scalar::Ops op, const NDArray* scalar, NDArra target = this; if(target->_dataType != DataTypeUtils::pickPairwiseResultType(_shapeInfo, scalar->_shapeInfo) && !(target->_dataType == this->_dataType || target->_dataType == scalar->_dataType)) throw std::invalid_argument("NDArray::applyScalarArr method: wrong type of target array!"); - if (!Environment::getInstance()->isExperimentalBuild()) { - if (scalar->dataType() != this->dataType()); + + if (this->dataType() == scalar->dataType() || Environment::getInstance()->isExperimentalBuild()) + NativeOpExcutioner::execScalar(op, _buffer, _shapeInfo, target->_buffer, target->_shapeInfo, scalar->_buffer, scalar->_shapeInfo, extraParams); + else { + auto tmp = const_cast(scalar)->cast(this->dataType()); + NativeOpExcutioner::execScalar(op, _buffer, _shapeInfo, target->_buffer, target->_shapeInfo, tmp->_buffer, tmp->_shapeInfo, extraParams); + + delete tmp; } - NativeOpExcutioner::execScalar(op, _buffer, _shapeInfo, target->_buffer, target->_shapeInfo, scalar->_buffer, scalar->_shapeInfo, extraParams); } template diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index ea9198ae52a5..a6ef15d04dba 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -38,6 +38,20 @@ class DeclarableOpsTests15 : public testing::Test { } }; +TEST_F(DeclarableOpsTests15, Test_NormalizeMoments_1) { + auto d = NDArrayFactory::create('c', {10, 10}); + auto w = NDArrayFactory::create(10); + auto x = NDArrayFactory::create('c', {10}); + auto y = NDArrayFactory::create('c', {10}); + + auto z0 = NDArrayFactory::create('c', {10}); + auto z1 = NDArrayFactory::create('c', {10}); + + nd4j::ops::normalize_moments op; + auto result = op.execute({&w, &x, &y}, {&z0, &z1}, {1e-4}, {}, {}); + ASSERT_EQ(Status::OK(), result); +} + TEST_F(DeclarableOpsTests15, Test_Half_assign_1) { auto x = NDArrayFactory::create('c', {2, 5}); int y = 1; From 7ecda26c4ccb37fdaf774f7f43c21538a5e179b9 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 18 Dec 2018 09:26:57 +0300 Subject: [PATCH 08/19] - varargs for few Nd4j.createFromArray signatures - new method for memory tracking --- .../java/org/nd4j/linalg/factory/Nd4j.java | 14 ++--- .../linalg/memory/BasicMemoryManager.java | 2 +- .../org/nd4j/linalg/memory/MemoryManager.java | 7 +++ .../nd4j/jita/memory/CudaMemoryManager.java | 5 ++ .../cpu/nativecpu/CpuMemoryManager.java | 5 ++ .../java/org/nd4j/nativeblas/Nd4jCpu.java | 29 ++++++++++ .../nd4j/linalg/memory/AccountingTests.java | 54 +++++++++++++++++++ 7 files changed, 108 insertions(+), 8 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 3c7256f51155..4037146da334 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -6819,7 +6819,7 @@ public static INDArray create(@NonNull Collection strings, long[] shape, * @param array * @return 1D INDArray with DOUBLE data type */ - public static INDArray createFromArray(double[] array) { + public static INDArray createFromArray(double... array) { return create(array, new long[]{array.length}, DataType.DOUBLE); } @@ -6828,7 +6828,7 @@ public static INDArray createFromArray(double[] array) { * @param array * @return 1D INDArray with FLOAT data type */ - public static INDArray createFromArray(float[] array) { + public static INDArray createFromArray(float... array) { return create(array, new long[]{array.length}, DataType.FLOAT); } @@ -6837,7 +6837,7 @@ public static INDArray createFromArray(float[] array) { * @param array * @return 1D INDArray with INT32 data type */ - public static INDArray createFromArray(int[] array) { + public static INDArray createFromArray(int... array) { return create(array, new long[]{array.length}, DataType.INT); } @@ -6846,7 +6846,7 @@ public static INDArray createFromArray(int[] array) { * @param array * @return 1D INDArray with INT16 data type */ - public static INDArray createFromArray(short[] array) { + public static INDArray createFromArray(short... array) { return create(array, new long[]{array.length}, DataType.SHORT); } @@ -6855,7 +6855,7 @@ public static INDArray createFromArray(short[] array) { * @param array * @return 1D INDArray with INT8 data type */ - public static INDArray createFromArray(byte[] array) { + public static INDArray createFromArray(byte... array) { return create(array, new long[]{array.length}, DataType.BYTE); } @@ -6864,7 +6864,7 @@ public static INDArray createFromArray(byte[] array) { * @param array * @return 1D INDArray with INT64 data type */ - public static INDArray createFromArray(long[] array) { + public static INDArray createFromArray(long... array) { return create(array, new long[]{array.length}, DataType.LONG); } @@ -6873,7 +6873,7 @@ public static INDArray createFromArray(long[] array) { * @param array * @return 1D INDArray with BOOL data type */ - public static INDArray createFromArray(boolean[] array) { + public static INDArray createFromArray(boolean... array) { return create(array, new long[]{array.length}, DataType.BOOL); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/BasicMemoryManager.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/BasicMemoryManager.java index 633461df6758..fd727b79b8f8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/BasicMemoryManager.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/BasicMemoryManager.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.memory.abstracts.DummyWorkspace; +import java.util.Map; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicBoolean; @@ -57,7 +58,6 @@ public abstract class BasicMemoryManager implements MemoryManager { private ThreadLocal tempWorkspace = new ThreadLocal<>(); - /** * This method returns * PLEASE NOTE: Cache options diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/MemoryManager.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/MemoryManager.java index 0d98bf98bf0d..584ae7bfeaaf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/MemoryManager.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/MemoryManager.java @@ -185,4 +185,11 @@ public interface MemoryManager { * This method returns per-device bandwidth use for memory transfers */ Map getBandwidthUse(); + + /** + * This method returns number of bytes allocated on specified device + * @param deviceId + * @return + */ + long allocatedMemory(Integer deviceId); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java index b0a1e0f69ee6..0527c8b84797 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java @@ -263,4 +263,9 @@ public void memset(INDArray array) { public Map getBandwidthUse() { return null; } + + @Override + public long allocatedMemory(Integer deviceId) { + return 0; + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuMemoryManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuMemoryManager.java index d8a365ff884d..adcd92f4d25b 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuMemoryManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuMemoryManager.java @@ -105,4 +105,9 @@ public void memset(INDArray array) { public Map getBandwidthUse() { return null; } + + @Override + public long allocatedMemory(Integer deviceId) { + return Pointer.totalBytes(); + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index d17bdf0773d1..d1c4bf200aea 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -17708,6 +17708,35 @@ INLINEDEF _CUDA_HD void doPermuteShapeBuffer(Nd4jLong *shapeBuffer, int *rearran } // #endif + /** + * log_matrix_determinant op. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: *) with log determinant for all + * M x M matricies + */ + +// #if NOT_EXCLUDED(OP_log_matrix_determinant) + @Namespace("nd4j::ops") public static class log_matrix_determinant extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public log_matrix_determinant(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public log_matrix_determinant(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public log_matrix_determinant position(long position) { + return (log_matrix_determinant)super.position(position); + } + + public log_matrix_determinant() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + /** * matrix_inverse op. - make inverse for all 2D square matricies found in the input tensor * diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java new file mode 100644 index 000000000000..2d97351eaa70 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java @@ -0,0 +1,54 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.memory; + +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * @author raver119@gmail.com + */ +@Slf4j +@RunWith(Parameterized.class) +public class AccountingTests extends BaseNd4jTest { + public AccountingTests(Nd4jBackend backend) { + super(backend); + } + + @Test + public void testDetached_1() { + val array = Nd4j.createFromArray(1, 2, 3, 4, 5); + assertEquals(DataType.INT, array.dataType()); + + assertTrue(Nd4j.getMemoryManager().allocatedMemory(0) > 0L); + } + + @Override + public char ordering() { + return 'c'; + } +} From 8f3425380b9f1a79c6dc711779a2507a28d58a0f Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 18 Dec 2018 10:07:20 +0300 Subject: [PATCH 09/19] one more test --- nd4j/nd4j-backends/nd4j-tests/pom.xml | 6 ++--- .../nd4j/linalg/memory/AccountingTests.java | 25 +++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml index c4a68b4f111d..a52c45bd2dc5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml @@ -103,20 +103,20 @@ ${project.version} - + - + ch.qos.logback logback-classic diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java index 2d97351eaa70..2ada71c8e30a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java @@ -23,6 +23,9 @@ import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; +import org.nd4j.linalg.api.memory.enums.AllocationPolicy; +import org.nd4j.linalg.api.memory.enums.LearningPolicy; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -47,6 +50,28 @@ public void testDetached_1() { assertTrue(Nd4j.getMemoryManager().allocatedMemory(0) > 0L); } + @Test + public void testWorkspaceAccounting_1() { + val wsConf = WorkspaceConfiguration.builder() + .initialSize(10 * 1024 * 1024) + .policyAllocation(AllocationPolicy.STRICT) + .policyLearning(LearningPolicy.FIRST_LOOP) + .build(); + + val before = Nd4j.getMemoryManager().allocatedMemory(0); + + val workspace = Nd4j.getWorkspaceManager().createNewWorkspace(wsConf, "random_name_here"); + + val middle = Nd4j.getMemoryManager().allocatedMemory(0); + + Nd4j.getWorkspaceManager().destroyWorkspace(workspace); + + val after = Nd4j.getMemoryManager().allocatedMemory(0); + + assertTrue(middle > before); + assertTrue(after < middle); + } + @Override public char ordering() { return 'c'; From cc19ee2190a6a695afefd2d77df49962304b3595 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 18 Dec 2018 11:55:43 +0300 Subject: [PATCH 10/19] one more test --- .../nd4j/linalg/memory/AccountingTests.java | 28 +++++++- .../linalg/api/memory/AllocationsTracker.java | 71 +++++++++++++++++++ .../api/memory/DeviceAllocationsTracker.java | 48 +++++++++++++ .../api/memory/enums/AllocationKind.java | 33 +++++++++ 4 files changed, 178 insertions(+), 2 deletions(-) create mode 100644 nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java create mode 100644 nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/DeviceAllocationsTracker.java create mode 100644 nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationKind.java diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java index 2ada71c8e30a..d0a953e08f12 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java @@ -23,14 +23,15 @@ import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.memory.DeviceAllocationsTracker; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; +import org.nd4j.linalg.api.memory.enums.AllocationKind; import org.nd4j.linalg.api.memory.enums.AllocationPolicy; import org.nd4j.linalg.api.memory.enums.LearningPolicy; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * @author raver119@gmail.com @@ -72,6 +73,29 @@ public void testWorkspaceAccounting_1() { assertTrue(after < middle); } + @Test + public void testTracker_1() { + val tracker = new DeviceAllocationsTracker(); + + for (val e: AllocationKind.values()) { + for (int v = 1; v <= 100; v++) { + tracker.updateState(e, v); + } + + assertNotEquals(0, tracker.getState(e)); + } + + for (val e: AllocationKind.values()) { + for (int v = 1; v <= 100; v++) { + tracker.updateState(e, -v); + } + + assertEquals(0, tracker.getState(e)); + } + + + } + @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java new file mode 100644 index 000000000000..3897a521d13a --- /dev/null +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java @@ -0,0 +1,71 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.memory; + +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import lombok.var; +import org.nd4j.linalg.api.memory.enums.AllocationKind; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * This class provides methods for tracking different memory allocations + * @author raver119@gmail.com + */ +@Slf4j +public class AllocationsTracker { + private static final AllocationsTracker INSTANCE = new AllocationsTracker(); + private Map devices = new ConcurrentHashMap<>(); + + protected AllocationsTracker() { + + } + + public static AllocationsTracker getInstance() { + return INSTANCE; + } + + protected DeviceAllocationsTracker trackerForDevice(Integer deviceId) { + var tracker = devices.get(deviceId); + if (tracker == null) { + synchronized (this) { + tracker = devices.get(deviceId); + if (tracker == null) { + tracker = new DeviceAllocationsTracker(); + devices.put(deviceId, tracker); + } + } + } + + return tracker; + } + + public void markAllocated(AllocationKind kind, Integer deviceId, long bytes) { + val tracker = trackerForDevice(deviceId); + + tracker.updateState(kind, bytes); + } + + public void markReleased(AllocationKind kind, Integer deviceId, long bytes) { + val tracker = trackerForDevice(deviceId); + + tracker.updateState(kind, -bytes); + } +} diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/DeviceAllocationsTracker.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/DeviceAllocationsTracker.java new file mode 100644 index 000000000000..3e42d05e6ce9 --- /dev/null +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/DeviceAllocationsTracker.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.memory; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.nd4j.linalg.api.memory.enums.AllocationKind; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; + +/** + * @author raver119@gmail.com + */ +@Slf4j +public class DeviceAllocationsTracker { + private Map bytesMap = new HashMap<>(); + + public DeviceAllocationsTracker() { + for (val e:AllocationKind.values()) { + bytesMap.put(e, new AtomicLong(0)); + } + } + + public void updateState(@NonNull AllocationKind kind, long bytes) { + bytesMap.get(kind).addAndGet(bytes); + } + + public long getState(@NonNull AllocationKind kind) { + return bytesMap.get(kind).get(); + } +} diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationKind.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationKind.java new file mode 100644 index 000000000000..abb6235896ba --- /dev/null +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationKind.java @@ -0,0 +1,33 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.memory.enums; + +/** + * This enum describes different allocation kinds + * @author raver119@gmail.com + */ +public enum AllocationKind { + /** + * General allocations + */ + GENERAL, + + /** + * Allocations that will be never released, and reused during session + */ + CONSTANT, +} From 9056b94cfc5b51026e771020e7ebcbc1b6fd0101 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 18 Dec 2018 12:51:24 +0300 Subject: [PATCH 11/19] workspace attached --- .../java/org/nd4j/jita/memory/CudaMemoryManager.java | 7 ++++--- .../java/org/nd4j/jita/workspace/CudaWorkspace.java | 11 ++++++++--- .../java/org/nd4j/linalg/memory/AccountingTests.java | 12 ++++++------ .../nd4j/linalg/api/memory/AllocationsTracker.java | 9 +++++++++ 4 files changed, 27 insertions(+), 12 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java index 0527c8b84797..ecd1e6c575a9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java @@ -17,12 +17,15 @@ package org.nd4j.jita.memory; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.memory.AllocationsTracker; +import org.nd4j.linalg.api.memory.enums.AllocationKind; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.compression.CompressedDataBuffer; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -68,8 +71,6 @@ public Pointer allocate(long bytes, MemoryKind kind, boolean initialize) { return ptr;//allocator.getMemoryHandler().alloc(AllocationStatus.HOST, null, null, initialize).getHostPointer(); } else if (kind == MemoryKind.DEVICE) { Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().mallocDevice(bytes, null, 0); - - //log.info("Allocating {} bytes for device_{}", bytes, Nd4j.getAffinityManager().getDeviceForCurrentThread()); if (ptr == null) @@ -266,6 +267,6 @@ public Map getBandwidthUse() { @Override public long allocatedMemory(Integer deviceId) { - return 0; + return AllocationsTracker.getInstance().bytesOnDevice(deviceId); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java index a75d14f3283e..d7f520dee641 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java @@ -24,6 +24,7 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.memory.AllocationsTracker; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.enums.*; import org.nd4j.linalg.api.memory.pointers.PagedPointer; @@ -85,8 +86,10 @@ protected void init() { workspace.setHostPointer(new PagedPointer(ptr)); - if (workspaceConfiguration.getPolicyMirroring() != MirroringPolicy.HOST_ONLY) + if (workspaceConfiguration.getPolicyMirroring() != MirroringPolicy.HOST_ONLY) { workspace.setDevicePointer(new PagedPointer(memoryManager.allocate((bytes + SAFETY_OFFSET), MemoryKind.DEVICE, false))); + AllocationsTracker.getInstance().markAllocated(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), bytes + SAFETY_OFFSET); + } //log.info("Workspace [{}] initialized successfully", id); } @@ -100,7 +103,7 @@ public PagedPointer alloc(long requiredMemory, DataType type, boolean initialize @Override public synchronized void destroyWorkspace(boolean extended) { - currentSize.set(0); + val size = currentSize.getAndSet(0); reset(); if (extended) @@ -111,8 +114,10 @@ public synchronized void destroyWorkspace(boolean extended) { if (workspace.getHostPointer() != null) NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(workspace.getHostPointer()); - if (workspace.getDevicePointer() != null) + if (workspace.getDevicePointer() != null) { NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(workspace.getDevicePointer(), null); + AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), size + SAFETY_OFFSET); + } workspace.setDevicePointer(null); workspace.setHostPointer(null); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java index d0a953e08f12..1e7dc060e31a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java @@ -53,22 +53,24 @@ public void testDetached_1() { @Test public void testWorkspaceAccounting_1() { + val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); val wsConf = WorkspaceConfiguration.builder() .initialSize(10 * 1024 * 1024) .policyAllocation(AllocationPolicy.STRICT) .policyLearning(LearningPolicy.FIRST_LOOP) .build(); - val before = Nd4j.getMemoryManager().allocatedMemory(0); + val before = Nd4j.getMemoryManager().allocatedMemory(deviceId); val workspace = Nd4j.getWorkspaceManager().createNewWorkspace(wsConf, "random_name_here"); - val middle = Nd4j.getMemoryManager().allocatedMemory(0); + val middle = Nd4j.getMemoryManager().allocatedMemory(deviceId); - Nd4j.getWorkspaceManager().destroyWorkspace(workspace); + workspace.destroyWorkspace(true); - val after = Nd4j.getMemoryManager().allocatedMemory(0); + val after = Nd4j.getMemoryManager().allocatedMemory(deviceId); + log.info("Before: {}; Middle: {}; After: {}", before, middle, after); assertTrue(middle > before); assertTrue(after < middle); } @@ -92,8 +94,6 @@ public void testTracker_1() { assertEquals(0, tracker.getState(e)); } - - } @Override diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java index 3897a521d13a..2f2c6efac103 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java @@ -68,4 +68,13 @@ public void markReleased(AllocationKind kind, Integer deviceId, long bytes) { tracker.updateState(kind, -bytes); } + + public long bytesOnDevice(Integer deviceId) { + return bytesOnDevice(AllocationKind.GENERAL, deviceId); + } + + public long bytesOnDevice(AllocationKind kind, Integer deviceId) { + val tracker = trackerForDevice(deviceId); + return tracker.getState(kind); + } } From 97d408189edd4feeeb934b0616384f583c2c93c2 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 18 Dec 2018 14:14:02 +0300 Subject: [PATCH 12/19] one more test --- .../nd4j/linalg/memory/AccountingTests.java | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java index 1e7dc060e31a..63417c2b5dc0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java @@ -75,6 +75,35 @@ public void testWorkspaceAccounting_1() { assertTrue(after < middle); } + @Test + public void testWorkspaceAccounting_2() { + val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + val wsConf = WorkspaceConfiguration.builder() + .initialSize(0) + .policyAllocation(AllocationPolicy.STRICT) + .policyLearning(LearningPolicy.OVER_TIME) + .cyclesBeforeInitialization(3) + .build(); + + val before = Nd4j.getMemoryManager().allocatedMemory(deviceId); + + long middle1 = 0; + try (val workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsConf, "random_name_here")) { + val array = Nd4j.create(DataType.DOUBLE, 5, 5); + middle1 = Nd4j.getMemoryManager().allocatedMemory(deviceId); + } + + val middle2 = Nd4j.getMemoryManager().allocatedMemory(deviceId); + + Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); + + val after = Nd4j.getMemoryManager().allocatedMemory(deviceId); + + log.info("Before: {}; Middle1: {}; Middle2: {}; After: {}", before, middle1, middle2, after); + assertTrue(middle1 > before); + assertTrue(after < middle1); + } + @Test public void testTracker_1() { val tracker = new DeviceAllocationsTracker(); From de786e0dc9f08f3093d5529601c9d59903b880c1 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 18 Dec 2018 15:00:35 +0300 Subject: [PATCH 13/19] - legacy tracking - constant tracking --- .../constant/ProtectedCudaConstantHandler.java | 4 ++++ .../nd4j/jita/memory/impl/CudaDirectProvider.java | 5 +++++ .../org/nd4j/jita/workspace/CudaWorkspace.java | 4 +++- .../org/nd4j/linalg/memory/AccountingTests.java | 14 ++++++++++++++ 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java index 09546933f1d6..5a446e11e1b3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java @@ -28,6 +28,8 @@ import org.nd4j.jita.flow.FlowController; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.memory.AllocationsTracker; +import org.nd4j.linalg.api.memory.enums.AllocationKind; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.cache.ArrayDescriptor; import org.nd4j.linalg.cache.ConstantHandler; @@ -134,6 +136,8 @@ public synchronized long moveToConstantSpace(DataBuffer dataBuffer) { //logger.info("shape: " + point.getShape()); // and release device memory :) + AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, deviceId, requiredMemoryBytes); + long currentOffset = constantOffsets.get(deviceId).get(); CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); if (currentOffset + requiredMemoryBytes >= MAX_CONSTANT_LENGTH || requiredMemoryBytes > MAX_BUFFER_LENGTH) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaDirectProvider.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaDirectProvider.java index 85ab91206b10..b200b9d904c7 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaDirectProvider.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaDirectProvider.java @@ -25,6 +25,9 @@ import org.nd4j.jita.allocator.pointers.PointersPair; import org.nd4j.jita.allocator.utils.AllocationUtils; import org.nd4j.jita.memory.MemoryProvider; +import org.nd4j.linalg.api.memory.AllocationsTracker; +import org.nd4j.linalg.api.memory.enums.AllocationKind; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; import org.slf4j.Logger; @@ -100,6 +103,7 @@ public PointersPair malloc(AllocationShape shape, AllocationPoint point, Allocat // throw new RuntimeException("Device allocation happened"); + AllocationsTracker.getInstance().markAllocated(AllocationKind.GENERAL, deviceId, reqMem); Pointer pointer = nativeOps.mallocDevice(reqMem, null, 0); //log.info("Device [{}] allocation, Thread id: {}, ReqMem: {}, Pointer: {}", AtomicAllocator.getInstance().getDeviceId(), Thread.currentThread().getId(), reqMem, pointer != null ? pointer.address() : null); @@ -159,6 +163,7 @@ public void free(AllocationPoint point) { // log.info("Deallocating {} bytes on [DEVICE]", reqMem); NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), reqMem); long result = nativeOps.freeDevice(point.getPointers().getDevicePointer(), new CudaPointer(0)); if (result == 0) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java index d7f520dee641..c50f6f38d090 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java @@ -322,7 +322,9 @@ protected void clearPinnedAllocations(boolean extended) { log.info("deleting external host allocation "); } - pinnedAllocationsSize.addAndGet(pair.getRequiredMemory() * -1); + val sizez = pair.getRequiredMemory() * -1; + pinnedAllocationsSize.addAndGet(sizez); + //AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), sizez); } else { break; } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java index 63417c2b5dc0..22dc5163703e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java @@ -51,6 +51,20 @@ public void testDetached_1() { assertTrue(Nd4j.getMemoryManager().allocatedMemory(0) > 0L); } + @Test + public void testDetached_2() { + val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + + val before = Nd4j.getMemoryManager().allocatedMemory(deviceId); + + val array = Nd4j.createFromArray(1, 2, 3, 4, 5); + assertEquals(DataType.INT, array.dataType()); + + val after = Nd4j.getMemoryManager().allocatedMemory(deviceId); + + assertTrue(after > before); + } + @Test public void testWorkspaceAccounting_1() { val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); From d9f71a50727851f097daf8936732810aa7c2033c Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 18 Dec 2018 15:26:49 +0300 Subject: [PATCH 14/19] cpu constant space --- .../cpu/nativecpu/cache/ConstantBuffersCache.java | 11 +++++++++-- .../java/org/nd4j/linalg/memory/AccountingTests.java | 2 ++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/cache/ConstantBuffersCache.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/cache/ConstantBuffersCache.java index a672b6d29a6b..778e01cc23d3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/cache/ConstantBuffersCache.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/cache/ConstantBuffersCache.java @@ -18,6 +18,8 @@ import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.memory.AllocationsTracker; +import org.nd4j.linalg.api.memory.enums.AllocationKind; import org.nd4j.linalg.cache.ArrayDescriptor; import org.nd4j.linalg.cache.BasicConstantHandler; import org.nd4j.linalg.factory.Nd4j; @@ -56,6 +58,7 @@ public DataBuffer getConstantBuffer(int[] array, DataType dataType) { buffersCache.put(descriptor, buffer); bytes.addAndGet(array.length * Nd4j.sizeOfDataType()); + AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, array.length * Nd4j.sizeOfDataType(dataType)); } return buffer; } @@ -75,6 +78,7 @@ public DataBuffer getConstantBuffer(boolean[] array, DataType dataType) { buffersCache.put(descriptor, buffer); bytes.addAndGet(array.length * Nd4j.sizeOfDataType()); + AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, array.length * Nd4j.sizeOfDataType(dataType)); } return buffer; } @@ -94,6 +98,7 @@ public DataBuffer getConstantBuffer(double[] array, DataType dataType) { buffersCache.put(descriptor, buffer); bytes.addAndGet(array.length * Nd4j.sizeOfDataType()); + AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, array.length * Nd4j.sizeOfDataType(dataType)); } return buffer; } @@ -112,7 +117,8 @@ public DataBuffer getConstantBuffer(float[] array, DataType dataType) { counter.incrementAndGet(); buffersCache.put(descriptor, buffer); - bytes.addAndGet(array.length * Nd4j.sizeOfDataType()); + bytes.addAndGet(array.length * Nd4j.sizeOfDataType(dataType)); + AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, array.length * Nd4j.sizeOfDataType(dataType)); } return buffer; } @@ -131,7 +137,8 @@ public DataBuffer getConstantBuffer(long[] array, DataType dataType) { counter.incrementAndGet(); buffersCache.put(descriptor, buffer); - bytes.addAndGet(array.length * Nd4j.sizeOfDataType()); + bytes.addAndGet(array.length * Nd4j.sizeOfDataType(dataType)); + AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, array.length * Nd4j.sizeOfDataType(dataType)); } return buffer; } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java index 22dc5163703e..52d1a40f7dbe 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java @@ -23,6 +23,7 @@ import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.memory.AllocationsTracker; import org.nd4j.linalg.api.memory.DeviceAllocationsTracker; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.enums.AllocationKind; @@ -63,6 +64,7 @@ public void testDetached_2() { val after = Nd4j.getMemoryManager().allocatedMemory(deviceId); assertTrue(after > before); + assertTrue(AllocationsTracker.getInstance().bytesOnDevice(AllocationKind.CONSTANT, Nd4j.getAffinityManager().getDeviceForCurrentThread()) > 0); } @Test From d1eccdd41c34592ae09892a90bc4580109b7f886 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 18 Dec 2018 16:56:46 +0300 Subject: [PATCH 15/19] minor tweaks --- .../main/java/org/nd4j/jita/workspace/CudaWorkspace.java | 7 ++++++- .../test/java/org/nd4j/linalg/memory/AccountingTests.java | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java index c50f6f38d090..879fcae16466 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java @@ -218,7 +218,9 @@ public PagedPointer alloc(long requiredMemory, MemoryKind kind, DataType type, b //pointer.setLeaked(true); pointer.isLeaked(); - externalAllocations.add(new PointersPair(null, pointer)); + val pp = new PointersPair(null, pointer); + pp.setRequiredMemory(requiredMemory); + externalAllocations.add(pp); return pointer; } else { @@ -352,6 +354,9 @@ protected void clearExternalAllocations() { if (isDebug.get()) log.info("deleting external device allocation... "); + + val sizez = pair.getRequiredMemory(); + AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), sizez); } } } catch (Exception e) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java index 52d1a40f7dbe..82e9a310f246 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java @@ -58,7 +58,7 @@ public void testDetached_2() { val before = Nd4j.getMemoryManager().allocatedMemory(deviceId); - val array = Nd4j.createFromArray(1, 2, 3, 4, 5); + val array = Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7); assertEquals(DataType.INT, array.dataType()); val after = Nd4j.getMemoryManager().allocatedMemory(deviceId); From f822e1a47b57c25f93c43b1169158ad382051909 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 18 Dec 2018 18:17:49 +0300 Subject: [PATCH 16/19] important pico fix --- .../main/java/org/nd4j/jita/memory/impl/CudaDirectProvider.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaDirectProvider.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaDirectProvider.java index b200b9d904c7..1526576df9bf 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaDirectProvider.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaDirectProvider.java @@ -163,7 +163,7 @@ public void free(AllocationPoint point) { // log.info("Deallocating {} bytes on [DEVICE]", reqMem); NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), reqMem); + AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, point.getDeviceId(), reqMem); long result = nativeOps.freeDevice(point.getPointers().getDevicePointer(), new CudaPointer(0)); if (result == 0) From 0aeef0aa942aae2737c9d7e2792d35a5eea6fe1a Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 18 Dec 2018 20:19:12 +0300 Subject: [PATCH 17/19] cpu workspace tracking --- .../cpu/nativecpu/CpuMemoryManager.java | 5 +++-- .../nativecpu/DirectShapeInfoProvider.java | 6 ++++-- .../cpu/nativecpu/workspace/CpuWorkspace.java | 11 +++++++++-- nd4j/nd4j-backends/nd4j-tests/pom.xml | 16 ++++++++-------- .../nd4j/linalg/memory/AccountingTests.java | 19 +++++++++++++++++++ 5 files changed, 43 insertions(+), 14 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuMemoryManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuMemoryManager.java index adcd92f4d25b..3a168a240a40 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuMemoryManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuMemoryManager.java @@ -19,6 +19,8 @@ import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.bytedeco.javacpp.Pointer; +import org.nd4j.linalg.api.memory.AllocationsTracker; +import org.nd4j.linalg.api.memory.enums.AllocationKind; import org.nd4j.linalg.api.memory.enums.MemoryKind; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -50,7 +52,6 @@ public Pointer allocate(long bytes, MemoryKind kind, boolean initialize) { //log.info("Allocating {} bytes at MemoryManager", bytes); - if (initialize) Pointer.memset(ptr, 0, bytes); @@ -108,6 +109,6 @@ public Map getBandwidthUse() { @Override public long allocatedMemory(Integer deviceId) { - return Pointer.totalBytes(); + return Pointer.totalBytes() + AllocationsTracker.getInstance().bytesOnDevice(deviceId); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/DirectShapeInfoProvider.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/DirectShapeInfoProvider.java index 6a58218324c7..ecaf830c5071 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/DirectShapeInfoProvider.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/DirectShapeInfoProvider.java @@ -18,6 +18,8 @@ import lombok.extern.slf4j.Slf4j; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.memory.AllocationsTracker; +import org.nd4j.linalg.api.memory.enums.AllocationKind; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.primitives.Pair; @@ -60,8 +62,8 @@ public Pair createShapeInformation(long[] shape, long[] stri Pair buffer = super.createShapeInformation(shape, stride, elementWiseStride, order, extras); longCache.put(descriptor, buffer); - bytes.addAndGet(buffer.getFirst().length() * 4 * 2); - + bytes.addAndGet(buffer.getFirst().length() * 8 * 2); + AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT,0, buffer.getFirst().length() * 8 * 2); return buffer; } else return longCache.get(descriptor); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java index 41c8798859af..a3f437087248 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java @@ -18,10 +18,13 @@ import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.bytedeco.javacpp.LongPointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; +import org.nd4j.linalg.api.memory.AllocationsTracker; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; +import org.nd4j.linalg.api.memory.enums.AllocationKind; import org.nd4j.linalg.api.memory.enums.LocationPolicy; import org.nd4j.linalg.api.memory.enums.MemoryKind; import org.nd4j.linalg.api.memory.pointers.PagedPointer; @@ -68,6 +71,7 @@ protected void init() { log.info("Allocating [{}] workspace of {} bytes...", id, currentSize.get()); workspace.setHostPointer(new PagedPointer(memoryManager.allocate(currentSize.get() + SAFETY_OFFSET, MemoryKind.HOST, true))); + AllocationsTracker.getInstance().markAllocated(AllocationKind.GENERAL, 0, currentSize.get() + SAFETY_OFFSET); } } else if (workspaceConfiguration.getPolicyLocation() == LocationPolicy.MMAP) { long flen = tempFile.length(); @@ -129,7 +133,7 @@ public synchronized void destroyWorkspace(boolean extended) { if (isDebug.get()) log.info("Destroying workspace..."); - currentSize.set(0); + val sizez = currentSize.getAndSet(0); hostOffset.set(0); deviceOffset.set(0); @@ -139,8 +143,11 @@ public synchronized void destroyWorkspace(boolean extended) { clearPinnedAllocations(extended); if (workspaceConfiguration.getPolicyLocation() == LocationPolicy.RAM) { - if (workspace.getHostPointer() != null) + if (workspace.getHostPointer() != null) { NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(workspace.getHostPointer()); + + AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, 0, sizez); + } } else if (workspaceConfiguration.getPolicyLocation() == LocationPolicy.MMAP) { if (workspace.getHostPointer() != null) NativeOpsHolder.getInstance().getDeviceNativeOps().munmapFile(null, mmap, tempFile.length()); diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml index a52c45bd2dc5..220aafa69b35 100644 --- a/nd4j/nd4j-backends/nd4j-tests/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml @@ -103,20 +103,20 @@ ${project.version} - org.nd4j - nd4j-cuda-10.0 + nd4j-native ${project.version} + ch.qos.logback logback-classic diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java index 82e9a310f246..5c72800951cb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java @@ -120,6 +120,25 @@ public void testWorkspaceAccounting_2() { assertTrue(after < middle1); } + @Test + public void testManualDeallocation_1() { + val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + val before = Nd4j.getMemoryManager().allocatedMemory(deviceId); + + val array = Nd4j.createFromArray(new byte[] {1, 2, 3}); + + val middle = Nd4j.getMemoryManager().allocatedMemory(deviceId); + + array.close(); + + val after = Nd4j.getMemoryManager().allocatedMemory(deviceId); + + assertTrue(middle > before); + + // <= here just because possible cache activation + assertTrue(after <= middle); + } + @Test public void testTracker_1() { val tracker = new DeviceAllocationsTracker(); From 1720d445388fcccf0d11459f792e7eaf9eab3791 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 18 Dec 2018 20:28:28 +0300 Subject: [PATCH 18/19] separate tracking for workspace allocations --- .../main/java/org/nd4j/jita/memory/CudaMemoryManager.java | 2 +- .../java/org/nd4j/linalg/cpu/nativecpu/CpuMemoryManager.java | 2 +- .../nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java | 4 ++-- .../org/nd4j/linalg/api/memory/enums/AllocationKind.java | 5 +++++ 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java index ecd1e6c575a9..2b0cb6cfe7e1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java @@ -267,6 +267,6 @@ public Map getBandwidthUse() { @Override public long allocatedMemory(Integer deviceId) { - return AllocationsTracker.getInstance().bytesOnDevice(deviceId); + return AllocationsTracker.getInstance().bytesOnDevice(AllocationKind.GENERAL, deviceId) + AllocationsTracker.getInstance().bytesOnDevice(AllocationKind.WORKSPACE, deviceId); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuMemoryManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuMemoryManager.java index 3a168a240a40..8af56286d4ad 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuMemoryManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuMemoryManager.java @@ -109,6 +109,6 @@ public Map getBandwidthUse() { @Override public long allocatedMemory(Integer deviceId) { - return Pointer.totalBytes() + AllocationsTracker.getInstance().bytesOnDevice(deviceId); + return Pointer.totalBytes() + AllocationsTracker.getInstance().bytesOnDevice(AllocationKind.GENERAL, deviceId) + AllocationsTracker.getInstance().bytesOnDevice(AllocationKind.WORKSPACE, deviceId); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java index a3f437087248..3e5eb1d49187 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java @@ -71,7 +71,7 @@ protected void init() { log.info("Allocating [{}] workspace of {} bytes...", id, currentSize.get()); workspace.setHostPointer(new PagedPointer(memoryManager.allocate(currentSize.get() + SAFETY_OFFSET, MemoryKind.HOST, true))); - AllocationsTracker.getInstance().markAllocated(AllocationKind.GENERAL, 0, currentSize.get() + SAFETY_OFFSET); + AllocationsTracker.getInstance().markAllocated(AllocationKind.WORKSPACE, 0, currentSize.get() + SAFETY_OFFSET); } } else if (workspaceConfiguration.getPolicyLocation() == LocationPolicy.MMAP) { long flen = tempFile.length(); @@ -146,7 +146,7 @@ public synchronized void destroyWorkspace(boolean extended) { if (workspace.getHostPointer() != null) { NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(workspace.getHostPointer()); - AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, 0, sizez); + AllocationsTracker.getInstance().markReleased(AllocationKind.WORKSPACE, 0, sizez); } } else if (workspaceConfiguration.getPolicyLocation() == LocationPolicy.MMAP) { if (workspace.getHostPointer() != null) diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationKind.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationKind.java index abb6235896ba..a0326f1b7aab 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationKind.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationKind.java @@ -30,4 +30,9 @@ public enum AllocationKind { * Allocations that will be never released, and reused during session */ CONSTANT, + + /** + * Allocations for workspaces + */ + WORKSPACE, } From bd60ee6640a303b7d14dcaac3ef4b83bbc721d90 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 19 Dec 2018 09:06:00 +0300 Subject: [PATCH 19/19] temporary ignored test pack --- .../src/test/java/org/nd4j/linalg/memory/AccountingTests.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java index 5c72800951cb..0add48687454 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java @@ -18,6 +18,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -38,6 +39,7 @@ * @author raver119@gmail.com */ @Slf4j +@Ignore @RunWith(Parameterized.class) public class AccountingTests extends BaseNd4jTest { public AccountingTests(Nd4jBackend backend) {