In [1]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import torch.nn.init as init

random.seed(42)
torch.manual_seed(42)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])



In [2]:
# Download and load the training and test sets
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create DataLoader for batching the data
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)

In [3]:
len(trainset), len(testset)

(60000, 10000)

In [4]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Tanh()  
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        self.activations = {}  
        
        x = self.flatten(x)
        self.activations['input'] = x
        
        x = self.fc1(x)
        self.activations['fc1'] = x
        
        x = self.fc2(x)
        self.activations['fc2'] = x
        
        x = self.fc3(x)
        self.activations['fc3'] = x
        
        return x


In [5]:
def analyze_layer_statistics(model):
    
    
    stats = {}
    for layer_name, activations in model.activations.items():
        stats[layer_name] = {
            'mean': torch.mean(activations).item(),
            'var': torch.var(activations).item(),
            'min': torch.min(activations).item(),
            'max': torch.max(activations).item()
        }
    
    return stats

In [6]:
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.01)

In [7]:
epochs = 20
losses = []
stats = []
for epoch in range(epochs):
    model.train()
    total = 0
    correct = 0
    running_loss = 0.0
    for images, labels in trainloader:
        optimizer.zero_grad() 
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss
        total += images.shape[0]
        correct += (torch.argmax(output,dim=1)==labels).sum()
    print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(trainloader):.4f}')
    print(f'Accuracy = {(correct/total)*100}')
    losses.append(running_loss/len(trainloader)) 
    stats.append(analyze_layer_statistics(model))


Epoch 1/20, Loss: 0.6898
Accuracy = 79.22333526611328
Epoch 2/20, Loss: 0.5069
Accuracy = 85.09500122070312
Epoch 3/20, Loss: 0.4851
Accuracy = 85.75666809082031
Epoch 4/20, Loss: 0.4681
Accuracy = 86.30166625976562
Epoch 5/20, Loss: 0.4597
Accuracy = 86.75
Epoch 6/20, Loss: 0.4909
Accuracy = 85.83833312988281
Epoch 7/20, Loss: 0.5200
Accuracy = 85.26333618164062
Epoch 8/20, Loss: 0.4962
Accuracy = 85.413330078125
Epoch 9/20, Loss: 0.4590
Accuracy = 86.98500061035156
Epoch 10/20, Loss: 0.4539
Accuracy = 86.87833404541016
Epoch 11/20, Loss: 0.4353
Accuracy = 87.35333251953125
Epoch 12/20, Loss: 0.4409
Accuracy = 87.42166900634766
Epoch 13/20, Loss: 0.4257
Accuracy = 87.79167175292969
Epoch 14/20, Loss: 0.4070
Accuracy = 88.41500091552734
Epoch 15/20, Loss: 0.4169
Accuracy = 87.96666717529297
Epoch 16/20, Loss: 0.4125
Accuracy = 88.21499633789062
Epoch 17/20, Loss: 0.4427
Accuracy = 87.19833374023438
Epoch 18/20, Loss: 0.4120
Accuracy = 88.27166748046875
Epoch 19/20, Loss: 0.4098
Accurac

In [8]:
correct = 0
total = 0
for images,labels in trainloader:
    model.eval()
    y_pred = model(images)
    loss = criterion(y_pred, labels)
    predicted = torch.argmax(y_pred,dim=1)
    total += labels.size(0)
    correct += (predicted == labels).sum()


# Calculate accuracy
accuracy = 100 * correct / total
print(f'Accuracy: {accuracy}%')
   

Accuracy: 87.19833374023438%


In [9]:
stats

[{'input': {'mean': -0.7525574564933777,
   'var': 0.36007067561149597,
   'min': -1.0,
   'max': 1.0},
  'fc1': {'mean': -2.0509326457977295,
   'var': 802.3048095703125,
   'min': -48.18707275390625,
   'max': 57.743446350097656},
  'fc2': {'mean': -0.06774401664733887,
   'var': 0.9928655624389648,
   'min': -1.0,
   'max': 1.0},
  'fc3': {'mean': 2.3586699962615967,
   'var': 9.788019180297852,
   'min': -5.2798051834106445,
   'max': 12.800895690917969}},
 {'input': {'mean': -0.7438746094703674,
   'var': 0.3727165460586548,
   'min': -1.0,
   'max': 1.0},
  'fc1': {'mean': -2.273608684539795,
   'var': 819.388427734375,
   'min': -104.11796569824219,
   'max': 55.68777847290039},
  'fc2': {'mean': -0.07153373211622238,
   'var': 0.9925441741943359,
   'min': -1.0,
   'max': 1.0},
  'fc3': {'mean': 2.1780524253845215,
   'var': 13.376968383789062,
   'min': -6.2618088722229,
   'max': 12.282782554626465}},
 {'input': {'mean': -0.7414805293083191,
   'var': 0.37564781308174133,
   

In [10]:
correct = 0
total = 0
for images,labels in testloader:
    model.eval()
    y_pred = model(images)
    loss = criterion(y_pred, labels)
    predicted = torch.argmax(y_pred,dim=1)
    total += labels.size(0)
    correct += (predicted == labels).sum()


# Calculate accuracy
accuracy = 100 * correct / total
print(f'Accuracy: {accuracy}%')
   

Accuracy: 87.43000030517578%
