# The forward and backward passes 

In [4]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
#export
from exp.nb_01 import *

def get_data():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    return map(tensor, (x_train, y_train, x_valid, y_valid))

def normalize(x, mean, std):
    return (x-mean)/std

In [6]:
x_train, y_train, x_valid, y_valid = get_data()

In [7]:
train_mean, train_std = x_train.mean(), x_train.std()
train_mean, train_std

(tensor(0.1304), tensor(0.3073))

In [8]:
x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std)

In [9]:
train_mean, train_std = x_train.mean(), x_train.std()
train_mean, train_std

(tensor(-6.2598e-06), tensor(1.))

In [10]:
#export
def test_near_zero(x, tol=1e-3):
    assert x.abs()<tol, f"Near zero: {x}"

In [11]:
test_near_zero(x_train.mean())
test_near_zero(x_train.std() - 1)

In [12]:
n, m = x_train.shape
c = y_train.max() + 1
n, m, c

(50000, 784, tensor(10))

## Model with one hidden layer *from the foundations*

### Basic architecture

In [13]:
nh = 50

The next cell uses a simplified version of Kaiming initialization/He initialization. In case anybody is reading this: see [ebook by Michael Nielson](http://neuralnetworksanddeeplearning.com/chap3.html) for a detailed explanation.

In [14]:
w1 = torch.randn(m, nh)/math.sqrt(m)
b1 = torch.zeros(nh)
w2 = torch.randn(nh, 1)/math.sqrt(nh)
b2 = torch.zeros(1)

In [15]:
test_near_zero(w1.mean())
test_near_zero(w1.std() - 1/math.sqrt(m))

In [16]:
def lin(x, w, b):
    return x@w + b

In [17]:
t = lin(x_valid, w1, b1)

In [18]:
t.mean(), t.std()

(tensor(0.0074), tensor(0.9375))

**Thanks to our initialization the mean and std are approximately equal to 0, 1!**

In [19]:
def relu(x):
    return x.clamp_min(0.)

In [20]:
t = relu(lin(x_valid, w1, b1))

In [21]:
t.mean(), t.std()

(tensor(0.3725), tensor(0.5536))

**Unfortunately not mean 0, std 1 anymore. That makes sense, however, because we removed all negative activations!**

**We have to change our initialization in order to account for the rectifiers! This is where He init comes into play.**

In [22]:
w1 = torch.randn(m, nh)*math.sqrt(2/m)

In [23]:
w1.mean(), w1.std() - math.sqrt(2/m)

(tensor(0.0004), tensor(-2.5149e-05))

In [24]:
t = relu(lin(x_valid, w1, b1))

In [25]:
t.mean(), t.std()

(tensor(0.5351), tensor(0.7962))

In [26]:
#export
from torch.nn import init

In [27]:
w1 = torch.zeros(m, nh)
init.kaiming_normal_(w1, mode='fan_out')
t = relu(lin(x_valid, w1, b1))

In [28]:
t.mean(), t.std()

(tensor(0.5450), tensor(0.8177))

In [29]:
init.kaiming_normal_??

From the documentation:

mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
            preserves the magnitude of the variance of the weights in the
            forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
            backwards pass.
            
Why did we use `fan_out` here instead of `fan_in` despite dividing by `m` beforehand?            

In [30]:
w1.shape

torch.Size([784, 50])

In [31]:
torch.nn.Linear(m, nh).weight.shape

torch.Size([50, 784])

**For a linear layer PyTorch uses the transposed weight matrix. This is why we use `fan_out` instead of `fan_in`.**

In [32]:
torch.nn.Linear.forward??

In [33]:
torch.nn.functional.linear??

**The standard deviation is now approx 1. Can we reduce the mean to 0?**

In [41]:
def relu(x): return x.clamp_min(0.) - 0.5

In [60]:
w1 = torch.randn(m, nh)*math.sqrt(2./m)
t1 = relu(lin(x_valid, w1, b1))
t1.mean(), t1.std()

(tensor(0.0070), tensor(0.7867))

**Closer to mean 0, std 1!**

In [45]:
def model(x):
    l1 = lin(x, w1, b1)
    l2 = relu(l1)
    l3 = lin(l2, w2, b2)
    return l3

In [46]:
%timeit -n 10 _ = model(x_valid)

3.71 ms ± 257 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [49]:
assert model(x_valid).shape == torch.Size([x_valid.shape[0], 1])