Skip to content

Commit

Permalink
Merge pull request #6877 from toslunar/bp-6655-improve-spectralnormal…
Browse files Browse the repository at this point in the history
…ization-hook

[backport] Improve `link_hooks.SpectralNormalization`
  • Loading branch information
hvy committed Apr 15, 2019
2 parents 76a164b + dff3882 commit 0cf3542
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 84 deletions.
2 changes: 1 addition & 1 deletion chainer/link_hook.py
Expand Up @@ -22,7 +22,7 @@ def __repr__(self):


class _ForwardPostprocessCallbackArgs(object):
"""Callback data for LinkHook.forward_postrocess"""
"""Callback data for LinkHook.forward_postprocess"""

def __init__(self, link, forward_name, args, kwargs, out):
# type: ('chainer.link.Link', str, tp.Tuple[tp.Any, ...], tp.Dict[str, tp.Any], tp.Any) -> None # NOQA
Expand Down
56 changes: 43 additions & 13 deletions chainer/link_hooks/spectral_normalization.py
@@ -1,12 +1,17 @@
import numpy

import chainer
from chainer import backend
from chainer import configuration
import chainer.functions as F
from chainer import link_hook
import chainer.links as L
from chainer import variable
import chainerx
from chainerx import _fallback_workarounds as fallback


def l2normalize(xp, v, eps=1e-12):
def l2normalize(xp, v, eps):
"""Normalize a vector by its L2 norm.
Args:
Expand All @@ -18,11 +23,18 @@ def l2normalize(xp, v, eps=1e-12):
:class:`numpy.ndarray` or :class:`cupy.ndarray`
"""
return v / (xp.linalg.norm(v) + eps)
# TODO(crcrpar): Remove this when chainerx.linalg.norm becomes available.
if xp is chainerx:
# NOTE(crcrpar): `chainerx.power` is not available as of 2019/03/27.
# See https://github.com/chainer/chainer/pull/6522
norm = chainerx.sqrt(chainerx.sum(v * v))
else:
norm = xp.linalg.norm(v)
return v / (norm + eps)


def update_approximate_vectors(
weight_matrix, u, n_power_iteration=1, eps=1e-12):
weight_matrix, u, n_power_iteration, eps):
"""Update the first left and right singular vectors.
This function updates the first left singular vector `u` and
Expand Down Expand Up @@ -95,8 +107,10 @@ class SpectralNormalization(link_hook.LinkHook):
Args:
n_power_iteration (int): Number of power iteration.
The default value is 1.
eps (int): Numerical stability in norm calculation.
The default value is 1e-12.
eps (float): Numerical stability in norm calculation.
The default value is 1e-6 for the compatibility with
mixed precision training. The value used in the author's
implementation is 1e-12.
use_gamma (bool): If ``True``, weight scaling parameter gamma which is
initialized by initial weight's max singular value is introduced.
factor (float, None): Scaling parameter to divide maximum singular
Expand Down Expand Up @@ -142,7 +156,7 @@ class SpectralNormalization(link_hook.LinkHook):

name = 'SpectralNormalization'

def __init__(self, n_power_iteration=1, eps=1e-12, use_gamma=False,
def __init__(self, n_power_iteration=1, eps=1e-6, use_gamma=False,
factor=None, weight_name='W', name=None):
assert n_power_iteration > 0
self.n_power_iteration = n_power_iteration
Expand Down Expand Up @@ -192,7 +206,7 @@ def forward_preprocess(self, cb_args):
# the unnormalized weight.
self.original_weight = weight
# note: `normalized_weight` is ~chainer.Variable
normalized_weight = self.normalize_weight(link, weight)
normalized_weight = self.normalize_weight(link)
setattr(link, self.weight_name, normalized_weight)

def forward_postprocess(self, cb_args):
Expand Down Expand Up @@ -225,20 +239,30 @@ def _prepare_parameters(self, link, input_variable=None):
if self.use_gamma:
# Initialize the scaling parameter with the max singular value.
weight_matrix = self.reshape_W(initialW.array)
_, s, _ = link.xp.linalg.svd(weight_matrix)
# TODO(crcrpar): Remove this when chainerx supports SVD.
if link.xp is chainerx:
xp, device, array = fallback._from_chx(weight_matrix)
if xp is numpy:
_, s, _ = numpy.linalg.svd(array)
else:
with chainer.using_device(device):
_, s, _ = xp.linalg.svd(array)
else:
_, s, _ = link.xp.linalg.svd(weight_matrix)
with link.init_scope():
link.gamma = variable.Parameter(s[0], ())
self._initialized = True

def normalize_weight(self, link, *args, **kwargs):
def normalize_weight(self, link):
"""Normalize target weight before every single forward computation."""
weight_name, vector_name = self.weight_name, self.vector_name
W = getattr(link, weight_name)
u = getattr(link, vector_name)
weight_matrix = self.reshape_W(W)
if not configuration.config.in_recomputing:
u, v = update_approximate_vectors(
weight_matrix, u, self.n_power_iteration, self.eps)
with chainer.using_device(link.device):
u, v = update_approximate_vectors(
weight_matrix, u, self.n_power_iteration, self.eps)
else:
v = self.v
sigma = calculate_max_singular_value(weight_matrix, u, v)
Expand All @@ -250,8 +274,14 @@ def normalize_weight(self, link, *args, **kwargs):
W = W / sigma
if not configuration.config.in_recomputing:
self.v = v
if configuration.config.train:
link.xp.copyto(getattr(link, vector_name), u)
with chainer.using_device(link.device):
if configuration.config.train:
if link.xp is chainerx:
# TODO(crcrpar): Remove this when
# chainerx supports `copyto`.
getattr(link, vector_name)[:] = u
else:
backend.copyto(getattr(link, vector_name), u)
return W

def reshape_W(self, W):
Expand Down

0 comments on commit 0cf3542

Please sign in to comment.