## basics

- Spectral Normalization for Generative Adversarial Networks
    - https://arxiv.org/abs/1802.05957
- pytorch 的两个接口
    - old：`torch.nn.utils.spectral_norm`
    - new：`torch.nn.utils.parametrizations.spectral_norm`

## pytorch api

In [15]:
import numpy as np
import torch 
from torch import nn

In [25]:
m = nn.Linear(5, 4)

In [27]:
W = m.weight.clone()
W

tensor([[ 0.1717,  0.2338, -0.4089, -0.0635,  0.3527],
        [ 0.0247, -0.3828,  0.0763,  0.2116,  0.3518],
        [ 0.3632,  0.0150,  0.1076, -0.1810,  0.3513],
        [-0.3899,  0.2467,  0.2515, -0.1216,  0.3808]],
       grad_fn=<CloneBackward0>)

In [28]:
m_sm = nn.utils.parametrizations.spectral_norm(m)
m_sm.weight

tensor([[ 0.2319,  0.3158, -0.5524, -0.0858,  0.4765],
        [ 0.0333, -0.5171,  0.1030,  0.2858,  0.4753],
        [ 0.4907,  0.0202,  0.1454, -0.2445,  0.4745],
        [-0.5267,  0.3333,  0.3397, -0.1643,  0.5145]], grad_fn=<DivBackward0>)

In [30]:
U, s, V = np.linalg.svd(W.detach().numpy())
s

array([0.74032176, 0.6395773 , 0.58293104, 0.37460697], dtype=float32)

In [31]:
W/s[0]

tensor([[ 0.2319,  0.3158, -0.5524, -0.0858,  0.4765],
        [ 0.0333, -0.5170,  0.1030,  0.2858,  0.4753],
        [ 0.4906,  0.0202,  0.1454, -0.2445,  0.4745],
        [-0.5267,  0.3333,  0.3397, -0.1643,  0.5144]], grad_fn=<DivBackward0>)

## 矩阵的谱范数

$$
\|A\|_2 = \max_{\|X\|\neq 0}\frac{\|AX\|_2}{\|X\|_2}=\sqrt{\lambda_\max(A^TA)}=\sigma_\max(A)
$$

- The spectral norm (also know as Induced 2-norm) is the maximum singular value of a matrix. Intuitively, you can think of it as the maximum 'scale', by which the matrix can 'stretch' a vector.

- The maximum singular value is the square root of the maximum eigenvalue or the maximum eigenvalue if the matrix is symmetric/hermitian

In [1]:
import numpy as np

In [4]:
A = np.random.randint(0, 5, (5, 4))
A

array([[1, 4, 4, 4],
       [3, 2, 3, 0],
       [0, 0, 2, 4],
       [1, 4, 0, 2],
       [1, 0, 4, 4]])

In [7]:
x = np.random.randn(4, 1)
x

array([[-0.98453805],
       [ 1.95325231],
       [ 0.62514591],
       [ 0.88666499]])

In [8]:
np.linalg.norm(x, 2)

2.4416166720742534

In [9]:
np.linalg.norm(A.dot(x), 2)

17.21674155029448

In [10]:
np.linalg.svd(A)

(array([[-0.65752954, -0.18334979, -0.15086348, -0.70308521, -0.13018891],
        [-0.31992344, -0.39459151,  0.76998239,  0.28509129, -0.26037782],
        [-0.35688849,  0.47670322, -0.26754883,  0.38739315, -0.65094455],
        [-0.30736573, -0.56885818, -0.51565347,  0.49822843,  0.26037782],
        [-0.49341748,  0.50973993,  0.21653282,  0.16152125,  0.65094455]]),
 array([10.47221133,  4.681808  ,  3.50148014,  1.07382518]),
 array([[-0.23090472, -0.42965405, -0.59942788, -0.63463898],
        [-0.304635  , -0.81122825,  0.22965326,  0.44313161],
        [ 0.53119337, -0.32160771,  0.58190445, -0.52515645],
        [ 0.7561178 , -0.23210905, -0.49933237,  0.35366431]]))