In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

## Xavier initialization

* normal distribution

$$
\mathbf{W} \sim \mathcal{N}(0,std(\mathbf{W}))\\
std(\mathbf{W}) = \sqrt{\frac{2}{n_{in}+n_{out}}}
$$

* uniform distribution

$$
\mathbf{W} \sim \mathcal{U}(-\sqrt{\frac{6}{n_{in}+n_{out}}}, \sqrt{\frac{6}{n_{in}+n_{out}}})
$$

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# for reproducibility
np.random.seed(777)
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed_all(777)

In [3]:
# hyperparameters
lr = 0.001
training_epochs = 15
batch_size = 100

In [4]:
dataset_train = torchvision.datasets.MNIST(root='', train=True, transform=transforms.ToTensor())
dataset_test = torchvision.datasets.MNIST(root='', train=False, transform=transforms.ToTensor())

In [5]:
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, drop_last=True)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, drop_last=True)

In [6]:
model = nn.Sequential(nn.Linear(28*28, 256), nn.ReLU(),
                      nn.Linear(256, 256), nn.ReLU(),
                      nn.Linear(256, 10)).to(device)

In [7]:
# Xavier initialization

def xavier(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)

model.apply(xavier)

Sequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=256, bias=True)
  (3): ReLU()
  (4): Linear(in_features=256, out_features=10, bias=True)
)

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

print('Start training...')

for epoch in range(1, training_epochs+1):

    running_loss = 0
    running_correct = 0

    for X, y in dataloader_train:
        X, y = X.view(-1,28*28).to(device), y.to(device)
        y_pred = model(X)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            running_loss += loss.item()
            running_correct += sum(y_pred.argmax(dim=1) == y)

    running_test_loss = 0
    running_test_correct = 0
    
    with torch.no_grad():
        for X, y in dataloader_test:
            X, y = X.view(-1,28*28).to(device), y.to(device)
            y_pred = model(X)
            loss = criterion(y_pred, y)

            running_test_loss += loss.item()
            running_test_correct += sum(y_pred.argmax(dim=1) == y)
    
    print(f'epoch {epoch}\
    training loss:{running_loss/len(dataloader_train):.3f}\
    training acc:{running_correct/len(dataloader_train):.3f}\
    test loss:{running_test_loss/len(dataloader_test):.3f}\
    test acc:{running_test_correct/len(dataloader_test):.3f}')

print('Finished Training')

Start training...
epoch 1    training loss:0.243    training acc:92.885    test loss:0.107    test acc:96.710
epoch 2    training loss:0.091    training acc:97.200    test loss:0.095    test acc:96.940
epoch 3    training loss:0.061    training acc:98.060    test loss:0.078    test acc:97.700
epoch 4    training loss:0.043    training acc:98.603    test loss:0.094    test acc:97.240
epoch 5    training loss:0.031    training acc:99.010    test loss:0.072    test acc:97.850
epoch 6    training loss:0.025    training acc:99.157    test loss:0.078    test acc:97.700
epoch 7    training loss:0.022    training acc:99.265    test loss:0.087    test acc:97.730
epoch 8    training loss:0.018    training acc:99.393    test loss:0.075    test acc:98.050
epoch 9    training loss:0.017    training acc:99.398    test loss:0.081    test acc:97.910
epoch 10    training loss:0.014    training acc:99.522    test loss:0.091    test acc:97.750
epoch 11    training loss:0.012    training acc:99.608    tes

## He initialization

* normal distribution

$$
\mathbf{W} \sim \mathcal{N}(0,std^2(\mathbf{W}))\\
std(\mathbf{W}) = \sqrt{\frac{2}{n_{in}}}
$$

* uniform distribution

$$
\mathbf{W} \sim \mathcal{U}(-\sqrt{\frac{6}{n_{in}}}, \sqrt{\frac{6}{n_{in}}})
$$

In [9]:
model = nn.Sequential(nn.Linear(28*28, 256), nn.ReLU(),
                      nn.Linear(256, 256), nn.ReLU(),
                      nn.Linear(256, 10)).to(device)

In [10]:
# Xavier initialization

def he(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)

model.apply(he)

Sequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=256, bias=True)
  (3): ReLU()
  (4): Linear(in_features=256, out_features=10, bias=True)
)

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

print('Start training...')

for epoch in range(1, training_epochs+1):

    running_loss = 0
    running_correct = 0

    for X, y in dataloader_train:
        X, y = X.view(-1,28*28).to(device), y.to(device)
        y_pred = model(X)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            running_loss += loss.item()
            running_correct += sum(y_pred.argmax(dim=1) == y)

    running_test_loss = 0
    running_test_correct = 0
    
    with torch.no_grad():
        for X, y in dataloader_test:
            X, y = X.view(-1,28*28).to(device), y.to(device)
            y_pred = model(X)
            loss = criterion(y_pred, y)

            running_test_loss += loss.item()
            running_test_correct += sum(y_pred.argmax(dim=1) == y)
    
    print(f'epoch {epoch}\
    training loss:{running_loss/len(dataloader_train):.3f}\
    training acc:{running_correct/len(dataloader_train):.3f}\
    test loss:{running_test_loss/len(dataloader_test):.3f}\
    test acc:{running_test_correct/len(dataloader_test):.3f}')

print('Finished Training')

Start training...
epoch 1    training loss:0.235    training acc:93.008    test loss:0.109    test acc:96.640
epoch 2    training loss:0.089    training acc:97.347    test loss:0.081    test acc:97.570
epoch 3    training loss:0.059    training acc:98.210    test loss:0.078    test acc:97.510
epoch 4    training loss:0.041    training acc:98.733    test loss:0.068    test acc:97.910
epoch 5    training loss:0.031    training acc:98.988    test loss:0.066    test acc:97.920
epoch 6    training loss:0.023    training acc:99.245    test loss:0.064    test acc:98.010
epoch 7    training loss:0.022    training acc:99.235    test loss:0.081    test acc:97.710
epoch 8    training loss:0.016    training acc:99.483    test loss:0.077    test acc:97.940
epoch 9    training loss:0.016    training acc:99.445    test loss:0.071    test acc:98.130
epoch 10    training loss:0.015    training acc:99.468    test loss:0.076    test acc:98.070
epoch 11    training loss:0.012    training acc:99.610    tes

* why pytorch uses different default init??