In [39]:
import torch as t
import torch.nn as nn
import pytorch_lightning as pl

In [69]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()

        self.conv1 = nn.Conv2d(3,16, kernel_size=3, stride=1, padding=1)
        self.bn1   = nn.BatchNorm2d(16)
        self.relu  = nn.ReLU()
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc       = nn.Linear(16**3, 10)
        self.count = 0 

    def forward(self, x):
        x = self.conv1(x)
        x += 100
        if (self.count % 20 == 0) :
            print(f'{self.count}. before BatchNorm2d : ', x.view(x.size(0),-1)[0][0:10], end='\t')
        x = self.bn1(x)
        if (self.count % 20 == 0) :
            print(f' -> {self.count}. after BatchNorm2d : ', x.view(x.size(0),-1)[0][0:10])
        x = self.relu(x)
        x = self.max_pool(x)
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        self.count += 1
        return x


In [70]:
# Basic settings
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = t.optim.SGD(model.parameters(), lr=0.01)

# get some sample data with batch size of 128
pl.seed_everything(42)
inputs = t.randn(128, 3, 32, 32)
labels = t.randint(0,10,(128,))
print(labels[:])

# training loop
for epoch in range(100):
    # forward pass 
    outputs = model(inputs)
    loss    = criterion(outputs, labels)
    
    # backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1)%10 == 0 :
        print(f"Epoch [{epoch+1}/100], Loss : {loss.item():.4f}")

Global seed set to 42


tensor([0, 7, 7, 9, 7, 8, 4, 0, 2, 3, 8, 6, 5, 4, 5, 9, 7, 8, 1, 2, 3, 0, 9, 9,
        2, 7, 5, 5, 0, 1, 1, 6, 5, 9, 5, 5, 2, 1, 1, 3, 4, 1, 0, 2, 5, 7, 1, 5,
        2, 0, 4, 5, 9, 8, 7, 5, 0, 2, 6, 9, 6, 8, 3, 9, 3, 4, 7, 8, 6, 5, 8, 0,
        2, 2, 7, 3, 6, 4, 3, 5, 2, 7, 6, 3, 1, 6, 8, 1, 6, 6, 7, 1, 4, 5, 4, 5,
        9, 0, 8, 3, 3, 3, 9, 6, 9, 1, 7, 0, 5, 6, 9, 9, 3, 8, 9, 8, 6, 6, 5, 5,
        0, 0, 3, 7, 1, 2, 6, 1])
0. before BatchNorm2d :  tensor([100.1686, 100.4076, 100.6059, 100.5386,  99.9092, 100.7679, 100.2677,
        100.6640, 100.0627, 100.1144], grad_fn=<SliceBackward0>)	 -> 0. after BatchNorm2d :  tensor([-0.0264,  0.4722,  0.8859,  0.7456, -0.5677,  1.2240,  0.1802,  1.0072,
        -0.2474, -0.1395], grad_fn=<SliceBackward0>)
Epoch [10/100], Loss : 6.3741
Epoch [20/100], Loss : 3.6105
20. before BatchNorm2d :  tensor([100.1972, 100.4443, 100.5880, 100.5302,  99.9322, 100.7420, 100.2735,
        100.6495, 100.0908, 100.1107], grad_fn=<SliceBackward0>)	 -> 20. a