/
extension.go
41 lines (35 loc) · 1.13 KB
/
extension.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
package cuda
import (
"fmt"
"github.com/pkg/errors"
"gorgonia.org/cu"
)
// this file relates to code that allows you to extend Engine
// LoadCUDAFunc loads a string representing a CUDA PTX file into the engine, giving it the universe of computing functions.
func (e *Engine) LoadCUDAFunc(moduleName, data string, funcs []string) (err error) {
fns := e.f
if fns == nil {
fns = make(map[string]cu.Function)
}
if err = cu.SetCurrentContext(e.c.Context.CUDAContext()); err != nil {
return errors.Wrapf(err, "Unable to set current context when loading module %q at device %v", moduleName, e.d)
}
var mod cu.Module
if mod, err = cu.LoadData(data); err != nil {
return errors.Wrapf(err, "Failed to load module %q data for Device %v context %x", moduleName, e.d, e.c)
}
for _, name := range funcs {
var fn cu.Function
if fn, err = mod.Function(name); err != nil {
return errors.Wrapf(err, "Unable to get function %q in Device %v context %x", name, e.d, e.c)
}
fqn := fmt.Sprintf("%v.%v", moduleName, name)
fns[fqn] = fn
}
if e.m == nil {
e.m = make(map[string]cu.Module)
}
e.m[moduleName] = mod
e.f = fns
return nil
}