In [3]:
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn

### Testing hand coding xavier normalization

In [70]:
ndim = 512
x = torch.tensor(np.random.normal(size=(ndim, ndim), scale=np.sqrt(1/ndim)))
for _ in range(1000):
    w = torch.tensor(np.random.normal(size=(ndim, ndim), scale=np.sqrt(1/ndim)))
    x = F.selu(x@w)

m = x.mean(dim=1)
s = x.std(dim=1)
x.mean(), x.std(), m.min(), m.max(), s.min(), s.max()

(tensor(1.4480e-05, dtype=torch.float64),
 tensor(1.0055, dtype=torch.float64),
 tensor(-0.1182, dtype=torch.float64),
 tensor(0.1194, dtype=torch.float64),
 tensor(0.8829, dtype=torch.float64),
 tensor(1.1382, dtype=torch.float64))

### Testing kaiming normal with selu

kaiming is terrible when using selu, as expected

In [58]:
x = nn.init.kaiming_normal_(torch.empty((100,100)), mode='fan_out')
for _ in range(100):
    w = nn.init.kaiming_normal_(torch.empty((100,100)), mode='fan_out')
    x = F.selu(x@w)

m = x.mean(dim=1)
s = x.std(dim=1)
x.mean(), x.std(), m.min(), m.max(), s.min(), s.max()

(tensor(132.1720),
 tensor(226.7363),
 tensor(44.5111),
 tensor(288.9836),
 tensor(71.3681),
 tensor(450.4102))

### Testing xavier normal with selu

xavier normal with selu appears to be far superior

In [75]:
x = nn.init.xavier_normal_(torch.zeros((100,100)))
for _ in range(1000):
    w = nn.init.xavier_normal_(torch.zeros((100,100)))
    x = F.selu(x@w)

m = x.mean(dim=1)
s = x.std(dim=1)
x.mean(), x.std(), m.min(), m.max(), s.min(), s.max()

(tensor(0.0232),
 tensor(0.9790),
 tensor(-0.1307),
 tensor(0.1830),
 tensor(0.7782),
 tensor(1.1792))

### Testing selu paper's recommendation with np based selu func

In [62]:
import numpy as np

def selu_np(x):
    alpha = 1.6732632423543772848170429916717
    scale = 1.0507009873554804934193349852946
    return scale*np.where(x>=0.0, x, alpha*np.exp(x)-alpha)

x = np.random.normal(size=(200, 200))
for _ in range(100):
    w = np.random.normal(size=(200, 200), scale=np.sqrt(1/200))
    x = selu_np(np.dot(x, w))

m = np.mean(x, axis=1)
s = np.std(x, axis=1)
x.mean(), x.std(), m.min(), m.max(), s.min(), s.max()


(-0.006943016470394205,
 0.9831865049501114,
 -0.17200409046748186,
 0.1426579564326867,
 0.753801801554161,
 1.2165376012023728)