-
Notifications
You must be signed in to change notification settings - Fork 5
/
curand.go
245 lines (209 loc) · 8.51 KB
/
curand.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
package curand
/*
#include <curand.h>
*/
import "C"
import (
"runtime"
"github.com/dereklstinson/gocudnn/gocu"
"github.com/dereklstinson/cutil"
)
//Generator is a random number generator for the device.
type Generator struct {
w *gocu.Worker
generator C.curandGenerator_t
gentype RngType
gogc bool
}
//CreateGeneratorEx creates a generater where the operations will flow through a worker
func CreateGeneratorEx(w *gocu.Worker, gentype RngType) *Generator {
if w == nil {
return CreateGenerator(gentype)
}
var g = new(Generator)
g.gentype = gentype
g.gogc = true
g.w = w
err := w.Work(func() error {
errsub := curandstatus(C.curandCreateGenerator(&g.generator, g.gentype.c())).error("NewGenerator-create")
if errsub != nil {
return errsub
}
runtime.SetFinalizer(g, curandDestroyGenerator)
return nil
})
if err != nil {
panic(err)
}
return g
}
//CreateGenerator creates a Generator
func CreateGenerator(gentype RngType) *Generator {
var generator C.curandGenerator_t
err := curandstatus(C.curandCreateGenerator(&generator, gentype.c())).error("NewGenerator-create")
if err != nil {
panic(err)
}
g := &Generator{
generator: generator,
gentype: gentype,
gogc: true,
}
runtime.SetFinalizer(g, curandDestroyGenerator)
return g
}
func curandDestroyGenerator(g *Generator) error {
return curandstatus(C.curandDestroyGenerator(g.generator)).error("curandDestroyGenerator")
}
//Destroy destroys the random generator. This doesn't do anything right now
func (c *Generator) Destroy() error {
if c.gogc {
return nil
}
return curandstatus(C.curandDestroyGenerator(c.generator)).error("(c *Generator) Destroy()")
}
//SetStream sets the a cuda stream for the curand generator
func (c *Generator) SetStream(stream gocu.Streamer) error {
if c.w != nil {
return c.w.Work(func() error {
return curandstatus(C.curandSetStream(c.generator, C.cudaStream_t(stream.Ptr()))).error("(c *Generator) SetStream")
})
}
return curandstatus(C.curandSetStream(c.generator, C.cudaStream_t(stream.Ptr()))).error("(c *Generator) SetStream")
}
//SetPsuedoSeed sets the seed for the curand generator
func (c *Generator) SetPsuedoSeed(seed uint64) error {
if c.w != nil {
return c.w.Work(func() error {
return curandstatus(C.curandSetPseudoRandomGeneratorSeed(c.generator, C.ulonglong(seed))).error("(c *Generator) SetPsuedoSeed")
})
}
return curandstatus(C.curandSetPseudoRandomGeneratorSeed(c.generator, C.ulonglong(seed))).error("(c *Generator) SetPsuedoSeed")
}
//Uint fills mem with random numbers
/*
From cuRAND documentation:
The curandGenerate() function is used to generate pseudo- or quasirandom bits of output for XORWOW, MRG32k3a, MTGP32, MT19937, Philox_4x32_10 and SOBOL32 generators. Each output element is a 32-bit unsigned int where all bits are random. For SOBOL64 generators, each output element is a 64-bit unsigned long long where all bits are random. curandGenerate() returns an error for SOBOL64 generators. Use curandGenerateLongLong() to generate 64 bit integers with the SOBOL64 generators.
//values need to be stored as an uint32
*/
func (c *Generator) Uint(mem cutil.Mem, sizeinbytes uint) error {
if c.w != nil {
return c.w.Work(func() error {
return curandstatus(C.curandGenerate(c.generator, (*C.uint)(mem.Ptr()), C.size_t(sizeinbytes))).error("(c *Generator) Uint")
})
}
return curandstatus(C.curandGenerate(c.generator, (*C.uint)(mem.Ptr()), C.size_t(sizeinbytes))).error("(c *Generator) Uint")
}
//Uint64 fills mem with unsigned long long random numbers
/*
From cuRAND documentation:
The curandGenerate() function is used to generate pseudo- or quasirandom bits of output for XORWOW, MRG32k3a, MTGP32, MT19937, Philox_4x32_10 and SOBOL32 generators. Each output element is a 32-bit unsigned int where all bits are random. For SOBOL64 generators, each output element is a 64-bit unsigned long long where all bits are random. curandGenerate() returns an error for SOBOL64 generators. Use curandGenerateLongLong() to generate 64 bit integers with the SOBOL64 generators.
//values need to be stored as an uint32
*/
func (c *Generator) Uint64(mem cutil.Mem, sizeinbytes uint) error {
if c.w != nil {
return c.w.Work(func() error {
return curandstatus(C.curandGenerateLongLong(c.generator, (*C.ulonglong)(mem.Ptr()), C.size_t(sizeinbytes))).error("(c *Generator) Uint64")
})
}
return curandstatus(C.curandGenerateLongLong(c.generator, (*C.ulonglong)(mem.Ptr()), C.size_t(sizeinbytes))).error("(c *Generator) Uint64")
}
//UniformFloat32 - generates uniform distributions in float32
/*
from cuRAND documentation:
The curandGenerateUniform() function is used to generate uniformly distributed floating point values between 0.0 and 1.0, where 0.0 is excluded and 1.0 is included.
*/
func (c *Generator) UniformFloat32(mem cutil.Mem, sizeinbytes uint) error {
if c.w != nil {
return c.w.Work(func() error {
return curandstatus(C.curandGenerateUniform(c.generator, (*C.float)(mem.Ptr()), C.size_t(sizeinbytes))).error("(c *Generator) UniformFloat32")
})
}
return curandstatus(C.curandGenerateUniform(c.generator, (*C.float)(mem.Ptr()), C.size_t(sizeinbytes))).error("(c *Generator) UniformFloat32")
}
//NormalFloat32 -generates a Normal distribution in float32
/*
from cuRAND documentation:
The curandGenerateNormal() function is used to generate normally distributed floating point values with the given mean and standard deviation.
*/
func (c *Generator) NormalFloat32(mem cutil.Mem, sizeinbytes uint, mean, std float32) error {
if c.w != nil {
return c.w.Work(func() error {
return curandstatus(C.curandGenerateNormal(c.generator, (*C.float)(mem.Ptr()), C.size_t(sizeinbytes), C.float(mean), C.float(std))).error("(c *Generator) NormalFloat32")
})
}
return curandstatus(C.curandGenerateNormal(c.generator, (*C.float)(mem.Ptr()), C.size_t(sizeinbytes), C.float(mean), C.float(std))).error("(c *Generator) NormalFloat32")
}
/*
Generator Flags
*/
//RngType holds CURAND generator type flags
type RngType C.curandRngType_t
func (rng RngType) c() C.curandRngType_t {
return C.curandRngType_t(rng)
}
//Test sets and returns test flag
func (rng *RngType) Test() RngType { *rng = RngType(C.CURAND_RNG_TEST); return *rng }
//PseudoDefault sets and returns PseudoDefault flag
func (rng *RngType) PseudoDefault() RngType { *rng = RngType(C.CURAND_RNG_PSEUDO_DEFAULT); return *rng }
//PseudoXORWOW sets and returns PseudoXORWOW flag
func (rng *RngType) PseudoXORWOW() RngType { *rng = RngType(C.CURAND_RNG_PSEUDO_XORWOW); return *rng }
//PseudoMRG32K3A sets and returns PseudoMRG32K3A flag
func (rng *RngType) PseudoMRG32K3A() RngType {
*rng = RngType(C.CURAND_RNG_PSEUDO_MRG32K3A)
return *rng
}
//PseudoMTGP32 sets and returns PseudoMTGP32 flag
func (rng *RngType) PseudoMTGP32() RngType { *rng = RngType(C.CURAND_RNG_PSEUDO_MTGP32); return *rng }
//PseudoMT19937 sets and returns PseudoMT19937 flag
func (rng *RngType) PseudoMT19937() RngType { *rng = RngType(C.CURAND_RNG_PSEUDO_MT19937); return *rng }
//PseudoPhilox43210 sets and returns PseudoPhilox43210 flag
func (rng *RngType) PseudoPhilox43210() RngType {
*rng = RngType(C.CURAND_RNG_PSEUDO_PHILOX4_32_10)
return *rng
}
//QuasiDefault sets and returns QuasiDefault flag
func (rng *RngType) QuasiDefault() RngType { *rng = RngType(C.CURAND_RNG_QUASI_DEFAULT); return *rng }
//QuasiSOBOL32 sets and returns QuasiSOBOL32 flag
func (rng *RngType) QuasiSOBOL32() RngType { *rng = RngType(C.CURAND_RNG_QUASI_SOBOL32); return *rng }
//QuasiScrambledSOBOL32 sets and returns QuasiScrambledSOBOL32 flag
func (rng *RngType) QuasiScrambledSOBOL32() RngType {
*rng = RngType(C.CURAND_RNG_QUASI_SCRAMBLED_SOBOL32)
return *rng
}
//QuasiSOBOL64 sets and returns QuasiSOBOL64 flag
func (rng *RngType) QuasiSOBOL64() RngType { *rng = RngType(C.CURAND_RNG_QUASI_SOBOL64); return *rng }
//QuasiScrambledSOBOL64 sets and returns QuasiScrambledSOBOL64 flag
func (rng *RngType) QuasiScrambledSOBOL64() RngType {
*rng = RngType(C.CURAND_RNG_QUASI_SCRAMBLED_SOBOL64)
return *rng
}
func (rng RngType) String() string {
f := rng
var s string
switch rng {
case f.PseudoDefault():
s = "PseudoDefault"
case f.PseudoMRG32K3A():
s = "PseudoMRG32K3A"
case f.PseudoMTGP32():
s = "PseudoMTGP32"
case f.PseudoPhilox43210():
s = "PseudoPhilox43210"
case f.PseudoXORWOW():
s = "PseudoXORWOW"
case f.QuasiDefault():
s = "QuasiDefault"
case f.QuasiSOBOL32():
s = "QuasiSOBOL32"
case f.QuasiSOBOL64():
s = "QuasiSOBOL64"
case f.QuasiScrambledSOBOL32():
s = "QuasiScrambledSOBOL32"
case f.QuasiScrambledSOBOL64():
s = "QuasiScrambledSOBOL64"
default:
s = "Unsupported Type"
}
return "RngType: " + s
}