-
-
Notifications
You must be signed in to change notification settings - Fork 776
/
generator.py
625 lines (514 loc) · 21.7 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
import atexit
import binascii
import functools
import operator
import os
import time
import numpy
import six
import cupy
from cupy import core
from cupy import cuda
from cupy.cuda import curand
from cupy.cuda import device
from cupy.random import _kernels
class RandomState(object):
"""Portable container of a pseudo-random number generator.
An instance of this class holds the state of a random number generator. The
state is available only on the device which has been current at the
initialization of the instance.
Functions of :mod:`cupy.random` use global instances of this class.
Different instances are used for different devices. The global state for
the current device can be obtained by the
:func:`cupy.random.get_random_state` function.
Args:
seed (None or int): Seed of the random number generator. See the
:meth:`~cupy.random.RandomState.seed` method for detail.
method (int): Method of the random number generator. Following values
are available::
cupy.cuda.curand.CURAND_RNG_PSEUDO_DEFAULT
cupy.cuda.curand.CURAND_RNG_XORWOW
cupy.cuda.curand.CURAND_RNG_MRG32K3A
cupy.cuda.curand.CURAND_RNG_MTGP32
cupy.cuda.curand.CURAND_RNG_MT19937
cupy.cuda.curand.CURAND_RNG_PHILOX4_32_10
"""
def __init__(self, seed=None, method=curand.CURAND_RNG_PSEUDO_DEFAULT):
self._generator = curand.createGenerator(method)
self.seed(seed)
def __del__(self):
# When createGenerator raises an error, _generator is not initialized
if hasattr(self, '_generator'):
curand.destroyGenerator(self._generator)
def _generate_normal(self, func, size, dtype, *args):
# curand functions below don't support odd size.
# * curand.generateNormal
# * curand.generateNormalDouble
# * curand.generateLogNormal
# * curand.generateLogNormalDouble
size = core.get_size(size)
element_size = six.moves.reduce(operator.mul, size, 1)
if element_size % 2 == 0:
out = cupy.empty(size, dtype=dtype)
func(self._generator, out.data.ptr, out.size, *args)
return out
else:
out = cupy.empty((element_size + 1,), dtype=dtype)
func(self._generator, out.data.ptr, out.size, *args)
return out[:element_size].reshape(size)
# NumPy compatible functions
def binomial(self, n, p, size=None, dtype=int):
"""Returns an array of samples drawn from the binomial distribution.
.. seealso::
:func:`cupy.random.binomial` for full documentation,
:meth:`numpy.random.RandomState.binomial`
"""
n, p = cupy.asarray(n), cupy.asarray(p)
if size is None:
size = cupy.broadcast(n, p).shape
y = cupy.zeros(shape=size, dtype=dtype)
_kernels.binomial_kernel(n, p, self.rk_seed, y)
if size is None:
self.rk_seed += 1
else:
self.rk_seed += numpy.prod(size)
return y
def laplace(self, loc=0.0, scale=1.0, size=None, dtype=float):
"""Returns an array of samples drawn from the laplace distribution.
.. seealso::
:func:`cupy.random.laplace` for full documentation,
:meth:`numpy.random.RandomState.laplace`
"""
x = self.uniform(size=size, dtype=dtype)
loc = cupy.asarray(loc, dtype)
scale = cupy.asarray(scale, dtype)
_kernels.laplace_kernel(x, loc, scale, x)
return x
def lognormal(self, mean=0.0, sigma=1.0, size=None, dtype=float):
"""Returns an array of samples drawn from a log normal distribution.
.. seealso::
:func:`cupy.random.lognormal` for full documentation,
:meth:`numpy.random.RandomState.lognormal`
"""
dtype = _check_and_get_dtype(dtype)
if dtype.char == 'f':
func = curand.generateLogNormal
else:
func = curand.generateLogNormalDouble
return self._generate_normal(func, size, dtype, mean, sigma)
def normal(self, loc=0.0, scale=1.0, size=None, dtype=float):
"""Returns an array of normally distributed samples.
.. seealso::
:func:`cupy.random.normal` for full documentation,
:meth:`numpy.random.RandomState.normal`
"""
dtype = _check_and_get_dtype(dtype)
if dtype.char == 'f':
func = curand.generateNormal
else:
func = curand.generateNormalDouble
return self._generate_normal(func, size, dtype, loc, scale)
def rand(self, *size, **kwarg):
"""Returns uniform random values over the interval ``[0, 1)``.
.. seealso::
:func:`cupy.random.rand` for full documentation,
:meth:`numpy.random.RandomState.rand`
"""
dtype = kwarg.pop('dtype', float)
if kwarg:
raise TypeError('rand() got unexpected keyword arguments %s'
% ', '.join(kwarg.keys()))
return self.random_sample(size=size, dtype=dtype)
def randn(self, *size, **kwarg):
"""Returns an array of standard normal random values.
.. seealso::
:func:`cupy.random.randn` for full documentation,
:meth:`numpy.random.RandomState.randn`
"""
dtype = kwarg.pop('dtype', float)
if kwarg:
raise TypeError('randn() got unexpected keyword arguments %s'
% ', '.join(kwarg.keys()))
return self.normal(size=size, dtype=dtype)
_1m_kernel = core.ElementwiseKernel(
'', 'T x', 'x = 1 - x', 'cupy_random_1_minus_x')
def random_sample(self, size=None, dtype=float):
"""Returns an array of random values over the interval ``[0, 1)``.
.. seealso::
:func:`cupy.random.random_sample` for full documentation,
:meth:`numpy.random.RandomState.random_sample`
"""
dtype = _check_and_get_dtype(dtype)
out = cupy.empty(size, dtype=dtype)
if dtype.char == 'f':
func = curand.generateUniform
else:
func = curand.generateUniformDouble
func(self._generator, out.data.ptr, out.size)
RandomState._1m_kernel(out)
return out
def _interval(self, mx, size):
"""Generate multiple integers independently sampled uniformly from ``[0, mx]``.
Args:
mx (int): Upper bound of the interval
size (None or int or tuple): Shape of the array or the scalar
returned.
Returns:
int or cupy.ndarray: If ``None``, an :class:`cupy.ndarray` with
shape ``()`` is returned.
If ``int``, 1-D array of length size is returned.
If ``tuple``, multi-dimensional array with shape
``size`` is returned.
Currently, only 32 bit integers can be sampled.
If 0 :math:`\\leq` ``mx`` :math:`\\leq` 0x7fffffff,
a ``numpy.int32`` array is returned.
If 0x80000000 :math:`\\leq` ``mx`` :math:`\\leq` 0xffffffff,
a ``numpy.uint32`` array is returned.
"""
if size is None:
return self._interval(mx, 1).reshape(())
elif isinstance(size, int):
size = (size, )
if mx == 0:
return cupy.zeros(size, dtype=numpy.int32)
if mx < 0:
raise ValueError(
'mx must be non-negative (actual: {})'.format(mx))
elif mx <= 0x7fffffff:
dtype = numpy.int32
elif mx <= 0xffffffff:
dtype = numpy.uint32
else:
raise ValueError(
'mx must be within uint32 range (actual: {})'.format(mx))
mask = (1 << mx.bit_length()) - 1
mask = cupy.array(mask, dtype=dtype)
n = functools.reduce(operator.mul, size, 1)
sample = cupy.empty((n,), dtype=dtype)
n_rem = n # The number of remaining elements to sample
ret = None
while n_rem > 0:
curand.generate(
self._generator, sample.data.ptr, sample.size)
# Drop the samples that exceed the upper limit
sample &= mask
success = sample <= mx
if ret is None:
# If the sampling has finished in the first iteration,
# just return the sample.
if success.all():
n_rem = 0
ret = sample
break
# Allocate the return array.
ret = cupy.empty((n,), dtype=dtype)
n_succ = min(n_rem, int(success.sum()))
ret[n - n_rem:n - n_rem + n_succ] = sample[success][:n_succ]
n_rem -= n_succ
assert n_rem == 0
return ret.reshape(size)
def seed(self, seed=None):
"""Resets the state of the random number generator with a seed.
.. seealso::
:func:`cupy.random.seed` for full documentation,
:meth:`numpy.random.RandomState.seed`
"""
if seed is None:
try:
seed_str = binascii.hexlify(os.urandom(8))
seed = numpy.uint64(int(seed_str, 16))
except NotImplementedError:
seed = numpy.uint64(time.clock() * 1000000)
else:
seed = numpy.asarray(seed).astype(numpy.uint64, casting='safe')
curand.setPseudoRandomGeneratorSeed(self._generator, seed)
curand.setGeneratorOffset(self._generator, 0)
self.rk_seed = numpy.uint32(seed)
def standard_normal(self, size=None, dtype=float):
"""Returns samples drawn from the standard normal distribution.
.. seealso::
:func:`cupy.random.standard_normal` for full documentation,
:meth:`numpy.random.RandomState.standard_normal`
"""
return self.normal(size=size, dtype=dtype)
def tomaxint(self, size=None):
"""Draws integers between 0 and max integer inclusive.
Args:
size (int or tuple of ints): Output shape.
Returns:
cupy.ndarray: Drawn samples.
.. seealso::
:meth:`numpy.random.RandomState.tomaxint`
"""
if size is None:
size = ()
sample = cupy.empty(size, dtype=cupy.int_)
# cupy.random only uses int32 random generator
size_in_int = sample.dtype.itemsize // 4
curand.generate(
self._generator, sample.data.ptr, sample.size * size_in_int)
# Disable sign bit
sample &= cupy.iinfo(cupy.int_).max
return sample
def uniform(self, low=0.0, high=1.0, size=None, dtype=float):
"""Returns an array of uniformly-distributed samples over an interval.
.. seealso::
:func:`cupy.random.uniform` for full documentation,
:meth:`numpy.random.RandomState.uniform`
"""
dtype = numpy.dtype(dtype)
rand = self.random_sample(size=size, dtype=dtype)
return dtype.type(low) + rand * dtype.type(high - low)
def choice(self, a, size=None, replace=True, p=None):
"""Returns an array of random values from a given 1-D array.
.. seealso::
:func:`cupy.random.choice` for full document,
:func:`numpy.random.choice`
"""
if a is None:
raise ValueError('a must be 1-dimensional or an integer')
if isinstance(a, cupy.ndarray) and a.ndim == 0:
raise NotImplementedError
if isinstance(a, six.integer_types):
a_size = a
if a_size <= 0:
raise ValueError('a must be greater than 0')
else:
a = cupy.array(a, copy=False)
if a.ndim != 1:
raise ValueError('a must be 1-dimensional or an integer')
else:
a_size = len(a)
if a_size == 0:
raise ValueError('a must be non-empty')
if p is not None:
p = cupy.array(p)
if p.ndim != 1:
raise ValueError('p must be 1-dimensional')
if len(p) != a_size:
raise ValueError('a and p must have same size')
if not (p >= 0).all():
raise ValueError('probabilities are not non-negative')
p_sum = cupy.sum(p).get()
if not numpy.allclose(p_sum, 1):
raise ValueError('probabilities do not sum to 1')
if size is None:
raise NotImplementedError
shape = size
size = numpy.prod(shape)
if not replace and p is None:
if a_size < size:
raise ValueError(
'Cannot take a larger sample than population when '
'\'replace=False\'')
if isinstance(a, six.integer_types):
indices = cupy.arange(a, dtype='l')
else:
indices = a.copy()
self.shuffle(indices)
return indices[:size].reshape(shape)
if not replace:
raise NotImplementedError
if p is not None:
p = cupy.broadcast_to(p, (size, a_size))
index = cupy.argmax(cupy.log(p) +
self.gumbel(size=(size, a_size)),
axis=1)
if not isinstance(shape, six.integer_types):
index = cupy.reshape(index, shape)
else:
index = self.randint(0, a_size, size=shape)
# Align the dtype with NumPy
index = index.astype(cupy.int64, copy=False)
if isinstance(a, six.integer_types):
return index
if index.ndim == 0:
return cupy.array(a[index], dtype=a.dtype)
return a[index]
def shuffle(self, a):
"""Returns a shuffled array.
.. seealso::
:func:`cupy.random.shuffle` for full document,
:func:`numpy.random.shuffle`
"""
if not isinstance(a, cupy.ndarray):
raise TypeError('The array must be cupy.ndarray')
if a.ndim == 0:
raise TypeError('An array whose ndim is 0 is not supported')
a[:] = a[self.permutation(len(a))]
def permutation(self, num):
"""Returns a permuted range."""
if not isinstance(num, six.integer_types):
raise TypeError('The data type of argument "num" must be integer')
sample = cupy.empty((num), dtype=numpy.int32)
curand.generate(self._generator, sample.data.ptr, num)
if 128 < num <= 32 * 1024 * 1024:
array = cupy.arange(num, dtype=numpy.int32)
# apply sort of cache blocking
block_size = 1 * 1024 * 1024
# The block size above is a value determined from the L2 cache size
# of GP100 (L2 cache size / size of int = 4MB / 4B = 1M). It may be
# better to change the value base on the L2 cache size of the GPU
# you use.
# When num > block_size, cupy kernel: _cupy_permutation is to be
# launched multiple times. However, it is observed that performance
# will be degraded if the launch count is too many. Therefore,
# the block size is adjusted so that launch count will not exceed
# twelve Note that this twelve is the value determined from
# measurement on GP100.
while num // block_size > 12:
block_size *= 2
for j_start in range(0, num, block_size):
j_end = j_start + block_size
_cupy_permutation()(array, sample, j_start, j_end, size=num)
else:
# When num > 32M, argsort is used, because it is faster than
# custom kernel. See https://github.com/cupy/cupy/pull/603.
array = cupy.argsort(sample)
return array
def gumbel(self, loc=0.0, scale=1.0, size=None, dtype=float):
"""Returns an array of samples drawn from a Gumbel distribution.
.. seealso::
:func:`cupy.random.gumbel` for full documentation,
:meth:`numpy.random.RandomState.gumbel`
"""
x = self.uniform(size=size, dtype=dtype)
loc = cupy.asarray(loc, dtype)
scale = cupy.asarray(scale, dtype)
# We use `1 - x` as input of `log` method to prevent overflow.
# It obeys numpy implementation.
_kernels.gumbel_kernel(x, loc, scale, x)
return x
def randint(self, low, high=None, size=None, dtype='l'):
"""Returns a scalar or an array of integer values over ``[low, high)``.
.. seealso::
:func:`cupy.random.randint` for full documentation,
:meth:`numpy.random.RandomState.randint`
"""
if high is None:
lo = 0
hi = low
else:
lo = low
hi = high
if lo >= hi:
raise ValueError('low >= high')
if lo < cupy.iinfo(dtype).min:
raise ValueError(
'low is out of bounds for {}'.format(cupy.dtype(dtype).name))
if hi > cupy.iinfo(dtype).max + 1:
raise ValueError(
'high is out of bounds for {}'.format(cupy.dtype(dtype).name))
diff = hi - lo - 1
if diff > cupy.iinfo(cupy.int32).max - cupy.iinfo(cupy.int32).min + 1:
raise NotImplementedError(
'Sampling from a range whose extent is larger than int32 '
'range is currently not supported')
x = self._interval(diff, size).astype(dtype, copy=False)
cupy.add(x, lo, out=x)
return x
def _cupy_permutation():
return core.ElementwiseKernel(
'raw int32 array, raw int32 sample, int32 j_start, int32 _j_end',
'',
'''
const int invalid = -1;
const int num = _ind.size();
int j = (sample[i] & 0x7fffffff) % num;
int j_end = _j_end;
if (j_end > num) j_end = num;
if (j == i || j < j_start || j >= j_end) continue;
// If a thread fails to do data swaping once, it changes j
// value using j_offset below and try data swaping again.
// This process is repeated until data swapping is succeeded.
// The j_offset is determined from the initial j
// (random number assigned to each thread) and the initial
// offset between j and i (ID of each thread).
// If a given number sequence in sample is really random,
// this j-update would not be necessary. This is work-around
// mainly to avoid potential eternal conflict when sample has
// rather synthetic number sequence.
int j_offset = ((2*j - i + num) % (num - 1)) + 1;
// A thread gives up to do data swapping if loop count exceed
// a threathod determined below. This is kind of safety
// mechanism to escape the eternal race condition, though I
// believe it never happens.
int loops = 256;
bool do_next = true;
while (do_next && loops > 0) {
// try to swap the contents of array[i] and array[j]
if (i != j) {
int val_j = atomicExch(&array[j], invalid);
if (val_j != invalid) {
int val_i = atomicExch(&array[i], invalid);
if (val_i != invalid) {
array[i] = val_j;
array[j] = val_i;
do_next = false;
// done
}
else {
// restore array[j]
array[j] = val_j;
}
}
}
j = (j + j_offset) % num;
loops--;
}
''',
'cupy_permutation',
)
def seed(seed=None):
"""Resets the state of the random number generator with a seed.
This function resets the state of the global random number generator for
the current device. Be careful that generators for other devices are not
affected.
Args:
seed (None or int): Seed for the random number generator. If ``None``,
it uses :func:`os.urandom` if available or :func:`time.clock`
otherwise. Note that this function does not support seeding by an
integer array.
"""
get_random_state().seed(seed)
# CuPy specific functions
_random_states = {}
@atexit.register
def reset_states():
global _random_states
_random_states = {}
def get_random_state():
"""Gets the state of the random number generator for the current device.
If the state for the current device is not created yet, this function
creates a new one, initializes it, and stores it as the state for the
current device.
Returns:
RandomState: The state of the random number generator for the
device.
"""
dev = cuda.Device()
rs = _random_states.get(dev.id, None)
if rs is None:
seed = os.getenv('CUPY_SEED')
if seed is None:
seed = os.getenv('CHAINER_SEED')
if seed is not None:
seed = numpy.uint64(int(seed))
rs = RandomState(seed)
rs = _random_states.setdefault(dev.id, rs)
return rs
def set_random_state(rs):
"""Sets the state of the random number generator for the current device.
Args:
state(RandomState): Random state to set for the current device.
"""
if not isinstance(rs, RandomState):
raise TypeError(
'Random state must be an instance of RandomState. '
'Actual: {}'.format(type(rs)))
_random_states[device.get_device_id()] = rs
def _check_and_get_dtype(dtype):
dtype = numpy.dtype(dtype)
if dtype.char not in ('f', 'd'):
raise TypeError('cupy.random only supports float32 and float64')
return dtype