Skip to content

Commit

Permalink
Optimize CAI
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Sep 28, 2019
1 parent 84932c0 commit 7bdaa81
Showing 1 changed file with 15 additions and 25 deletions.
40 changes: 15 additions & 25 deletions lib/model/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class ConvolutionAware(initializers.Initializer):
seed: A Python integer. Used to seed the random generator.
# References
Armen Aghajanyan, https://arxiv.org/abs/1702.06295
# Adapted and fixed from:
# Adapted, fixed and optimized from:
https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/initializers/convaware.py
"""

Expand Down Expand Up @@ -153,40 +153,30 @@ def __call__(self, shape, dtype=None):
return K.variable(self.orthogonal(shape), dtype=dtype)

kernel_fourier_shape = correct_fft(np.zeros(kernel_shape)).shape
init = []
for _ in range(filters_size):
basis = self._create_basis(
stack_size, np.prod(kernel_fourier_shape), dtype)
basis = basis.reshape((stack_size,) + kernel_fourier_shape)

filters = [correct_ifft(x, kernel_shape) +
np.random.normal(0, self.eps_std, kernel_shape) for
x in basis]

init.append(filters)

# Format of array is now: filters, stack, row, column
init = np.array(init)
basis = self._create_basis(filters_size, stack_size, np.prod(kernel_fourier_shape), dtype)
basis = basis.reshape((filters_size, stack_size,) + kernel_fourier_shape)
randoms = np.random.normal(0, self.eps_std, basis.shape[:-2] + kernel_shape)
init = correct_ifft(basis, kernel_shape) + randoms
init = self._scale_filters(init, variance)
return K.variable(init.transpose(transpose_dimensions), dtype=dtype, name="conv_aware")

def _create_basis(self, filters, size, dtype):
def _create_basis(self, filters_size, filters, size, dtype):
if size == 1:
return np.random.normal(0.0, self.eps_std, (filters, size))

return np.random.normal(0.0, self.eps_std, (filters_size, filters, size))
nbb = filters // size + 1
lst = []
for _ in range(nbb):
var_a = np.random.normal(0.0, 1.0, (size, size))
var_a = self._symmetrize(var_a)
var_u, _, _ = np.linalg.svd(var_a)
lst.extend(var_u.T.tolist())
var_p = np.array(lst[:filters], dtype=dtype)
var_a = np.random.normal(0.0, 1.0, (filters_size, nbb, size, size))
var_a = self._symmetrize(var_a)
var_u = np.linalg.svd(var_a)[0].transpose(0, 1, 3, 2)
var_p = np.reshape(var_u, (filters_size, nbb * size, size))[:, :filters, :].astype(dtype)
return var_p

@staticmethod
def _symmetrize(var_a):
return var_a + var_a.T - np.diag(var_a.diagonal())
var_b = np.transpose(var_a, axes=(0, 1, 3, 2))
diag = var_a.diagonal(axis1=2, axis2=3)
var_c = np.array([[np.diag(arr) for arr in batch] for batch in diag])
return var_a + var_b - var_c

@staticmethod
def _scale_filters(filters, variance):
Expand Down

0 comments on commit 7bdaa81

Please sign in to comment.