<a href="https://colab.research.google.com/github/cluePrints/fastai-v3-notes/blob/master/fastai3_part2b_initializing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# multiply matrices bunch of times
# pay attention to variance
# google around kaiming_init_, kaiming_normal_

In [0]:
import torch
from torch import nn

In [143]:
conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=5)
conv.weight.shape

torch.Size([1, 1, 5, 5])

In [144]:
def stats(w):
  with torch.no_grad():
    return w.mean(), w.std()

stats(conv.weight)

(tensor(0.0284), tensor(0.1052))

In [0]:
matrix = conv.weight[0,0,...]

In [146]:
stats(matrix)

(tensor(0.0284), tensor(0.1052))

In [147]:
def multiplications_until_matrix_vanishes(matrix):
  result = matrix
  with torch.no_grad():
    for i in range(0, 500):
      result = matrix @ result
      if result.std() == 0:
        return i
      
      if result.std() != result.std():
        return i
  return -1

multiplications_until_matrix_vanishes(matrix)

93

In [148]:
# TODO: try other strategies, e.g. normal, orthonormal
matrix_unif = torch.empty(dim,dim, requires_grad=False).uniform_(-0.01, 0.01)
multiplications_until_matrix_vanishes(matrix_unif)

48

In [149]:
1/math.sqrt(dim)

0.044194173824159216

In [150]:
stats(matrix_unif)

(tensor(-1.1071e-05), tensor(0.0058))

In [151]:
# kaiming init
import math
dim = 512
matrix_kaiming = torch.empty(dim,dim, requires_grad=False).uniform_(-1, 1)/math.sqrt(dim)
# Note to self: tolerated multiplications can be 100...max+, depending on the randomness
multiplications_until_matrix_vanishes(matrix_kaiming)

182

In [152]:
stats(matrix_kaiming)

(tensor(6.0093e-06), tensor(0.0255))

In [153]:
from torch.nn import init
matrix_torch_kaiming = torch.empty(dim,dim, requires_grad=False)
init.kaiming_uniform_(matrix_torch_kaiming)
multiplications_until_matrix_vanishes(matrix_torch_kaiming)

238

In [0]:
def mult_and_relu(matrix):
  result = matrix
  with torch.no_grad():
    for i in range(0, 500):
      result = matrix @ result
      result = result.clamp_min(0.)
      if result.std() == 0:
        return i
      
      if result.std() != result.std():
        return i

      
  return float('Inf')

In [155]:
mult_and_relu(matrix_unif)

41

In [156]:
mult_and_relu(matrix_kaiming)

108

In [157]:
mult_and_relu(matrix_torch_kaiming)

inf

In [158]:
relu_kaiming_adjustment = 2
matrix_kaiming_relu = torch.empty(dim,dim, requires_grad=False).uniform_(-1, 1)*(math.sqrt(relu_kaiming_adjustment/dim))
mult_and_relu(matrix_kaiming_relu)

202

In [159]:
rnd_repeats = 100
mean = 0.
squares_sum = 0.
for i in range(0, rnd_repeats):
  matrix = torch.randn(dim, dim)
  vector = torch.randn(dim)
  result = matrix @ vector
  squares_sum += result.pow(2).mean().item()
  mean += result.mean().item()

mean/rnd_repeats, squares_sum/rnd_repeats

(-0.03923895612359047, 506.34834381103514)

In [160]:
coeff = 1/math.sqrt(dim)
rnd_repeats = 100
mean = 0.
squares_sum = 0.
for i in range(0, rnd_repeats):
  matrix = torch.randn(512, 512) * coeff
  vector = torch.randn(512)
  result = matrix @ vector
  squares_sum += result.pow(2).mean().item()
  mean += result.mean().item()

mean/rnd_repeats, squares_sum/rnd_repeats

(0.0024990265956148505, 1.0058455669879913)

In [161]:
coeff_relu = math.sqrt(2/dim)
rnd_repeats = 100
mean = 0.
squares_sum = 0.
for i in range(0, rnd_repeats):
  matrix = torch.randn(512, 512) * coeff_relu
  vector = torch.randn(512)
  result = (matrix @ vector).clamp_min(0.)
  squares_sum += result.pow(2).mean().item()
  mean += result.mean().item()

mean/rnd_repeats, squares_sum/rnd_repeats

(0.5725732117891311, 1.0228411120176315)

In [0]:
# review Xavier init paper
# calculate std with a loop & see where Xavier coeff comes from
# add relu