Skip to content

Commit

Permalink
[Core][Streams] New API to wrap native backend streams (#525)
Browse files Browse the repository at this point in the history
* [Core][Streams] New API to wrap native backend streams

* [Core][Streams] Small changes based on review feedback

* [OpenCL][Streams] Adding polyfill entry for clRetainCommandQueue

Co-authored-by: Kris Rowe <kris.rowe@anl.gov>
  • Loading branch information
noelchalmers and kris-rowe committed Dec 6, 2021
1 parent 5f5ec0f commit 1a4c374
Show file tree
Hide file tree
Showing 23 changed files with 111 additions and 16 deletions.
16 changes: 16 additions & 0 deletions include/occa/core/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,22 @@ namespace occa {
*/
stream createStream(const occa::json &props = occa::json());

/**
* @startDoc{wrapStream}
*
* Description:
* Wrap a native backend stream object inside a [[stream]] for the device.
* The simplest example would be on a `CUDA` device, where a pointer to a cuStream_t, created via cudaStreamCreate, is passed in.
*
* > Note that automatic garbage collection is not set for wrapped stream objects.
*
* Returns:
* The wrapped [[stream]]
*
* @endDoc
*/
stream wrapStream(void* ptr, const occa::json &props = occa::json());

/**
* @startDoc{getStream}
*
Expand Down
8 changes: 8 additions & 0 deletions src/core/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,14 @@ namespace occa {
return modeDevice->createStream(streamProperties(props));
}

stream device::wrapStream(void* ptr, const occa::json &props) {
assertInitialized();

occa::json streamProps = streamProperties(props);

return modeDevice->wrapStream(ptr, streamProps);
}

stream device::getStream() {
assertInitialized();
return modeDevice->currentStream;
Expand Down
1 change: 1 addition & 0 deletions src/occa/internal/core/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ namespace occa {

// |---[ Stream ]------------------
virtual modeStream_t* createStream(const occa::json &props) = 0;
virtual modeStream_t* wrapStream(void *ptr, const occa::json &props) = 0;

virtual streamTag tagStream() = 0;
virtual void waitFor(streamTag tag) = 0;
Expand Down
6 changes: 6 additions & 0 deletions src/occa/internal/modes/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ namespace occa {
return new stream(this, props, cuStream);
}

modeStream_t* device::wrapStream(void* ptr, const occa::json &props) {
OCCA_ERROR("A nullptr was passed to cuda::device::wrapStream",nullptr != ptr);
CUstream cuStream = *static_cast<CUstream*>(ptr);
return new stream(this, props, cuStream, true);
}

occa::streamTag device::tagStream() {
CUevent cuEvent = NULL;

Expand Down
1 change: 1 addition & 0 deletions src/occa/internal/modes/cuda/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ namespace occa {

//---[ Stream ]-------------------
virtual modeStream_t* createStream(const occa::json &props);
virtual modeStream_t* wrapStream(void* ptr, const occa::json &props);

virtual streamTag tagStream();
virtual void waitFor(streamTag tag);
Expand Down
15 changes: 9 additions & 6 deletions src/occa/internal/modes/cuda/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@ namespace occa {
namespace cuda {
stream::stream(modeDevice_t *modeDevice_,
const occa::json &properties_,
CUstream cuStream_) :
CUstream cuStream_, bool isWrapped_) :
modeStream_t(modeDevice_, properties_),
cuStream(cuStream_) {}
cuStream(cuStream_),
isWrapped(isWrapped_) {}

stream::~stream() {
OCCA_CUDA_DESTRUCTOR_ERROR(
"Device: freeStream",
cuStreamDestroy(cuStream)
);
if (!isWrapped) {
OCCA_CUDA_DESTRUCTOR_ERROR(
"Device: freeStream",
cuStreamDestroy(cuStream)
);
}
}
}
}
5 changes: 4 additions & 1 deletion src/occa/internal/modes/cuda/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ namespace occa {
public:
CUstream cuStream;

bool isWrapped;

stream(modeDevice_t *modeDevice_,
const occa::json &properties_,
CUstream cuStream_);
CUstream cuStream_,
bool isWrapped_=false);

virtual ~stream();
};
Expand Down
6 changes: 6 additions & 0 deletions src/occa/internal/modes/dpcpp/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ namespace occa
return new occa::dpcpp::stream(this, props, q);
}

modeStream_t* device::wrapStream(void* ptr, const occa::json &props) {
OCCA_ERROR("A nullptr was passed to dpcpp::device::wrapStream",nullptr != ptr);
::sycl::queue q = *static_cast<::sycl::queue*>(ptr);
return new stream(this, props, q);
}

occa::streamTag device::tagStream()
{
//@note: This creates a host event which will return immediately.
Expand Down
1 change: 1 addition & 0 deletions src/occa/internal/modes/dpcpp/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ namespace occa

//---[ Stream ]-------------------
virtual modeStream_t *createStream(const occa::json &props) override;
virtual modeStream_t* wrapStream(void* ptr, const occa::json &props) override;

virtual occa::streamTag tagStream() override;
virtual void waitFor(occa::streamTag tag) override;
Expand Down
6 changes: 6 additions & 0 deletions src/occa/internal/modes/hip/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ namespace occa {
return new stream(this, props, hipStream);
}

modeStream_t* device::wrapStream(void* ptr, const occa::json &props) {
OCCA_ERROR("A nullptr was passed to hip::device::wrapStream",nullptr != ptr);
hipStream_t hipStream = *static_cast<hipStream_t*>(ptr);
return new stream(this, props, hipStream);
}

occa::streamTag device::tagStream() {
hipEvent_t hipEvent = NULL;

Expand Down
1 change: 1 addition & 0 deletions src/occa/internal/modes/hip/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ namespace occa {

//---[ Stream ]-------------------
virtual modeStream_t* createStream(const occa::json &props);
virtual modeStream_t* wrapStream(void* ptr, const occa::json &props);

virtual streamTag tagStream();
virtual void waitFor(streamTag tag);
Expand Down
12 changes: 8 additions & 4 deletions src/occa/internal/modes/hip/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@ namespace occa {
namespace hip {
stream::stream(modeDevice_t *modeDevice_,
const occa::json &properties_,
hipStream_t hipStream_) :
hipStream_t hipStream_,
bool isWrapped_) :
modeStream_t(modeDevice_, properties_),
hipStream(hipStream_) {}
hipStream(hipStream_),
isWrapped(isWrapped_) {}

stream::~stream() {
OCCA_HIP_ERROR("Device: freeStream",
hipStreamDestroy(hipStream));
if (!isWrapped) {
OCCA_HIP_ERROR("Device: freeStream",
hipStreamDestroy(hipStream));
}
}
}
}
5 changes: 4 additions & 1 deletion src/occa/internal/modes/hip/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ namespace occa {
public:
hipStream_t hipStream;

bool isWrapped;

stream(modeDevice_t *modeDevice_,
const occa::json &properties_,
hipStream_t hipStream_);
hipStream_t hipStream_,
bool isWrapped_=false);

virtual ~stream();
};
Expand Down
6 changes: 6 additions & 0 deletions src/occa/internal/modes/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ namespace occa {
return new stream(this, props, metalCommandQueue);
}

modeStream_t* device::wrapStream(void* ptr, const occa::json &props) {
OCCA_ERROR("A nullptr was passed to metal::device::wrapStream",nullptr != ptr);
api::metal::commandQueue_t q = *static_cast<api::metal::commandQueue_t*>(ptr);
return new stream(this, props, q, true);
}

occa::streamTag device::tagStream() {
metal::stream &stream = (
*((metal::stream*) (currentStream.getModeStream()))
Expand Down
1 change: 1 addition & 0 deletions src/occa/internal/modes/metal/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace occa {

//---[ Stream ]-------------------
virtual modeStream_t* createStream(const occa::json &props);
virtual modeStream_t* wrapStream(void* ptr, const occa::json &props);

virtual streamTag tagStream();
virtual void waitFor(streamTag tag);
Expand Down
10 changes: 7 additions & 3 deletions src/occa/internal/modes/metal/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ namespace occa {
namespace metal {
stream::stream(modeDevice_t *modeDevice_,
const occa::json &properties_,
api::metal::commandQueue_t metalCommandQueue_) :
api::metal::commandQueue_t metalCommandQueue_,
bool isWrapped_) :
modeStream_t(modeDevice_, properties_),
metalCommandQueue(metalCommandQueue_) {}
metalCommandQueue(metalCommandQueue_),
isWrapped(isWrapped_) {}

stream::~stream() {
metalCommandQueue.free();
if (!isWrapped) {
metalCommandQueue.free();
}
}
}
}
5 changes: 4 additions & 1 deletion src/occa/internal/modes/metal/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ namespace occa {
public:
api::metal::commandQueue_t metalCommandQueue;

bool isWrapped;

stream(modeDevice_t *modeDevice_,
const occa::json &properties_,
api::metal::commandQueue_t metalCommandQueue_);
api::metal::commandQueue_t metalCommandQueue_,
bool isWrapped_=false);

virtual ~stream();
};
Expand Down
10 changes: 10 additions & 0 deletions src/occa/internal/modes/opencl/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ namespace occa {
return new stream(this, props, commandQueue);
}

modeStream_t* device::wrapStream(void* ptr, const occa::json &props) {
OCCA_ERROR("A nullptr was passed to opencl::device::wrapStream",nullptr != ptr);

cl_command_queue commandQueue = *static_cast<cl_command_queue*>(ptr);
OCCA_OPENCL_ERROR("Device: Retaining Command Queue",
clRetainCommandQueue(commandQueue));

return new stream(this, props, commandQueue);
}

occa::streamTag device::tagStream() {
cl_event clEvent;

Expand Down
1 change: 1 addition & 0 deletions src/occa/internal/modes/opencl/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace occa {

//---[ Stream ]-------------------
virtual modeStream_t* createStream(const occa::json &props);
virtual modeStream_t* wrapStream(void* ptr, const occa::json &props);

virtual streamTag tagStream();
virtual void waitFor(streamTag tag);
Expand Down
4 changes: 4 additions & 0 deletions src/occa/internal/modes/opencl/polyfill.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,10 @@ namespace occa {
return NULL;
}

inline cl_int clRetainCommandQueue(cl_command_queue command_queue) {
return OCCA_OPENCL_IS_NOT_ENABLED;
}

inline cl_int clReleaseCommandQueue(cl_command_queue command_queue) {
return OCCA_OPENCL_IS_NOT_ENABLED;
}
Expand Down
2 changes: 2 additions & 0 deletions src/occa/internal/modes/opencl/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ namespace occa {
public:
cl_command_queue commandQueue;

bool isWrapped;

stream(modeDevice_t *modeDevice_,
const occa::json &properties_,
cl_command_queue commandQueue_);
Expand Down
4 changes: 4 additions & 0 deletions src/occa/internal/modes/serial/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ namespace occa {
return new stream(this, props);
}

modeStream_t* device::wrapStream(void* ptr, const occa::json &props) {
return new stream(this, props);
}

occa::streamTag device::tagStream() {
return new occa::serial::streamTag(this, sys::currentTime());
}
Expand Down
1 change: 1 addition & 0 deletions src/occa/internal/modes/serial/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace occa {

//---[ Stream ]-------------------
virtual modeStream_t* createStream(const occa::json &props);
virtual modeStream_t* wrapStream(void* ptr, const occa::json &props);

virtual streamTag tagStream();
virtual void waitFor(streamTag tag);
Expand Down

0 comments on commit 1a4c374

Please sign in to comment.