In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Does nn.Conv2d init work well?

In [167]:
#export
from exp.nb_02 import *
import jax.lax as lax
import jax.experimental.stax as stax
from jax.experimental.stax import Dense, Conv, Relu, Sigmoid, glorot, randn

def normalize(x, m, s): return (x-m)/s

def kalming_normal(out_axis=0, in_axis=1, scale=onp.sqrt(2.)):
  """An initializer function for random Kalming-scaled coefficients."""
  def init(rng, shape):
    fan_in, fan_out = shape[in_axis], shape[out_axis]
    size = onp.prod(onp.delete(shape, [in_axis, out_axis]))
    std = scale / np.sqrt((fan_in) / 2. * size)
    std = lax.convert_element_type(std, np.float32)
    return std * jax.random.normal(rng, shape, dtype=np.float32)
  return init

def GlorotUniformInitializer(out_axis=0, in_axis=1, scale=1.):
  """An initializer function for random uniform Glorot-scaled coefficients."""
  def init(rng, shape):
    fan_in, fan_out = shape[in_axis], shape[out_axis]
    print(fan_in, fan_out)
    size = onp.prod(onp.delete(shape, [in_axis, out_axis]))
    std = scale / np.sqrt((fan_in + fan_out) / 6. * size)
    std = lax.convert_element_type(std, np.float32)
    return jax.random.uniform(rng, shape, minval=-std, maxval=std, dtype=np.float32)
  return init

def uniform(param=1.):
  """An initializer function for random uniform coefficients."""
  def init(rng, shape):
    a = lax.convert_element_type(param, np.float32)
    return jax.random.uniform(rng, shape, minval=-a, maxval=a, dtype=np.float32)
  return init


In [89]:
x_train,y_train,x_valid,y_valid = get_data()
train_mean,train_std = np.mean(x_train), np.std(x_train)
x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std)

In [90]:
x_train = x_train.reshape(-1,28,28,1)
x_valid = x_valid.reshape(-1,28,28,1)
x_train.shape,x_valid.shape

((60000, 28, 28, 1), (10000, 28, 28, 1))

In [91]:
n,*_ = x_train.shape
c = y_train.max()+1
nh = 32
n,c

(60000, DeviceArray(2., dtype=float32))

In [92]:
x = x_valid[:100]

In [46]:
x.shape

(100, 28, 28, 1)

In [47]:
def stats(x): return np.mean(x),np.std(x)

In [70]:
suma = []
for i in range(100):
    rng = jax.random.PRNGKey(i)
    y = jax.random.normal(rng, (100,), dtype=np.float32)
    suma.append(stats(y))

In [74]:
s = [a[0] for a in suma]
sum(s)

DeviceArray(0.79094946, dtype=float32)

In [159]:
net_init, net_apply = stax.serial(Conv(nh, (5, 5), b_init=uniform(1.0/onp.sqrt(32.*5.))))

In [160]:
rng = jax.random.PRNGKey(3)
_, weights = net_init(rng, (-1, 28, 28, 1))
(weights[0][0].shape, weights[0][1].shape)

((5, 5, 1, 32), (32,))

In [161]:
stats(weights[0][0]),stats(weights[0][1])

((DeviceArray(-0.00036583, dtype=float32),
  DeviceArray(0.07276955, dtype=float32)),
 (DeviceArray(-0.00922053, dtype=float32),
  DeviceArray(0.04904579, dtype=float32)))

In [162]:
t = net_apply(weights, x)

In [163]:
stats(t)

(DeviceArray(-0.01046354, dtype=float32),
 DeviceArray(0.37558442, dtype=float32))

In [169]:
net_init, net_apply = stax.serial(Conv(nh, (3, 3), W_init=kalming_normal(3, 2), b_init=uniform(1.0/onp.sqrt(32.*3.))))
_, weights = net_init(rng, (-1, 28, 28, 1))
stats(net_apply(weights, x))

(DeviceArray(-0.02456197, dtype=float32),
 DeviceArray(2.0832274, dtype=float32))

In [170]:
net_init, net_apply = stax.serial(Conv(nh, (3, 3), W_init=kalming_normal(3, 2, scale=1./onp.sqrt(2.)), b_init=uniform(1.0/onp.sqrt(32.*3.))))
_, weights = net_init(rng, (-1, 28, 28, 1))
stats(net_apply(weights, x))

(DeviceArray(-0.01823236, dtype=float32),
 DeviceArray(1.0431904, dtype=float32))

In [None]:
def f1(x,a=0): return F.leaky_relu(l1(x),a)

In [None]:
init.kaiming_normal_(l1.weight, a=0)
stats(f1(x))

(tensor(0.5547, grad_fn=<MeanBackward1>),
 tensor(1.0199, grad_fn=<StdBackward0>))

In [None]:
l1 = nn.Conv2d(1, nh, 5)
stats(f1(x))

(tensor(0.2219, grad_fn=<MeanBackward1>),
 tensor(0.3653, grad_fn=<StdBackward0>))

In [None]:
l1.weight.shape

torch.Size([32, 1, 5, 5])

In [None]:
# receptive field size
rec_fs = l1.weight[0,0].numel()
rec_fs

25

In [None]:
nf,ni,*_ = l1.weight.shape
nf,ni

(32, 1)

In [None]:
fan_in  = ni*rec_fs
fan_out = nf*rec_fs
fan_in,fan_out

(25, 800)

In [None]:
def gain(a): return math.sqrt(2.0 / (1 + a**2))

In [None]:
gain(1),gain(0),gain(0.01),gain(0.1),gain(math.sqrt(5.))

(1.0,
 1.4142135623730951,
 1.4141428569978354,
 1.4071950894605838,
 0.5773502691896257)

In [None]:
torch.zeros(10000).uniform_(-1,1).std()

tensor(0.5788)

In [None]:
1/math.sqrt(3.)

0.5773502691896258

In [None]:
def kaiming2(x,a, use_fan_out=False):
    nf,ni,*_ = x.shape
    rec_fs = x[0,0].shape.numel()
    fan = nf*rec_fs if use_fan_out else ni*rec_fs
    std = gain(a) / math.sqrt(fan)
    bound = math.sqrt(3.) * std
    x.data.uniform_(-bound,bound)

In [None]:
kaiming2(l1.weight, a=0);
stats(f1(x))

(tensor(0.5603, grad_fn=<MeanBackward1>),
 tensor(1.0921, grad_fn=<StdBackward0>))

In [None]:
kaiming2(l1.weight, a=math.sqrt(5.))
stats(f1(x))

(tensor(0.2186, grad_fn=<MeanBackward1>),
 tensor(0.3437, grad_fn=<StdBackward0>))

In [None]:
class Flatten(nn.Module):
    def forward(self,x): return x.view(-1)

In [None]:
m = nn.Sequential(
    nn.Conv2d(1,8, 5,stride=2,padding=2), nn.ReLU(),
    nn.Conv2d(8,16,3,stride=2,padding=1), nn.ReLU(),
    nn.Conv2d(16,32,3,stride=2,padding=1), nn.ReLU(),
    nn.Conv2d(32,1,3,stride=2,padding=1),
    nn.AdaptiveAvgPool2d(1),
    Flatten(),
)

In [None]:
y = y_valid[:100].float()

In [None]:
t = m(x)
stats(t)

(tensor(0.0875, grad_fn=<MeanBackward1>),
 tensor(0.0065, grad_fn=<StdBackward0>))

In [None]:
l = mse(t,y)
l.backward()

In [None]:
stats(m[0].weight.grad)

(tensor(0.0054), tensor(0.0333))

In [None]:
init.kaiming_uniform_??

In [None]:
for l in m:
    if isinstance(l,nn.Conv2d):
        init.kaiming_uniform_(l.weight)
        l.bias.data.zero_()

In [None]:
t = m(x)
stats(t)

(tensor(-0.0352, grad_fn=<MeanBackward1>),
 tensor(0.4043, grad_fn=<StdBackward0>))

In [None]:
l = mse(t,y)
l.backward()
stats(m[0].weight.grad)

(tensor(0.0093), tensor(0.4231))

## Export

In [None]:
!./notebook2script.py 02a_why_sqrt5.ipynb