<a href="https://colab.research.google.com/github/kyj098707/Deep-Learning-Paeper-Review-and-Code/blob/master/2_Batch_Normalization_Experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
import torchvision.datasets
import torchvision.transforms as transforms
import random

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

learning_rate = 0.001
training_epochs = 15
batch_size = 100

In [None]:
mnist_train = torchvision.datasets.MNIST(root='MNIST_data/',
                          train=True,
                          transform=transforms.ToTensor(),
                          download=True)

mnist_test = torchvision.datasets.MNIST(root='MNIST_data/',
                         train=False,
                         transform=transforms.ToTensor(),
                         download=True)

In [None]:
train_loader = torch.utils.data.DataLoader(dataset=mnist_train,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          drop_last=True)

test_loader = torch.utils.data.DataLoader(dataset=mnist_test,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          drop_last=True)

In [None]:
## 배치 정규화 적용 x
without_BN_model = torch.nn.Sequential(
    nn.Linear(784,32,bias=True),
    nn.ReLU(),
    nn.Linear(32,32,bias=True),
    nn.ReLU(),
    nn.Linear(32,10,bias=True),
).to(device)

##  배치 정규화 적용 o
with_BN_model = torch.nn.Sequential(
    nn.Linear(784,32,bias=True),
    nn.BatchNorm1d(32),
    nn.ReLU(),
    nn.Linear(32,32,bias=True),
    nn.BatchNorm1d(32),
    nn.ReLU(),
    nn.Linear(32,10,bias=True),
).to(device)

In [None]:
criterion = torch.nn.CrossEntropyLoss().to(device)
BN_optimizer = torch.optim.Adam(with_BN_model.parameters(), lr=learning_rate)
WOBN_optimizer = torch.optim.Adam(without_BN_model.parameters(), lr=learning_rate)

In [None]:
train_losses = []
train_accs = []

valid_losses = []
valid_accs = []

train_total_batch = len(train_loader)
test_total_batch = len(test_loader)
for epoch in range(training_epochs):
    with_BN_model.train()

    for X, Y in train_loader:
        X = X.view(-1, 28 * 28).to(device)
        Y = Y.to(device)

        BN_optimizer.zero_grad()
        bn_prediction = with_BN_model(X)
        bn_loss = criterion(bn_prediction, Y)
        bn_loss.backward()
        BN_optimizer.step()

        WOBN_optimizer.zero_grad()
        wobn_prediction = without_BN_model(X)
        wobn_loss = criterion(wobn_prediction, Y)
        wobn_loss.backward()
        WOBN_optimizer.step()

    with torch.no_grad():
        with_BN_model.eval()     # set the model to evaluation mode

        # Test the model using train sets
        bn_loss, wobn_loss, bn_acc, wobn_acc = 0, 0, 0, 0
        for i, (X, Y) in enumerate(train_loader):
            X = X.view(-1, 28 * 28).to(device)
            Y = Y.to(device)

            bn_prediction = with_BN_model(X)
            bn_correct_prediction = torch.argmax(bn_prediction, 1) == Y
            bn_loss += criterion(bn_prediction, Y)
            bn_acc += bn_correct_prediction.float().mean()

            wobn_prediction = without_BN_model(X)
            wobn_correct_prediction = torch.argmax(wobn_prediction, 1) == Y
            wobn_loss += criterion(wobn_prediction, Y)
            wobn_acc += wobn_correct_prediction.float().mean()

        bn_loss, wobn_loss, bn_acc, wobn_acc = bn_loss / train_total_batch, wobn_loss / train_total_batch, bn_acc / train_total_batch, wobn_acc / train_total_batch

        # Save train losses/acc
        train_losses.append([bn_loss, wobn_loss])
        train_accs.append([bn_acc, wobn_acc])
        print(
            '[Epoch %d-TRAIN] Batchnorm Loss(Acc): bn_loss:%.5f(bn_acc:%.2f) vs No Batchnorm Loss(Acc): wobn_loss:%.5f(wobn_acc:%.2f)' % (
            (epoch + 1), bn_loss.item(), bn_acc.item(), wobn_loss.item(), wobn_acc.item()))
        # Test the model using test sets
        bn_loss, wobn_loss, bn_acc, wobn_acc = 0, 0, 0, 0
        for i, (X, Y) in enumerate(test_loader):
            X = X.view(-1, 28 * 28).to(device)
            Y = Y.to(device)

            bn_prediction = with_BN_model(X)
            bn_correct_prediction = torch.argmax(bn_prediction, 1) == Y
            bn_loss += criterion(bn_prediction, Y)
            bn_acc += bn_correct_prediction.float().mean()

            wobn_prediction = without_BN_model(X)
            wobn_correct_prediction = torch.argmax(wobn_prediction, 1) == Y
            wobn_loss += criterion(wobn_prediction, Y)
            wobn_acc += wobn_correct_prediction.float().mean()

        bn_loss, wobn_loss, bn_acc, wobn_acc = bn_loss / test_total_batch, wobn_loss / test_total_batch, bn_acc / test_total_batch, wobn_acc / test_total_batch

        # Save valid losses/acc
        valid_losses.append([bn_loss, wobn_loss])
        valid_accs.append([bn_acc, wobn_acc])
        print(
            '[Epoch %d-VALID] Batchnorm Loss(Acc): bn_loss:%.5f(bn_acc:%.2f) vs No Batchnorm Loss(Acc): wobn_loss:%.5f(wobn_acc:%.2f)' % (
                (epoch + 1), bn_loss.item(), bn_acc.item(), wobn_loss.item(), wobn_acc.item()))
        print()

print('Learning finished')

[Epoch 1-TRAIN] Batchnorm Loss(Acc): bn_loss:0.16300(bn_acc:0.96) vs No Batchnorm Loss(Acc): wobn_loss:0.26204(wobn_acc:0.92)
[Epoch 1-VALID] Batchnorm Loss(Acc): bn_loss:0.16806(bn_acc:0.95) vs No Batchnorm Loss(Acc): wobn_loss:0.25685(wobn_acc:0.93)

[Epoch 2-TRAIN] Batchnorm Loss(Acc): bn_loss:0.10865(bn_acc:0.97) vs No Batchnorm Loss(Acc): wobn_loss:0.19944(wobn_acc:0.94)
[Epoch 2-VALID] Batchnorm Loss(Acc): bn_loss:0.12802(bn_acc:0.96) vs No Batchnorm Loss(Acc): wobn_loss:0.20185(wobn_acc:0.94)

[Epoch 3-TRAIN] Batchnorm Loss(Acc): bn_loss:0.08525(bn_acc:0.98) vs No Batchnorm Loss(Acc): wobn_loss:0.16570(wobn_acc:0.95)
[Epoch 3-VALID] Batchnorm Loss(Acc): bn_loss:0.10932(bn_acc:0.97) vs No Batchnorm Loss(Acc): wobn_loss:0.17320(wobn_acc:0.95)

[Epoch 4-TRAIN] Batchnorm Loss(Acc): bn_loss:0.07119(bn_acc:0.98) vs No Batchnorm Loss(Acc): wobn_loss:0.14366(wobn_acc:0.96)
[Epoch 4-VALID] Batchnorm Loss(Acc): bn_loss:0.10382(bn_acc:0.97) vs No Batchnorm Loss(Acc): wobn_loss:0.15683(wobn