Skip to content

Commit

Permalink
[WIP] CUDA Context changes (#7434)
Browse files Browse the repository at this point in the history
- CUDA Context pool management tweaks to address #7431
- fix for PAIRWISE_BOOL operations extraArguments dtype
- fixes #7384
  • Loading branch information
raver119 committed Apr 3, 2019
1 parent eac1e58 commit 085cee3
Show file tree
Hide file tree
Showing 18 changed files with 260 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -517,16 +517,21 @@ public INDArray getUpdaterParameters() {

//Get threshold algorithm instances from each thread, and average them - they may have state that needs
// to be averaged and persisted, to avoid starting threshold adaption from scratch
EncodingHandler mh = (EncodingHandler) accum.getHandler();
ThresholdAlgorithm taAveraged = mh.getAverageThresholdAlgorithm();
val mh = (EncodingHandler) accum.getHandler();
val taAveraged = mh.getAverageThresholdAlgorithm();

// FIXME: fill stats here
return SharedTrainingResult.builder().aggregationsCount(1).scoreSum(originalModel.score())
val result = SharedTrainingResult.builder().aggregationsCount(1).scoreSum(originalModel.score())
.updaterStateArray(updaterState).listenerMetaData(new ArrayList<>())
.listenerStaticInfo(new ArrayList<>()).listenerUpdates(new ArrayList<>())
.minibatchesPerExecutor(Collections.singletonMap(SparkUtils.getSparkExecutorId(), iteratorDataSetCount.get().get()))
.thresholdAlgorithm(taAveraged)
.build();

// releasing Context here
Nd4j.getMemoryManager().releaseCurrentContext();

return result;
} else {
// blocking call right here, all non-master threads will be blocked here
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.deeplearning4j.spark.impl.paramavg;

import lombok.val;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorageRouter;
Expand Down Expand Up @@ -206,8 +207,17 @@ public ParameterAveragingTrainingResult processMinibatch(DataSet dataSet, MultiL

Nd4j.getExecutioner().commit();

if (isLast)
return getFinalResult(network);
if (isLast) {
val result = getFinalResult(network);

// releasing Context here
Nd4j.getMemoryManager().releaseCurrentContext();

return result;
}

// releasing Context here
Nd4j.getMemoryManager().releaseCurrentContext();

return null;
}
Expand Down Expand Up @@ -244,8 +254,17 @@ public ParameterAveragingTrainingResult processMinibatch(MultiDataSet dataSet, C

Nd4j.getExecutioner().commit();

if (isLast)
return getFinalResult(graph);
if (isLast) {
val result = getFinalResult(graph);

// releasing Context here
Nd4j.getMemoryManager().releaseCurrentContext();

return result;
}

// releasing Context here
Nd4j.getMemoryManager().releaseCurrentContext();

return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,9 @@ public MemoryWorkspace scopeOutOfWorkspaces() {
return new DummyWorkspace().notifyScopeEntered();//workspace.tagOutOfScopeUse();
}
}

@Override
public void releaseCurrentContext() {
// no-op
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,9 @@ public interface MemoryManager {
* @return
*/
long allocatedMemory(Integer deviceId);

/**
* This method releases Context (if current backend has one, sure)
*/
void releaseCurrentContext();
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,23 @@
import org.nd4j.linalg.jcublas.context.CudaContext;

/**
*
* This interface describes pool of CudaContext objects, used to execute kernels
* @author raver119@gmail.com
*/
public interface ContextPool {
/**
* This method returns CudaContext for given device
* @param deviceId
* @return
*/
CudaContext acquireContextForDevice(Integer deviceId);

@Deprecated
ContextPack acquireContextPackForDevice(Integer deviceId);

/**
* This method returns CudaContext to the pool for reuse
* @param context
*/
void releaseContext(CudaContext context);
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ public CudaContext acquireContextForDevice(Integer deviceId) {
return contextsPool.get(threadId);
}

@Override
public void releaseContext(CudaContext context) {
// no-op
}

protected CudaContext createNewStream(Integer deviceId) {
log.trace("Creating new stream for thread: [{}], device: [{}]...", Thread.currentThread().getId(), deviceId);
//JCuda.cudaSetDevice(deviceId);
Expand Down Expand Up @@ -301,7 +306,7 @@ protected void getDeviceBuffers(CudaContext context, int deviceId) {

context.syncOldStream();

Pointer allocationPointer = nativeOps.mallocDevice(1024 * 1024, new CudaPointer(deviceId), 0);
Pointer allocationPointer = nativeOps.mallocDevice(16384 * sizeOf, new CudaPointer(deviceId), 0);
if (allocationPointer == null)
throw new IllegalStateException("Can't allocate [DEVICE] allocation buffer memory!");

Expand All @@ -313,11 +318,11 @@ protected void getDeviceBuffers(CudaContext context, int deviceId) {
context.setBufferAllocation(allocationPointer);
context.setBufferReduction(reductionPointer);

Pointer specialPointer = nativeOps.mallocDevice(1024 * 1024 * sizeOf, new CudaPointer(deviceId), 0);
Pointer specialPointer = nativeOps.mallocDevice(16384 * sizeOf, new CudaPointer(deviceId), 0);
if (specialPointer == null)
throw new IllegalStateException("Can't allocate [DEVICE] special buffer memory!");

nativeOps.memsetAsync(specialPointer, 0, 65536 * sizeOf, 0, context.getOldStream());
nativeOps.memsetAsync(specialPointer, 0, 16384 * sizeOf, 0, context.getOldStream());

context.setBufferSpecial(specialPointer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import lombok.var;
import org.apache.commons.lang3.RandomUtils;
import org.nd4j.jita.allocator.context.ContextPack;
import org.nd4j.jita.allocator.garbage.GarbageResourceReference;
Expand All @@ -32,6 +34,7 @@
import org.nd4j.nativeblas.NativeOpsHolder;

import java.lang.ref.ReferenceQueue;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -52,7 +55,8 @@ public class LimitedContextPool extends BasicContextPool {

// pool of used pools
protected Map<Long, CudaContext> acquired = new ConcurrentHashMap<>();
protected AtomicInteger currentPoolSize = new AtomicInteger(0);
//protected AtomicInteger currentPoolSize = new AtomicInteger(0);
protected List<AtomicInteger> devicePoolSizes = new ArrayList<>();
protected Map<Integer, ResourceGarbageCollectorThread> collectors = new HashMap<>();
protected Map<Integer, ReferenceQueue<Thread>> queueMap = new HashMap<>();

Expand All @@ -61,24 +65,23 @@ public LimitedContextPool() {
int perDevicePool = CudaEnvironment.getInstance().getConfiguration().getPoolSize();

for (int i = 0; i < 4; i++) {
ReferenceQueue<Thread> queue = new ReferenceQueue<>();
ResourceGarbageCollectorThread collector = new ResourceGarbageCollectorThread(i, queue);
val queue = new ReferenceQueue<Thread>();
val collector = new ResourceGarbageCollectorThread(i, queue);
collector.start();

collectors.put(i, collector);
queueMap.put(i, queue);
}

fillPoolWithResources(perDevicePool, false);
currentPoolSize.set(perDevicePool);
}

protected void addResourcesToPool(int numResources) {
int device = AtomicAllocator.getInstance().getDeviceId();

cublasHandle_t handle = createNewCublasHandle();
val handle = createNewCublasHandle();
for (int cnt = 0; cnt < numResources; cnt++) {
CudaContext context = createNewStream(device);
val context = createNewStream(device);
context.initOldStream();
getDeviceBuffers(context, device);
context.setHandle(handle);
Expand All @@ -102,11 +105,12 @@ protected synchronized void fillPoolWithResources(int numResources, boolean rest
for (Integer device : devices) {
nativeOps.setDevice(new CudaPointer(device));
pool.put(device, new LinkedBlockingQueue<CudaContext>());
devicePoolSizes.add(new AtomicInteger(numResources));

cublasHandle_t handle = createNewCublasHandle();
cusolverDnHandle_t solverHandle = createNewSolverHandle();
val handle = createNewCublasHandle();
val solverHandle = createNewSolverHandle();
for (int cnt = 0; cnt < numResources; cnt++) {
CudaContext context = createNewStream(device);
val context = createNewStream(device);
context.initOldStream();
getDeviceBuffers(context, device);
context.setHandle(handle);
Expand All @@ -116,6 +120,8 @@ protected synchronized void fillPoolWithResources(int numResources, boolean rest

pool.get(device).add(context);
}


}

if (restoreDevice) {
Expand All @@ -125,8 +131,8 @@ protected synchronized void fillPoolWithResources(int numResources, boolean rest

@Override
public CudaContext acquireContextForDevice(Integer deviceId) {
long threadIdx = Thread.currentThread().getId();
CudaContext context = acquired.get(threadIdx);
val threadIdx = Thread.currentThread().getId();
var context = acquired.get(threadIdx);
if (context != null && deviceId == context.getDeviceId()) {
return context;
}
Expand All @@ -138,14 +144,12 @@ public CudaContext acquireContextForDevice(Integer deviceId) {
int col = RandomUtils.nextInt(0, collectors.size());
collectors.get(col);

GarbageResourceReference reference = new GarbageResourceReference(Thread.currentThread(), queueMap.get(col),
context, deviceId.intValue());
val reference = new GarbageResourceReference(Thread.currentThread(), queueMap.get(col), context, deviceId.intValue());
context.attachReference(reference);
//Garba reference = new GarbageBufferReference((BaseDataBuffer) buffer, queueMap.get(bucketId), point);
//point.attachReference(reference);

acquired.put(threadIdx, context);
context.setDeviceId(deviceId);
context.setThreadId(threadIdx);
return context;
} else {

Expand All @@ -158,27 +162,30 @@ public CudaContext acquireContextForDevice(Integer deviceId) {
int col = RandomUtils.nextInt(0, collectors.size());
collectors.get(col);

GarbageResourceReference reference = new GarbageResourceReference(Thread.currentThread(),
queueMap.get(col), context, deviceId.intValue());
val reference = new GarbageResourceReference(Thread.currentThread(), queueMap.get(col), context, deviceId.intValue());
context.attachReference(reference);

acquired.put(threadIdx, context);
context.setDeviceId(deviceId);
context.setThreadId(threadIdx);
} else {
if (currentPoolSize.get() < CudaEnvironment.getInstance().getConfiguration().getPoolSize()
* 3) {
addResourcesToPool(16);

// there's possible race condition, but we don't really care
currentPoolSize.addAndGet(16);
} else {
log.warn("Can't allocate new context, sleeping...");

Nd4j.getMemoryManager().invokeGc();
try {
Thread.sleep(500);
} catch (Exception e) {
//
val currentPoolSize = devicePoolSizes.get(deviceId);
synchronized (currentPoolSize) {
if (currentPoolSize.get() < CudaEnvironment.getInstance().getConfiguration().getPoolSize()) {
addResourcesToPool(16);

// there's possible race condition, but we don't really care
currentPoolSize.addAndGet(16);
log.warn("Initial pool size: {}; Current pool size: {}", CudaEnvironment.getInstance().getConfiguration().getPoolSize(), currentPoolSize.get());
} else {
log.warn("Can't allocate new context, sleeping...");

Nd4j.getMemoryManager().invokeGc();
try {
Thread.sleep(500);
} catch (Exception e) {
//
}
}
}
}
Expand All @@ -192,6 +199,7 @@ public CudaContext acquireContextForDevice(Integer deviceId) {
}

@Override
@Deprecated
public ContextPack acquireContextPackForDevice(Integer deviceId) {
return new ContextPack(acquireContextForDevice(deviceId));
}
Expand All @@ -201,6 +209,16 @@ public CudaContext getContextForDevice(Integer deviceId) {
return acquireContextForDevice(deviceId);
}

@Override
public void releaseContext(CudaContext context) {
val threadIdx = context.getThreadId();
val deviceId = context.getDeviceId();

context.setThreadId(-1);
acquired.remove(threadIdx);
pool.get(deviceId).add(context);
}

private class ResourceGarbageCollectorThread extends Thread implements Runnable {
private final ReferenceQueue<Thread> queue;

Expand All @@ -216,8 +234,12 @@ public void run() {
GarbageResourceReference reference = (GarbageResourceReference) queue.poll();
if (reference != null) {
CudaContext context = reference.getContext();
Long threadId = reference.getThreadId();
int deviceId = reference.getDeviceId();
val threadId = reference.getThreadId();
val deviceId = reference.getDeviceId();

// there's a chance context was already released
if (context.getThreadId() != threadId)
continue;

pool.get(deviceId).add(context);
acquired.remove(threadId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,22 @@ public Map<Integer, Long> getBandwidthUse() {
public long allocatedMemory(Integer deviceId) {
return AllocationsTracker.getInstance().bytesOnDevice(AllocationKind.GENERAL, deviceId) + AllocationsTracker.getInstance().bytesOnDevice(AllocationKind.WORKSPACE, deviceId);
}

@Override
public void releaseCurrentContext() {
// gettting context for this thread
val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();

// we dont want any remnaints below this line
context.syncOldStream();
context.syncSpecialStream();

if (context == null)
return;

val pool = AtomicAllocator.getInstance().getContextPool();

// push it back to pool
pool.releaseContext(context);
}
}
Loading

0 comments on commit 085cee3

Please sign in to comment.