-
Notifications
You must be signed in to change notification settings - Fork 1
/
malloced.go
155 lines (136 loc) · 3.66 KB
/
malloced.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
package nvidia
import (
"unsafe"
"github.com/dereklstinson/gocudnn/cudart"
"github.com/dereklstinson/gocudnn/cudart/crtutil"
"github.com/dereklstinson/gocudnn/gocu"
"github.com/dereklstinson/cutil"
)
//Malloced is a pointer to some nvidia memory
type Malloced struct {
ptr unsafe.Pointer
numbytes uint
host bool
}
type reader struct {
m *Malloced
counter uint
}
type Worker interface {
Work(func() error) error
}
const defaultmemcopykind = cudart.MemcpyKind(4) //enum of 4 is the default memcopy kind
//Ptr is an unsafe pointer to nvidia memory
func (m *Malloced) Ptr() unsafe.Pointer {
if m == nil {
return nil
}
return m.ptr
}
//DPtr is a double pointer to nvidia device memory
func (m *Malloced) DPtr() *unsafe.Pointer {
if m == nil {
return nil
}
return &m.ptr
}
//OffSet returns the offset of the nvidia memory
func (m *Malloced) OffSet(bybytes uint) *Malloced {
if m.numbytes-bybytes < 1 {
return nil
}
offset := unsafe.Pointer(uintptr(m.ptr) + uintptr(bybytes))
return &Malloced{
ptr: offset,
numbytes: m.numbytes - bybytes,
host: m.host,
}
}
//SIB returns the size in bytes
func (m *Malloced) SIB() uint {
if m == nil {
return 0
}
return m.numbytes
}
//NewReadWriter creates a devio.Buffer for the malloced memory.
//If s is nil then copies will be synced if not then copies will be async
func (m *Malloced) NewReadWriter(s gocu.Streamer) *crtutil.ReadWriter {
return crtutil.NewReadWriter(m, m.numbytes, s)
}
//MallocHost allocates memory onto the host used by nvidia devices.
//Handler will set the device it is allocating to. Besure to set back if wanting to use another device
func MallocHost(w Worker, sizebytes uint) (x *Malloced, err error) {
x = new(Malloced)
x.numbytes = sizebytes
x.host = true
if w == nil {
return nil, cudart.MallocManagedGlobal(x, sizebytes)
}
err = w.Work(func() error {
return cudart.MallocManagedHost(x, sizebytes)
})
if err != nil {
return nil, err
}
return x, nil
}
type copier struct {
async bool
s gocu.Streamer
}
func (c copier) CopyHostToDevice(dest, src cutil.Pointer, sib uint) error {
if c.s != nil {
return cudart.MemcpyAsync(dest, src, sib, defaultmemcopykind, c.s)
}
return cudart.Memcpy(dest, src, sib, defaultmemcopykind)
}
func (c copier) CopyDeviceToHost(dest, src cutil.Pointer, sib uint) error {
if c.s != nil {
return cudart.MemcpyAsync(dest, src, sib, defaultmemcopykind, c.s)
}
return cudart.Memcpy(dest, src, sib, defaultmemcopykind)
}
func (c copier) Sync() error {
if c.s != nil {
return c.s.Sync()
}
return nil
}
//Memcpy is like cudart.Memcpy but it is using the cudart.Memcpykind{}.Default() flag
func Memcpy(dest, src cutil.Pointer, sizeinbytes uint) error {
// if w != nil {
// return w.Work(func() error {
// return cudart.MemCpy(dest, src, sizeinbytes, defaultmemcopykind)
// })
// }
return cudart.Memcpy(dest, src, sizeinbytes, defaultmemcopykind)
}
//SetAll sets the memory to whatever integer value passed
func (m *Malloced) SetAll(val int32) error {
//if w != nil {
// return w.Work(func() error {
// return cudart.Memset(m, val, m.numbytes)
// })
//}
return cudart.Memset(m, val, m.numbytes)
}
//MallocGlobal allocates memory to the nvidia gpu
//Handler will set the device it is allocating to. Besure to set back if wanting to use another device
func MallocGlobal(w Worker, sizebytes uint) (x *Malloced, err error) {
if w == nil {
x = new(Malloced)
x.numbytes = sizebytes
err = cudart.MallocManagedGlobal(x, sizebytes)
return x, err
}
err = w.Work(func() error {
x = new(Malloced)
x.numbytes = sizebytes
return cudart.MallocManagedGlobal(x, sizebytes)
})
if err != nil {
return nil, err
}
return x, nil
}