Permalink
Browse files

Able to call kernels with any number of arguments.

  • Loading branch information...
1 parent 8376803 commit 93bfcbcb3c288683a4bd03e2910c4f9617cc366a Jitu Das committed Mar 26, 2012
Showing with 91 additions and 12 deletions.
  1. +46 −7 src/function.cpp
  2. +5 −4 src/function.hpp
  3. +26 −0 src/mem.cpp
  4. +3 −0 src/mem.hpp
  5. +11 −1 test/test.js
View
@@ -1,3 +1,6 @@
+#include <node_buffer.h>
+#include <cstring>
+#include <cstdio>
#include "function.hpp"
#include "mem.hpp"
@@ -7,14 +10,16 @@ Persistent<FunctionTemplate> NodeCuda::Function::constructor_template;
void NodeCuda::Function::Initialize(Handle<Object> target) {
HandleScope scope;
-
+
Local<FunctionTemplate> t = FunctionTemplate::New(NodeCuda::Function::New);
constructor_template = Persistent<FunctionTemplate>::New(t);
constructor_template->InstanceTemplate()->SetInternalFieldCount(1);
constructor_template->SetClassName(String::NewSymbol("CudaFunction"));
+
+ NODE_SET_METHOD(target, "addToParamBuffer", NodeCuda::Function::AddToParamBuffer);
NODE_SET_PROTOTYPE_METHOD(constructor_template, "launch", NodeCuda::Function::LaunchKernel);
-
+
// Function objects can only be created by cuModuleGetFunction
}
@@ -35,19 +40,53 @@ Handle<Value> NodeCuda::Function::LaunchKernel(const Arguments& args) {
unsigned int gridDimX = gridDim->Get(0)->Uint32Value();
unsigned int gridDimY = gridDim->Get(1)->Uint32Value();
unsigned int gridDimZ = gridDim->Get(2)->Uint32Value();
-
+
Local<Array> blockDim = Local<Array>::Cast(args[1]);
unsigned int blockDimX = blockDim->Get(0)->Uint32Value();
unsigned int blockDimY = blockDim->Get(1)->Uint32Value();
unsigned int blockDimZ = blockDim->Get(2)->Uint32Value();
-
- Mem *mem = ObjectWrap::Unwrap<Mem>(args[2]->ToObject());
- void *cuArgs[] = { &mem->m_devicePtr };
+ Local<Object> buf = args[2]->ToObject();
+ char *pbuffer = Buffer::Data(buf);
+ size_t bufferSize = args[3]->IntegerValue();
+
+ void *cuExtra[] = {
+ CU_LAUNCH_PARAM_BUFFER_POINTER, pbuffer,
+ CU_LAUNCH_PARAM_BUFFER_SIZE, &bufferSize,
+ CU_LAUNCH_PARAM_END
+ };
CUresult error = cuLaunchKernel(pfunction->m_function,
gridDimX, gridDimY, gridDimZ,
blockDimX, blockDimY, blockDimZ,
- 0, 0, cuArgs, NULL);
+ 0, 0, NULL, cuExtra);
+
return scope.Close(Number::New(error));
}
+
+// From from NVIDIA C Programming Guide
+#define ALIGN_UP(offset, alignment) \
+ (((offset) + (alignment) - 1) & ~((alignment) - 1))
+
+Handle<Value> NodeCuda::Function::AddToParamBuffer(const Arguments& args) {
+ HandleScope scope;
+
+ Local<Object> dstbuf = args[0]->ToObject();
+ char *dst = Buffer::Data(dstbuf);
+
+ size_t bufferSize = args[1]->IntegerValue();
+
+ Local<Object> srcbuf = args[2]->ToObject();
+ char *src = Buffer::Data(srcbuf);
+ size_t srclen = Buffer::Length(srcbuf);
+
+ size_t alignment = args[3]->IntegerValue();
+
+ bufferSize = ALIGN_UP(bufferSize, alignment);
+ for (int i=0; i<srclen; i++)
+ dst[i+bufferSize] = src[i];
+ bufferSize += srclen;
+
+ return scope.Close(Number::New(bufferSize));
+}
+
View
@@ -15,16 +15,17 @@ class Function : public ObjectWrap {
static Persistent<FunctionTemplate> constructor_template;
static Handle<Value> LaunchKernel(const Arguments& args);
-
+ static Handle<Value> AddToParamBuffer(const Arguments& args);
+
Function() : ObjectWrap(), m_function(0) {}
-
+
~Function() {}
private:
static Handle<Value> New(const Arguments& args);
-
+
CUfunction m_function;
-
+
friend Handle<Value> Module::GetFunction(const Arguments&);
};
View
@@ -1,3 +1,4 @@
+#include <cstring>
#include <node_buffer.h>
#include "mem.hpp"
@@ -16,9 +17,12 @@ void Mem::Initialize(Handle<Object> target) {
// Mem objects can only be created by allocation functions
NODE_SET_METHOD(target, "memAlloc", Mem::Alloc);
NODE_SET_METHOD(target, "memAllocPitch", Mem::AllocPitch);
+
+ constructor_template->InstanceTemplate()->SetAccessor(String::New("devicePtr"), Mem::GetDevicePtr);
NODE_SET_PROTOTYPE_METHOD(constructor_template, "free", Mem::Free);
NODE_SET_PROTOTYPE_METHOD(constructor_template, "copyHtoD", Mem::CopyHtoD);
+ NODE_SET_PROTOTYPE_METHOD(constructor_template, "copyDtoH", Mem::CopyDtoH);
}
Handle<Value> Mem::New(const Arguments& args) {
@@ -84,3 +88,25 @@ Handle<Value> Mem::CopyHtoD(const Arguments& args) {
return scope.Close(Number::New(error));
}
+Handle<Value> Mem::CopyDtoH(const Arguments& args) {
+ HandleScope scope;
+ Mem *pmem = ObjectWrap::Unwrap<Mem>(args.This());
+
+ Local<Object> buf = args[0]->ToObject();
+ char *phost = Buffer::Data(buf);
+ size_t bytes = Buffer::Length(buf);
+
+ CUresult error = cuMemcpyDtoH(phost, pmem->m_devicePtr, bytes);
+
+ return scope.Close(Number::New(error));
+}
+
+Handle<Value> Mem::GetDevicePtr(Local<String> property, const AccessorInfo &info) {
+ HandleScope scope;
+ Mem *pmem = ObjectWrap::Unwrap<Mem>(info.Holder());
+ Buffer *ptrbuf = Buffer::New(sizeof(pmem->m_devicePtr));
+
+ memcpy(Buffer::Data(ptrbuf->handle_), &pmem->m_devicePtr, sizeof(pmem->m_devicePtr));
+
+ return scope.Close(ptrbuf->handle_);
+}
View
@@ -18,7 +18,10 @@ class Mem : public ObjectWrap {
static Handle<Value> AllocPitch(const Arguments& args);
static Handle<Value> Free(const Arguments& args);
static Handle<Value> CopyHtoD(const Arguments& args);
+ static Handle<Value> CopyDtoH(const Arguments& args);
+ static Handle<Value> GetDevicePtr(Local<String> property, const AccessorInfo &info);
+
Mem() : ObjectWrap(), m_devicePtr(0) {}
~Mem() {}
View
@@ -54,11 +54,21 @@ console.log("Loaded module:", cuModule);
var cuFunction = cuModule.getFunction("helloWorld");
console.log("Got function:", cuFunction);
+var paramBuffer = new Buffer(256);
+var paramBufferSize = 0;
+var argBuffer = new Buffer(8);
+paramBufferSize = cu.addToParamBuffer(paramBuffer, paramBufferSize, cuMem.devicePtr, 8);
+
//cuLaunchKernel
-var error = cuFunction.launch([3,1,1], [2,2,2], cuMem, 100);
+//var error = cuFunction.launch([3,1,1], [2,2,2], cuMem);
+var error = cuFunction.launch([3,1,1], [2,2,2], paramBuffer, paramBufferSize);
console.log("Launched kernel:", error);
+// cuMemcpyDtoH
+var error = cuMem.copyDtoH(buf);
+console.log("Copied buffer to host:", error);
+
//cuCtxSynchronize
var error = cuCtx.synchronize();
console.log("Context synchronize with error code: " + error);

0 comments on commit 93bfcbc

Please sign in to comment.