In [1]:
import torch
import torch.nn as nn

seed = 42
torch.manual_seed(42)
torch.cuda.manual_seed(42)

### NLP Example

* Sample Embedding 생성

In [2]:
batch, sentence_len, embed_dim = 2, 5, 5
batch_embed = torch.randn(batch, sentence_len, embed_dim)
print(batch_embed)

tensor([[[ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784],
         [-1.2345, -0.0431, -1.6047, -0.7521,  1.6487],
         [-0.3925, -1.4036, -0.7279, -0.5594, -0.7688],
         [ 0.7624,  1.6423, -0.1596, -0.4974,  0.4396],
         [-0.7581,  1.0783,  0.8008,  1.6806,  1.2791]],

        [[ 1.2964,  0.6105,  1.3347, -0.2316,  0.0418],
         [-0.2516,  0.8599, -1.3847, -0.8712,  0.0780],
         [ 0.5258, -0.4880,  1.1914, -0.8140, -0.7360],
         [-0.8371, -0.9224, -0.0635,  0.6756, -0.0978],
         [ 1.8446, -1.1845,  1.3835, -1.2024,  0.7078]]])


* nn.LayerNorm 함수를 사용한 Nomalize 
    * embeding demension을 기준으로 $\mu, \sigma^2$를 이용하여 Normalize


In [3]:
layer_norm = nn.LayerNorm(embed_dim)
print(layer_norm(batch_embed))

tensor([[[ 9.5596e-01,  6.4450e-01,  2.2894e-01, -1.9008e+00,  7.1452e-02],
         [-7.2907e-01,  3.0826e-01, -1.0513e+00, -3.0907e-01,  1.7812e+00],
         [ 1.1002e+00, -1.8430e+00,  1.2390e-01,  6.1422e-01,  4.6816e-03],
         [ 4.3522e-01,  1.6136e+00, -7.9961e-01, -1.2520e+00,  2.8364e-03],
         [-1.8792e+00,  3.1295e-01, -1.8318e-02,  1.0319e+00,  5.5265e-01]],

        [[ 1.0773e+00,  1.7915e-04,  1.1375e+00, -1.3222e+00, -8.9286e-01],
         [ 8.0589e-02,  1.5173e+00, -1.3841e+00, -7.2040e-01,  5.0664e-01],
         [ 7.4713e-01, -5.3673e-01,  1.5900e+00, -9.4959e-01, -8.5080e-01],
         [-1.0051e+00, -1.1509e+00,  3.1715e-01,  1.5804e+00,  2.5848e-01],
         [ 1.1994e+00, -1.1678e+00,  8.3913e-01, -1.1818e+00,  3.1105e-01]]],
       grad_fn=<NativeLayerNormBackward0>)


* Layer Normalization 구현 
* $y = \frac{x - E[x]}{\sqrt{Var[x] + \epsilon}}*\gamma + \beta$
* $\gamma, \beta$는 학습 가능한 paramters( weight, bias 
* var 계산 시, unbiased=False으로 설정하지 않으면 Bessel’s correction을 통해 표본 크기를 n이 아닌 n-1을 사용하게 된다)  

In [None]:
eg_mean = torch.mean(batch_embed, -1, keepdim=True) # E[x]
print('mean:\n ', eg_mean)
print(' ')
eg_var = torch.var(batch_embed, -1, keepdim=True, unbiased=False) # Var[x]
print('var:\n', eg_var)

mean:
  tensor([[[ 0.5776],
         [-0.3971],
         [-0.7704],
         [ 0.4375],
         [ 0.8161]],

        [[ 0.6104],
         [-0.3139],
         [-0.0642],
         [-0.2490],
         [ 0.3098]]])
 
var:
 tensor([[[1.9924],
         [1.3193],
         [0.1180],
         [0.5575],
         [0.7018]],

        [[0.4055],
         [0.5985],
         [0.6235],
         [0.3423],
         [1.6374]]])


In [None]:
# 위의 결과와 동일
eg_x_hat = (batch_embed-eg_mean)/torch.sqrt(eg_var + layer_norm.eps)
print(layer_norm.state_dict())
print(layer_norm.weight * eg_x_hat + layer_norm.bias)

OrderedDict({'weight': tensor([1., 1., 1., 1., 1.]), 'bias': tensor([0., 0., 0., 0., 0.])})
tensor([[[ 9.5596e-01,  6.4450e-01,  2.2894e-01, -1.9008e+00,  7.1452e-02],
         [-7.2907e-01,  3.0826e-01, -1.0513e+00, -3.0907e-01,  1.7812e+00],
         [ 1.1002e+00, -1.8430e+00,  1.2390e-01,  6.1422e-01,  4.6816e-03],
         [ 4.3521e-01,  1.6136e+00, -7.9961e-01, -1.2520e+00,  2.8363e-03],
         [-1.8792e+00,  3.1295e-01, -1.8318e-02,  1.0319e+00,  5.5265e-01]],

        [[ 1.0773e+00,  1.7905e-04,  1.1375e+00, -1.3222e+00, -8.9286e-01],
         [ 8.0589e-02,  1.5173e+00, -1.3841e+00, -7.2040e-01,  5.0664e-01],
         [ 7.4713e-01, -5.3673e-01,  1.5900e+00, -9.4959e-01, -8.5080e-01],
         [-1.0051e+00, -1.1509e+00,  3.1715e-01,  1.5804e+00,  2.5848e-01],
         [ 1.1994e+00, -1.1678e+00,  8.3913e-01, -1.1818e+00,  3.1105e-01]]],
       grad_fn=<AddBackward0>)


### Image Example

In [None]:
batch, channel, height, width = 2, 3, 5, 5
batch_img = torch.randn(batch, channel, height, width)
print(batch_img)

tensor([[[[ 2.2181e+00,  5.2317e-01,  3.4665e-01, -1.9733e-01, -1.0546e+00],
          [ 1.2780e+00,  1.4534e-01,  2.3105e-01,  5.6622e-02,  4.2630e-01],
          [ 5.7501e-01, -6.4172e-01, -2.2064e+00, -7.5080e-01,  2.8140e+00],
          [ 3.5979e-01, -1.3407e+00, -5.8537e-01,  5.3619e-01,  5.2462e-01],
          [ 1.1412e+00,  5.1644e-02,  7.2811e-01, -7.1064e-01, -1.0495e+00]],

         [[ 6.0390e-01, -1.7223e+00, -8.2777e-01,  1.3347e+00,  4.8354e-01],
          [-1.9756e-01,  1.2683e+00,  7.8459e-01,  2.8647e-02,  6.4076e-01],
          [ 5.8325e-01,  1.0669e+00, -4.5015e-01, -6.7875e-01,  5.7432e-01],
          [ 4.0476e-01,  1.7847e-01,  2.6491e-01,  1.2732e+00, -1.3109e-03],
          [-3.0360e-01, -9.8644e-01,  1.2330e-01, -5.9915e-01,  4.7706e-01]],

         [[ 7.2618e-01,  9.1152e-02, -3.8907e-01,  5.2792e-01,  1.0311e+00],
          [-7.0477e-01,  1.3254e-01,  7.6424e-01,  1.0950e+00,  3.3989e-01],
          [ 7.1997e-01,  4.1141e-01, -5.7332e-01,  5.0686e-01, -1.4364e+

* nn.LayerNorm 함수를 사용한 Nomalize 
    * image demension을 기준으로 $\mu, \sigma^2$를 이용하여 Normalize
        * 1\) channel dimension을 기준으로 $\mu, \sigma^2$ 산출
        * 2\) 1)에서 height 기준으로 $\mu, \sigma^2$ 산출
        * 3\) 2)에서 width 기준으로 $\mu, \sigma^2$ 산출
        * 각 image의 $\mu, \sigma^2$ 값 산출

In [None]:
layer_norm = nn.LayerNorm([channel, height, width])
print(layer_norm(batch_img))

tensor([[[[ 2.3414e+00,  4.2670e-01,  2.2730e-01, -3.8722e-01, -1.3556e+00],
          [ 1.2794e+00, -1.1660e-04,  9.6711e-02, -1.0034e-01,  3.1727e-01],
          [ 4.8526e-01, -8.8923e-01, -2.6568e+00, -1.0125e+00,  3.0146e+00],
          [ 2.4214e-01, -1.6788e+00, -8.2557e-01,  4.4141e-01,  4.2835e-01],
          [ 1.1249e+00, -1.0596e-01,  6.5822e-01, -9.6709e-01, -1.3498e+00]],

         [[ 5.1790e-01, -2.1099e+00, -1.0994e+00,  1.3435e+00,  3.8194e-01],
          [-3.8748e-01,  1.2685e+00,  7.2202e-01, -1.3194e-01,  5.5954e-01],
          [ 4.9457e-01,  1.0410e+00, -6.7282e-01, -9.3106e-01,  4.8449e-01],
          [ 2.9294e-01,  3.7307e-02,  1.3496e-01,  1.2740e+00, -1.6578e-01],
          [-5.0727e-01, -1.2786e+00, -2.5013e-02, -8.4114e-01,  3.7461e-01]],

         [[ 6.5604e-01, -6.1328e-02, -6.0381e-01,  4.3207e-01,  1.0005e+00],
          [-9.6045e-01, -1.4579e-02,  6.9904e-01,  1.0727e+00,  2.1966e-01],
          [ 6.4902e-01,  3.0045e-01, -8.1196e-01,  4.0829e-01, -1.7870e+