In [1]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import torch.nn.init as init

random.seed(42)
torch.manual_seed(42)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])



In [2]:
# Download and load the training and test sets
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create DataLoader for batching the data
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)

In [3]:
len(trainset), len(testset)

(60000, 10000)

In [4]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Tanh()  
        self.fc3 = nn.Linear(256, 10)
        init.xavier_uniform_(self.fc1.weight, gain=nn.init.calculate_gain('tanh'))
        init.xavier_uniform_(self.fc3.weight)

    def forward(self, x):
        self.activations = {}  
        
        x = self.flatten(x)
        self.activations['input'] = x
        
        x = self.fc1(x)
        self.activations['fc1'] = x
        
        x = self.fc2(x)
        self.activations['fc2'] = x
        
        x = self.fc3(x)
        self.activations['fc3'] = x
        
        return x


In [5]:
def analyze_layer_statistics(model):
    
    
    stats = {}
    for layer_name, activations in model.activations.items():
        stats[layer_name] = {
            'mean': torch.mean(activations).item(),
            'var': torch.var(activations).item(),
            'min': torch.min(activations).item(),
            'max': torch.max(activations).item()
        }
    
    return stats

In [6]:
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.01)

In [7]:
epochs = 20
losses = []
stats = []
for epoch in range(epochs):
    model.train()
    total = 0
    correct = 0
    running_loss = 0.0
    for images, labels in trainloader:
        optimizer.zero_grad() 
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss
        total += images.shape[0]
        correct += (torch.argmax(output,dim=1)==labels).sum()
    print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(trainloader):.4f}')
    print(f'Accuracy = {(correct/total)*100}')
    losses.append(running_loss/len(trainloader)) 
    stats.append(analyze_layer_statistics(model))


Epoch 1/20, Loss: 0.7544
Accuracy = 78.99500274658203
Epoch 2/20, Loss: 0.4993
Accuracy = 84.97166442871094
Epoch 3/20, Loss: 0.4939
Accuracy = 85.3316650390625
Epoch 4/20, Loss: 0.4953
Accuracy = 85.71666717529297
Epoch 5/20, Loss: 0.4712
Accuracy = 86.28333282470703
Epoch 6/20, Loss: 0.4636
Accuracy = 86.48999786376953
Epoch 7/20, Loss: 0.4530
Accuracy = 87.1066665649414
Epoch 8/20, Loss: 0.4488
Accuracy = 87.15666198730469
Epoch 9/20, Loss: 0.4626
Accuracy = 86.74832916259766
Epoch 10/20, Loss: 0.4434
Accuracy = 87.45500183105469
Epoch 11/20, Loss: 0.4208
Accuracy = 88.20499420166016
Epoch 12/20, Loss: 0.4249
Accuracy = 88.038330078125
Epoch 13/20, Loss: 0.4314
Accuracy = 87.6116714477539
Epoch 14/20, Loss: 0.4177
Accuracy = 88.19000244140625
Epoch 15/20, Loss: 0.4210
Accuracy = 88.19000244140625
Epoch 16/20, Loss: 0.4032
Accuracy = 88.61333465576172
Epoch 17/20, Loss: 0.4284
Accuracy = 87.76499938964844
Epoch 18/20, Loss: 0.4263
Accuracy = 88.00833129882812
Epoch 19/20, Loss: 0.408

In [8]:
correct = 0
total = 0
for images,labels in trainloader:
    model.eval()
    y_pred = model(images)
    loss = criterion(y_pred, labels)
    predicted = torch.argmax(y_pred,dim=1)
    total += labels.size(0)
    correct += (predicted == labels).sum()


# Calculate accuracy
accuracy = 100 * correct / total
print(f'Accuracy: {accuracy}%')
   

Accuracy: 89.53333282470703%


In [9]:
stats

[{'input': {'mean': -0.7240942120552063,
   'var': 0.40124067664146423,
   'min': -1.0,
   'max': 1.0},
  'fc1': {'mean': 1.050910234451294,
   'var': 967.3980102539062,
   'min': -79.095947265625,
   'max': 80.490966796875},
  'fc2': {'mean': 0.02612701617181301,
   'var': 0.997160017490387,
   'min': -1.0,
   'max': 1.0},
  'fc3': {'mean': 5.061877727508545,
   'var': 9.07724380493164,
   'min': -2.8012540340423584,
   'max': 14.708564758300781}},
 {'input': {'mean': -0.7244036197662354,
   'var': 0.39743906259536743,
   'min': -1.0,
   'max': 1.0},
  'fc1': {'mean': 0.9943134784698486,
   'var': 984.3416748046875,
   'min': -102.2292709350586,
   'max': 80.68025970458984},
  'fc2': {'mean': 0.028667941689491272,
   'var': 0.9977889657020569,
   'min': -1.0,
   'max': 1.0},
  'fc3': {'mean': 5.142202377319336,
   'var': 13.980228424072266,
   'min': -4.393506050109863,
   'max': 15.703813552856445}},
 {'input': {'mean': -0.7228939533233643,
   'var': 0.4012758433818817,
   'min': -1.

In [10]:
correct = 0
total = 0
for images,labels in testloader:
    model.eval()
    y_pred = model(images)
    loss = criterion(y_pred, labels)
    predicted = torch.argmax(y_pred,dim=1)
    total += labels.size(0)
    correct += (predicted == labels).sum()


# Calculate accuracy
accuracy = 100 * correct / total
print(f'Accuracy: {accuracy}%')
   

Accuracy: 89.2300033569336%
