-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
orthogonal.py
66 lines (54 loc) · 2.76 KB
/
orthogonal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy
from chainer import backend
from chainer import initializer
from chainer import utils
# Original code forked from MIT licensed keras project
# https://github.com/fchollet/keras/blob/master/keras/initializations.py
class Orthogonal(initializer.Initializer):
"""Initializes array with an orthogonal system.
This initializer first makes a matrix of the same shape as the
array to be initialized whose elements are drawn independently from
standard Gaussian distribution.
Next, it applies QR decomposition to (the transpose of) the matrix.
To make the decomposition (almost surely) unique, we require the diagonal
of the triangular matrix R to be non-negative (see e.g. Edelman & Rao,
https://web.eecs.umich.edu/~rajnrao/Acta05rmt.pdf).
Then, it initializes the array with the (semi-)orthogonal matrix Q.
Finally, the array is multiplied by the constant ``scale``.
If the ``ndim`` of the input array is more than 2, we consider the array
to be a matrix by concatenating all axes except the first one.
The number of vectors consisting of the orthogonal system
(i.e. first element of the shape of the array) must be equal to or smaller
than the dimension of each vector (i.e. second element of the shape of
the array).
Attributes:
scale (float): A constant to be multiplied by.
dtype: Data type specifier.
Reference: Saxe et al., https://arxiv.org/abs/1312.6120
"""
def __init__(self, scale=1.1, dtype=None):
self.scale = scale
super(Orthogonal, self).__init__(dtype)
# TODO(Kenta Oono)
# How do we treat overcomplete base-system case?
def __call__(self, array):
if self.dtype is not None:
assert array.dtype == self.dtype
device = backend.get_device_from_array(array)
if not array.shape: # 0-dim case
array[...] = self.scale * (2 * numpy.random.randint(2) - 1)
elif not array.size:
raise ValueError('Array to be initialized must be non-empty.')
else:
# numpy.prod returns float value when the argument is empty.
flat_shape = (len(array), utils.size_of_shape(array.shape[1:]))
if flat_shape[0] > flat_shape[1]:
raise ValueError('Cannot make orthogonal system because'
' # of vectors ({}) is larger than'
' that of dimensions ({})'.format(
flat_shape[0], flat_shape[1]))
a = numpy.random.normal(size=flat_shape)
# cupy.linalg.qr requires cusolver in CUDA 8+
q, r = numpy.linalg.qr(a.T)
q *= numpy.copysign(self.scale, numpy.diag(r))
array[...] = device.xp.asarray(q.T.reshape(array.shape))