In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Does nn.Conv2d init work well?

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=21)

In [1]:
#export
from exp.nb_02 import *

def get_data():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f,encoding='latin-1')
    return map(tensor, (x_train, y_train , x_valid, y_valid))

def normalize(x, m, s): return (x-m)/s

In [3]:
torch.nn.modules.conv._ConvNd.reset_parameters??

In [4]:
x_train, y_train, x_valid, y_valid = get_data()
train_mean, train_std = x_train.mean(), x_train.std()
x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std)

In [8]:
#converts the images from 1-d vectors to 3d arrays (1x28x28)
x_train =  x_train.view(-1,1,28,28)
x_valid = x_train.view(-1,1,28,28)
x_train.shape, x_valid.shape

(torch.Size([50000, 1, 28, 28]), torch.Size([50000, 1, 28, 28]))

In [13]:
n, *_ = x_train.shape
c = y_train.max()+1
nh = 32
n,c

(50000, tensor(10))

In [34]:
l1 = nn.Conv2d(1,nh,5)

In [35]:
x = x_valid[:100]
x.shape

torch.Size([100, 1, 28, 28])

In [36]:
def stats(x):
    '''
    Returns mean and std
    '''
    return x.mean(), x.std()

In [37]:
stats(l1.weight), stats(l1.bias)

((tensor(-0.0027, grad_fn=<MeanBackward0>),
  tensor(0.1141, grad_fn=<StdBackward0>)),
 (tensor(-0.0180, grad_fn=<MeanBackward0>),
  tensor(0.1012, grad_fn=<StdBackward0>)))

In [38]:
t = l1(x)

In [39]:
stats(t)

(tensor(-0.0270, grad_fn=<MeanBackward0>),
 tensor(0.5709, grad_fn=<StdBackward0>))

In [40]:
init.kaiming_normal_(l1.weight, a=1)
stats(l1(x))

(tensor(-0.0474, grad_fn=<MeanBackward0>),
 tensor(1.0451, grad_fn=<StdBackward0>))

In [42]:
import torch.nn.functional as F

In [47]:
def f1(x,a=0):
    return F.leaky_relu(l1(x),a)

In [48]:
init.kaiming_normal_(l1.weight,a=0)
stats(f1(x))

(tensor(0.5629, grad_fn=<MeanBackward0>),
 tensor(0.9966, grad_fn=<StdBackward0>))

In [50]:
#without the kaiming init we have a much lower variance and the mean is far from 1
l1 = nn.Conv2d(1,nh,5)
stats(f1(x))

(tensor(0.1870, grad_fn=<MeanBackward0>),
 tensor(0.3824, grad_fn=<StdBackward0>))

In [51]:
l1.weight.shape

torch.Size([32, 1, 5, 5])

In [52]:
rec_fs = l1.weight[0,0].numel()
rec_fs

25

In [57]:
nf, ni, *_ = l1.weight.shape
nf, ni

(32, 1)

In [59]:
fan_in = ni*rec_fs
fan_out = nf*rec_fs

fan_in, fan_out

(25, 800)

In [63]:
#a is the amount of alpha in leakyrelu, its used for kaiming init
def gain(a): 
    return math.sqrt(2.0/(1+a**2))

In [64]:
gain(1), gain(0), gain(0.1), gain(math.sqrt(5))

(1.0, 1.4142135623730951, 1.4071950894605838, 0.5773502691896257)

pytorch uses kaiming uniform which actually has a std less than 1 since it pulls from a uniform distribution btwn -1 and 1, if we create a mock we will see it sits around 1/sqrt(3)
which is why a pytorch uses such a high gain parameter


we will see below how bad it works with kaiming nor

In [66]:
torch.zeros(10000).uniform_(-1,1).std(), 1/math.sqrt(3.)

(tensor(0.5774), 0.5773502691896258)

In [69]:
def kaiming2(x,a ,use_fan_out=False):
    nf, ni, *_ = x.shape
    rec_fs = x[0,0].shape.numel()
    fan = nf*rec_fs if use_fan_out else ni*rec_fs
    std = gain(a) / math.sqrt(fan)
    bound = math.sqrt(3.) * std
    x.data.uniform_(-bound, bound)

In [70]:
kaiming2(l1.weight, a=0)
stats(f1(x))

(tensor(0.5637, grad_fn=<MeanBackward0>),
 tensor(1.0513, grad_fn=<StdBackward0>))

In [72]:
kaiming2(l1.weight, a=math.sqrt(5.))
stats(f1(x))

(tensor(0.2148, grad_fn=<MeanBackward0>),
 tensor(0.3997, grad_fn=<StdBackward0>))

In [74]:
class Flatten(nn.Module):

    def forward(self, x):
        return x.view(-1)

In [76]:
m = nn.Sequential(
    nn.Conv2d(1,8,5,stride=2,padding=2), nn.ReLU(),
    nn.Conv2d(8,16,3,stride=2,padding=1), nn.ReLU(),
    nn.Conv2d(16,32,3,stride=2,padding=1), nn.ReLU(),
    nn.Conv2d(32,1,3,stride=2,padding=1),
    nn.AdaptiveAvgPool2d(1),
    Flatten()
)

In [77]:
y = y_valid[:100].float()

In [78]:
t = m(x)
stats(t)

(tensor(-0.0413, grad_fn=<MeanBackward0>),
 tensor(0.0077, grad_fn=<StdBackward0>))

In [81]:
l = mse(t,y)
l.backward()

In [82]:
stats(m[0].weight.grad)

(tensor(0.0102), tensor(0.0312))

Now lets try kaiming uniform on conv layers

In [85]:
init.kaiming_uniform_??

In [83]:
for l in m:
    if isinstance(l, nn.Conv2d):
        init.kaiming_uniform_(l.weight)
        l.bias.data.zero_()

In [84]:
#not great but better than before
t = m(x)
stats(t)

(tensor(-0.7845, grad_fn=<MeanBackward0>),
 tensor(0.4963, grad_fn=<StdBackward0>))

In [86]:
#now lets check backward pass
l = mse(t,y)
l.backward()
stats(m[0].weight.grad)

(tensor(0.2681), tensor(1.0013))

## Export

In [None]:
!./notebook2script.py 02a_why_sqrt5.ipynb