In [1]:
import torch
import torch.nn.functional as F

## tl;dr Kaiming init fan_out mode is wrong for stride>1 convolutions

In the example below final gradient variance must be 1 according to paper but in reality it's 25 times smaller. The case is a bit unnatural, you probably wouldn't do stride=5 convolution with kernel size 5, but it illustrates the point well:

In [3]:
d, c, k = 4, 3, 5
w = torch.empty(d, c, k, k)
torch.nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
w /= 2**.5 # divide by sqrt(2) gain, since we test on 1 layer w/out relu
var_w = 1/(k*k*d) # expected variance of w
print(w.var(), var_w)
x = torch.randn(1, c, 400, 400, requires_grad=True)
y = F.conv2d(x, w, stride=k)
# delta(y_l) is mean=0,std=1, so var(delta(x)) should be n_hat*var(w) = k^2*d*var(w)
y.backward(torch.randn_like(y))
x.grad.var(), var_w*d, var_w*k*k*d

tensor(0.0103) 0.01


(tensor(0.0415), 0.04, 1.0)

Whereas for stride=1 it's fine, we get variance of 1:

In [4]:
x.grad.zero_()
y = F.conv2d(x, w, stride=1)
y.backward(torch.randn_like(y))
x.grad.var()

tensor(1.0119)

The problem with stride=k is that n_hat in such (worst) cases equals d instead of k^2d. If we account for this, we get grad variance 1:

In [7]:
var_w = 1/d # instead of 1/(k*k*d)
w.normal_(0, var_w**.5)
print(w.var(), var_w)
x = torch.randn(1, c, 400, 400, requires_grad=True)
y = F.conv2d(x, w, stride=k)
y.backward(torch.randn_like(y))
x.grad.var(), var_w*d

tensor(0.2535) 0.25


(tensor(1.0309), 1.0)

When 1<stride<k we get something between the worst and the best cases.

## Details: when and why $\hat{n}$ is not $k^2d$

In essence the problem is that during backward pass convolution stride and image size also influence the variance of gradient, while Kaiming derivation assumes stride=1 and big (compared to kernel size) image size. Experiments below show, that when those assumptions are violated, Kaiming formula for gradient variance doesn't hold.

In [8]:
# calculate variance of y_L and delta(x1) for L conv layers
def conv_variance(L=1, imsize=100, ch=3, k=3, stride=1):
    # batch of 1 input square image, var(x1)==1
    x1 = torch.randn(1, ch, imsize, imsize, requires_grad=True)
    # for simplicity input channels c == output channels d == ch
    w = torch.randn(ch, ch, k, k) # var(w)==1

    x = x1
    for _ in range(L):
        y = F.conv2d(x, w, stride=stride)
        x = F.relu(y)
    
    # instead of getting grads from loss func we use mean=0,std=1 grads directly
    # for last layer delta(x_L) calculation, i.e. var(delta(y_L))==1
    delta_y_L = torch.randn_like(y)
    y.backward(delta_y_L)
    return y.var().item(), x1.grad.var().item()

# run conv_variance 100 times and return mean results
def test(*args, **kws):
    N = 100
    yvar, dvar = 0., 0.
    for _ in range(N):
        yv, dv = conv_variance(*args, **kws)
        yvar += yv
        dvar += dv
    return {'var(y_L)':yvar/N, 'var(delta(x1))':dvar/N}

In conv_variance() kernel size k and num of input/output channels c,d are the same for all layers (i.e. $n = k^2c = \hat{n} = k^2d$) and also $var(x_1) = 1$, $var(\Delta y_L) = 1$, $var(w) = 1$, so according to paper final layer variance and first layer grad variance MUST BE equal:

$$
var(y_L) = n\,var(w)\,var(x_1) \prod_{l=2}^{L}\frac{1}{2}n\,var(w) = n \prod_{l=2}^{L}\frac{1}{2}n = \frac{n^L}{2^{L-1}}
$$

$$
var(\Delta x_1) = n\,var(w)\,var(\Delta y_L) \prod_{l=L-1}^{1}\frac{1}{2}n\,var(w) = n \prod_{l=L-1}^{1}\frac{1}{2}n = \frac{n^L}{2^{L-1}}
$$

For L=1 $var(y_L) = var(\Delta x_1) = n = k^2c$

In [9]:
test(L=1, imsize=100, ch=3, k=3) # n=27

{'var(y_L)': 27.046004705429077, 'var(delta(x1))': 25.973986530303954}

Variances above are close, but grad variance is always a bit smaller. That's because only inner input pixels get k^2d gradients. Outer pixels get less grads, so their variance is less than k^2d, in turn making overall gradient variance smaller than k^2d.

With bigger image size w.r.t. kernel size outer pixels become negligible, so y_L and delta(x1) variances are even closer:

In [10]:
test(L=1, imsize=1000, ch=3, k=3) # n=27

{'var(y_L)': 26.860750179290772, 'var(delta(x1))': 26.754636344909667}

Even more closer with smaller kernel size:

In [11]:
test(L=1, imsize=1000, ch=3, k=2) # n=12

{'var(y_L)': 12.026659684181213, 'var(delta(x1))': 11.999356188774108}

But when kernel size is comparable to image size, only a few input pixels at the image center get k^2d grads, so real $\hat{n}$ is between d and k^2d:

In [13]:
test(L=1, imsize=7, ch=3, k=3) # n=27

{'var(y_L)': 27.256405773162843, 'var(delta(x1))': 13.555154795646667}

In the extreme case image size equals kernel size, so each input pixel is used in exactly one output pixel and thus gets back one gradient of variance d, hence real $\hat{n}=d$:

In [16]:
test(L=1, imsize=3, ch=3, k=3) # n=27, real n_hat=3

{'var(y_L)': 27.152843496501447, 'var(delta(x1))': 3.3247227012366056}

Same happens when stride equals kernel size - each input pixel is used only once:

In [17]:
test(L=1, imsize=100, ch=3, k=3, stride=3) # n=27, real n_hat=3

{'var(y_L)': 26.94449602127075, 'var(delta(x1))': 2.9340525698661803}

When stride is not 1, but less than kernel size, some fraction of input pixels is used more than once, so get more than d but less than k^2d grads:

In [18]:
# n=4*4*3=48, real n_hat is inside interval (3,48)
print(test(L=1, imsize=100, ch=3, k=4, stride=4))
print(test(L=1, imsize=100, ch=3, k=4, stride=3))
print(test(L=1, imsize=100, ch=3, k=4, stride=2))
print(test(L=1, imsize=100, ch=3, k=4, stride=1))

{'var(y_L)': 47.99311561584473, 'var(delta(x1))': 3.0068949675559997}
{'var(y_L)': 47.557743053436276, 'var(delta(x1))': 5.1832145404815675}
{'var(y_L)': 47.56187725067139, 'var(delta(x1))': 11.376869382858276}
{'var(y_L)': 47.71765846252441, 'var(delta(x1))': 44.87692974090576}


For several layers all of the above holds, but things get much noisier:

In [19]:
# n=27**3/4=4920, real n_hat is inside (3**3/4,27**3/4) = (7,4920)
L = 3
print(test(L, imsize=500, ch=3, k=3, stride=3))
print(test(L, imsize=500, ch=3, k=3, stride=2))
print(test(L, imsize=500, ch=3, k=3, stride=1))

{'var(y_L)': 4315.193284301758, 'var(delta(x1))': 6.674200213253498}
{'var(y_L)': 4921.608842010498, 'var(delta(x1))': 79.8527949142456}
{'var(y_L)': 5323.199635009765, 'var(delta(x1))': 6091.3666088867185}
