forked from mumax/3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
module.go
52 lines (44 loc) · 1.18 KB
/
module.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
package cu
// This file implements loading of CUDA ptx modules
//#include <cuda.h>
import "C"
import (
"unsafe"
)
// Represents a CUDA CUmodule, a reference to executable device code.
type Module uintptr
// Loads a compute module from file
func ModuleLoad(fname string) Module {
//fmt.Fprintln(os.Stderr, "driver.ModuleLoad", fname)
var mod C.CUmodule
err := Result(C.cuModuleLoad(&mod, C.CString(fname)))
if err != SUCCESS {
panic(err)
}
return Module(uintptr(unsafe.Pointer(mod)))
}
// Loads a compute module from string
func ModuleLoadData(image string) Module {
var mod C.CUmodule
err := Result(C.cuModuleLoadData(&mod, unsafe.Pointer(C.CString(image))))
if err != SUCCESS {
panic(err)
}
return Module(uintptr(unsafe.Pointer(mod)))
}
// Returns a Function handle.
func ModuleGetFunction(module Module, name string) Function {
var function C.CUfunction
err := Result(C.cuModuleGetFunction(
&function,
C.CUmodule(unsafe.Pointer(uintptr(module))),
C.CString(name)))
if err != SUCCESS {
panic(err)
}
return Function(uintptr(unsafe.Pointer(function)))
}
// Returns a Function handle.
func (m Module) GetFunction(name string) Function {
return ModuleGetFunction(m, name)
}