# Conv2d Init

In [251]:
import torch
import torch.nn as nn

from torchvision import datasets
from torchvision import transforms

tfms = transforms.ToTensor()

ds_train = datasets.MNIST('./data/', download="True", train=True, transform=tfms)
ds_test = datasets.MNIST('./data/', download="True", train=False, transform=tfms)

In [28]:
ds_train.data.shape

torch.Size([60000, 28, 28])

In [37]:
x_train = ds_train.data.float()
x_test  = ds_test.data.float()

In [224]:
y_train = ds_train.targets
y_test  = ds_test.targets

In [38]:
def stats(x):
    return x.mean(), x.std()

In [39]:
stats(x_train)

(tensor(33.3141), tensor(78.5675))

In [40]:
stats(x_test)

(tensor(33.7912), tensor(79.1725))

In [41]:
def normalize(x, m, s):
    return (x-m)/s

In [42]:
m,s = stats(x_train)
x_train = normalize(x_train, m, s)
x_test  = normalize(x_test, m, s)

In [44]:
stats(x_train)

(tensor(4.7499e-05), tensor(1.))

In [45]:
stats(x_test)

(tensor(0.0061), tensor(1.0077))

## Test conv

In [55]:
x = x_test[:100].unsqueeze(1)

In [56]:
x.shape

torch.Size([100, 1, 28, 28])

In [106]:
nh=32
l1 = nn.Conv2d(1, nh, kernel_size=5)

In [107]:
l1.weight.shape

torch.Size([32, 1, 5, 5])

In [108]:
stats(l1.weight)

(tensor(-0.0040, grad_fn=<MeanBackward0>),
 tensor(0.1173, grad_fn=<StdBackward0>))

In [50]:
stats(l1.bias)

(tensor(-0.0008, grad_fn=<MeanBackward0>),
 tensor(0.1152, grad_fn=<StdBackward0>))

In [58]:
t = l1(x)

In [59]:
stats(t)

(tensor(0.0021, grad_fn=<MeanBackward0>),
 tensor(0.6386, grad_fn=<StdBackward0>))

In [61]:
# ??nn.Conv2d

In [168]:
torch.nn.init.kaiming_normal_(l1.weight, a=1.)
stats(l1.weight)

(tensor(0.0027, grad_fn=<MeanBackward0>),
 tensor(0.2021, grad_fn=<StdBackward0>))

In [169]:
stats(l1(x))

(tensor(0.0163, grad_fn=<MeanBackward0>),
 tensor(1.1196, grad_fn=<StdBackward0>))

In [161]:
# ??nn.modules.conv._ConvNd

In [170]:
import torch.nn.functional as F

In [171]:
def f1(x, a=0):
    return F.leaky_relu(l1(x), a)

In [176]:
torch.nn.init.kaiming_normal_(l1.weight, a=0.)
stats(f1(x))
# Mean is no longer zero, its about half as we discussed last week

(tensor(0.4643, grad_fn=<MeanBackward0>),
 tensor(0.8288, grad_fn=<StdBackward0>))

In [175]:
torch.nn.init.kaiming_normal_(l1.weight, a=1.)
stats(f1(x))

(tensor(0.3700, grad_fn=<MeanBackward0>),
 tensor(0.7091, grad_fn=<StdBackward0>))

### Reset to defaults again

In [177]:
l1 = nn.Conv2d(1, nh, kernel_size=5)
stats(l1(x))

(tensor(0.0069, grad_fn=<MeanBackward0>),
 tensor(0.6921, grad_fn=<StdBackward0>))

In [178]:
l1.weight.shape

torch.Size([32, 1, 5, 5])

In [179]:
l1.weight[0,0].shape

torch.Size([5, 5])

In [182]:
rct_fs = receptive_field_size = l1.weight[0,0].numel()

In [184]:
n_out, n_in, *_ = l1.weight.shape
n_out, n_in

(32, 1)

In [188]:
fan_in =   n_in * rct_fs
fan_out = n_out * rct_fs

fan_in, fan_out

(25, 800)

In [192]:
import math

In [193]:
def gain(a):
    return math.sqrt(2. / (1. + a**2))

Kaiming init implemenatation - https://towardsdatascience.com/understand-kaiming-initialization-and-implementation-detail-in-pytorch-f7aa967e9138

In [196]:
gain(1), gain(0.1), gain(0.01), gain(0), gain(math.sqrt(0.5))

(1.0,
 1.4071950894605838,
 1.4141428569978354,
 1.4142135623730951,
 1.1547005383792515)

But `_Convnd` init is `kaiming_uniform`

In [202]:
u = torch.zeros(10000).uniform_(-1,1)
stats(u)

(tensor(0.0046), tensor(0.5789))

In [203]:
# std is nearly equal to
1/math.sqrt(3.)

0.5773502691896258

In [213]:
def kaiming2(x, a, use_fan_out=False):
    fan_out, fan_in , *_ = x.shape
    rct_fs = x[0,0].numel()
    fan = fan_out if use_fan_out else fan_in
    std = gain(a)/ math.sqrt(rct_fs * fan)
    bound = math.sqrt(3.) * std # (because uniform)
    x.data.uniform_(-bound, bound)

In [215]:
kaiming2(l1.weight, a=0)
stats(l1(x))

(tensor(-0.0409, grad_fn=<MeanBackward0>),
 tensor(1.5280, grad_fn=<StdBackward0>))

In [217]:
kaiming2(l1.weight, a=math.sqrt(5))
stats(l1(x))

(tensor(-0.0320, grad_fn=<MeanBackward0>),
 tensor(0.5977, grad_fn=<StdBackward0>))

In [219]:
class Flatten(nn.Module):
#     def __init__(self):
#         super().__init__()
    def forward(self, x):
        return x.view(-1)        

In [235]:
m = nn.Sequential(
    nn.Conv2d(1,8, kernel_size=2, stride=2, padding=2), nn.ReLU(),
    nn.Conv2d(8,16, kernel_size=3, stride=2, padding=2), nn.ReLU(),
    nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=2), nn.ReLU(),
    nn.Conv2d(32, 1, kernel_size=3, stride=2, padding=2),
    nn.AdaptiveAvgPool2d(1),
    Flatten()
)

In [236]:
y = y_test[:100].float()

In [237]:
t = m(x)
stats(t)

(tensor(0.0130, grad_fn=<MeanBackward0>),
 tensor(0.0029, grad_fn=<StdBackward0>))

In [238]:
mse = nn.MSELoss()

In [239]:
l = mse(t, y)
l.backward()

In [240]:
stats(m[0].weight.grad)

(tensor(0.0024), tensor(0.0143))

In [245]:
# this one already has the *sqrt(3)
# ??torch.nn.init.kaiming_uniform_

In [241]:
for l in m:
    if isinstance(l,nn.Conv2d):
        torch.nn.init.kaiming_uniform_(l.weight)
        l.bias.data.zero_()

In [242]:
t = m(x)
stats(t)

(tensor(-0.2717, grad_fn=<MeanBackward0>),
 tensor(0.1381, grad_fn=<StdBackward0>))

In [243]:
l = mse(t,y)
l.backward()
stats(m[0].weight.grad)

(tensor(-0.0570), tensor(0.3718))

## When do things explode/vanish

In [247]:
x = torch.randn(512)
w = torch.randn(512,512)

for i in range(100):
    x = w @ x
    if x.std() != x.std():
        print(i)
        break

28


In [250]:
x = torch.randn(512)
w = torch.randn(512,512)

for i in range(100):
    x = w @ x * 0.01
    if x.std() == 0:
        print(i)
        break

70
