In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader
import torchsummary

As it turns out, quite serendipitously, batch normalization conveys all three benefits: preprocessing, numerical stability, and regularization.

Fixing a trained model, you might think that we would prefer using the entire dataset to estimate the mean and variance. Once training is complete, why would we want the same image to be classified differently, depending on the batch in which it happens to reside? During training, such exact calculation is infeasible because the intermediate variables for all data examples change every time we update our model. However, once the model is trained, we can calculate the means and variances of each layerâ€™s variables based on the entire dataset. Indeed this is standard practice for models employing batch normalization; thus batch normalization layers function differently in training mode (normalizing by minibatch statistics) than in prediction mode (normalizing by dataset statistics). In this form they closely resemble the behavior of dropout regularization of Section 5.6, where noise is only injected during training.

The key difference from batch normalization in fully connected layers is that we apply the operation on a per-channel basis across all locations. This is compatible with our assumption of translation invariance that led to convolutions: we assumed that the specific location of a pattern within an image was not critical for the purpose of understanding.

## **Data**
http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-iamges-idx3-ubyte.gz  
http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz  
http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-iamges-idx3-ubyte.gz  
http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-iamges-idx1-ubyte.gz  
`{root}\FashionMNIST\raw`

In [2]:
trans = transforms.Compose([transforms.Resize((224, 224)),  # upscale
                            transforms.ToTensor()])

data_train = torchvision.datasets.FashionMNIST(
    root='./data', train=True, transform=trans, download=False 
)
data_val = torchvision.datasets.FashionMNIST(
    root='./data', train=False, transform=trans, download=False
)

In [3]:
data_train

Dataset FashionMNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
           )

In [4]:
data_val

Dataset FashionMNIST
    Number of datapoints: 10000
    Root location: ./data
    Split: Test
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
           )

In [5]:
image, label = data_train[0]  # [image, label]
print(image.shape) # (channel, height, weight)
print(label)

torch.Size([1, 224, 224])
9


## **LeNet with BN**

In [6]:
class BNLenet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.LazyConv2d(out_channels=6, kernel_size=5), nn.LazyBatchNorm2d(),
            nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
            nn.LazyConv2d(out_channels=16, kernel_size=5), nn.LazyBatchNorm2d(),
            nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(), nn.LazyLinear(120), nn.LazyBatchNorm1d(),
            nn.Sigmoid(), nn.LazyLinear(84), nn.LazyBatchNorm1d(),
            nn.Sigmoid(), nn.LazyLinear(num_classes)
        )

    def forward(self, X):
        return self.net(X)

In [7]:
torchsummary.summary(BNLenet(), input_size=(1, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 6, 220, 220]             156
       BatchNorm2d-2          [-1, 6, 220, 220]              12
           Sigmoid-3          [-1, 6, 220, 220]               0
         AvgPool2d-4          [-1, 6, 110, 110]               0
            Conv2d-5         [-1, 16, 106, 106]           2,416
       BatchNorm2d-6         [-1, 16, 106, 106]              32
           Sigmoid-7         [-1, 16, 106, 106]               0
         AvgPool2d-8           [-1, 16, 53, 53]               0
           Flatten-9                [-1, 44944]               0
           Linear-10                  [-1, 120]       5,393,400
      BatchNorm1d-11                  [-1, 120]             240
          Sigmoid-12                  [-1, 120]               0
           Linear-13                   [-1, 84]          10,164
      BatchNorm1d-14                   



## **Training**

In [8]:
batch_size = 128

train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(data_val, batch_size=batch_size, shuffle=False)

In [9]:
model = BNLenet()

In [10]:
optimizer = torch.optim.SGD(params=model.parameters(), lr=0.01)

In [11]:
def accuracy(y_hat, y):
    # y_hat: (B, q)
    # y: (B)
    preds = y_hat.argmax(axis=1).type(y.dtype)  # (B)
    compare = (preds == y).type(torch.float32)  # (B)
    return compare.sum()

In [None]:
%%time
for i in range(10):
    model.train()

    train_loss = 0
    num_train_batches = 0
    for b, (X, y) in enumerate(train_loader):
        optimizer.zero_grad()
        y_hat = model(X)
        loss = F.cross_entropy(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        num_train_batches += 1
        if b % 10 == 0:
            print(f'epoch={i} | batch={b} | train_loss={train_loss/num_train_batches:.4f}')

    model.eval()
    with torch.no_grad():
        val_loss = 0
        num_val_batches = 0
        val_acc = 0
        total = 0
        for X, y in val_loader:
            y_hat = model(X)
            loss = F.cross_entropy(y_hat, y)
            val_loss += loss.item()
            num_val_batches += 1
            val_acc += accuracy(y_hat, y)
            total += y.numel()
        
    print(f'epoch={i} | train_loss={train_loss/num_train_batches:.4f} | val_loss={val_loss/num_val_batches:.4f} | val_acc={val_acc/total:.4f}')