## Xavier Method
Fills the input Tensor with values according to the method described in Understanding the difficulty of training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a normal distribution. The resulting tensor will have values sampled from $ \mathcal{N} (0,std^2) $ where
$$
gain\times \sqrt{\frac{2}{fan\_in-fan\_out}} 
$$

In [1]:
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
import torch.nn as nn

batch_size=256

d2l.use_svg_display()
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=True)

train_iter = data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=4)
test_iter = data.DataLoader(mnist_test,batch_size,shuffle=False,num_workers=4)

In [2]:
import torch.nn as nn

net = nn.Sequential(nn.Flatten(),
                    nn.Linear(28*28,256),
                    nn.ReLU(),
                    nn.Linear(256,10),
                    nn.Softmax(dim=1))

def init_layer(m):
    if type(m) == nn.Linear:
        nn.init.xavier_normal_(m.weight)
        m.bias.data.fill_(0)

loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(),lr=0.05)

net.apply(init_layer)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=256, bias=True)
  (2): ReLU()
  (3): Linear(in_features=256, out_features=10, bias=True)
  (4): Softmax(dim=1)
)

In [3]:
from torch.nn.functional import one_hot
epochs = 45

for epoch in range(epochs):
    for X,y in train_iter:
        l = loss(net(X),y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
    with torch.no_grad():
        if (epoch+1)%5 == 0:
            print(f" epoch{epoch+1} test_acc {d2l.evaluate_accuracy_gpu(net,iter(test_iter))}")

 epoch5 test_acc 0.7677
 epoch10 test_acc 0.7866
 epoch15 test_acc 0.7971
 epoch20 test_acc 0.7999
 epoch25 test_acc 0.8029
 epoch30 test_acc 0.8065
 epoch35 test_acc 0.8083
 epoch40 test_acc 0.8093
 epoch45 test_acc 0.81


This function will mitigating gradient explossion and gradient disappearance