In [1]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import torch.nn.init as init

random.seed(42)
torch.manual_seed(42)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])



In [2]:
# Download and load the training and test sets
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create DataLoader for batching the data
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)

In [3]:
len(trainset), len(testset)

(60000, 10000)

In [4]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Tanh()  
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        self.activations = {}  
        
        x = self.flatten(x)
        self.activations['input'] = x
        
        x = self.fc1(x)
        self.activations['fc1'] = x
        
        x = self.fc2(x)
        self.activations['fc2'] = x
        
        x = self.fc3(x)
        self.activations['fc3'] = x
        
        return x


In [5]:
def analyze_layer_statistics(model):
    
    
    stats = {}
    for layer_name, activations in model.activations.items():
        stats[layer_name] = {
            'mean': torch.mean(activations).item(),
            'var': torch.var(activations).item(),
            'min': torch.min(activations).item(),
            'max': torch.max(activations).item()
        }
    
    return stats

In [6]:
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.01)

In [7]:
epochs = 20
losses = []
stats = []
for epoch in range(epochs):
    model.train()
    total = 0
    correct = 0
    running_loss = 0.0
    for images, labels in trainloader:
        optimizer.zero_grad() 
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss
        total += images.shape[0]
        correct += (torch.argmax(output,dim=1)==labels).sum()
    print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(trainloader):.4f}')
    print(f'Accuracy = {(correct/total)*100}')
    losses.append(running_loss/len(trainloader)) 
    stats.append(analyze_layer_statistics(model))
        

Epoch 1/20, Loss: 0.2962
Accuracy = 90.98500061035156
Epoch 2/20, Loss: 0.2355
Accuracy = 92.88333129882812
Epoch 3/20, Loss: 0.2208
Accuracy = 93.33000183105469
Epoch 4/20, Loss: 0.2121
Accuracy = 93.6500015258789
Epoch 5/20, Loss: 0.2138
Accuracy = 93.56333923339844
Epoch 6/20, Loss: 0.2077
Accuracy = 93.69999694824219
Epoch 7/20, Loss: 0.1949
Accuracy = 94.2066650390625
Epoch 8/20, Loss: 0.1864
Accuracy = 94.37999725341797
Epoch 9/20, Loss: 0.1937
Accuracy = 94.24832916259766
Epoch 10/20, Loss: 0.1980
Accuracy = 94.16666412353516
Epoch 11/20, Loss: 0.1884
Accuracy = 94.53499603271484
Epoch 12/20, Loss: 0.1775
Accuracy = 94.82666778564453
Epoch 13/20, Loss: 0.1731
Accuracy = 94.90167236328125
Epoch 14/20, Loss: 0.1683
Accuracy = 95.0566635131836
Epoch 15/20, Loss: 0.1662
Accuracy = 95.07833099365234
Epoch 16/20, Loss: 0.1736
Accuracy = 95.00666809082031
Epoch 17/20, Loss: 0.1665
Accuracy = 95.22000122070312
Epoch 18/20, Loss: 0.1633
Accuracy = 95.17166900634766
Epoch 19/20, Loss: 0.1

In [8]:
stats

[{'input': {'mean': -0.022650903090834618,
   'var': 0.9482970833778381,
   'min': -0.4242129623889923,
   'max': 2.821486711502075},
  'fc1': {'mean': -0.9084437489509583,
   'var': 586.65478515625,
   'min': -108.02886962890625,
   'max': 106.68702697753906},
  'fc2': {'mean': -0.049198683351278305,
   'var': 0.9744032621383667,
   'min': -1.0,
   'max': 1.0},
  'fc3': {'mean': 0.12721586227416992,
   'var': 17.290164947509766,
   'min': -9.230281829833984,
   'max': 14.586481094360352}},
 {'input': {'mean': -0.008559945039451122,
   'var': 0.9816018342971802,
   'min': -0.4242129623889923,
   'max': 2.821486711502075},
  'fc1': {'mean': -0.7454853653907776,
   'var': 1410.2916259765625,
   'min': -152.87429809570312,
   'max': 133.42295837402344},
  'fc2': {'mean': -0.02232307195663452,
   'var': 0.9862335920333862,
   'min': -1.0,
   'max': 1.0},
  'fc3': {'mean': -0.12379715591669083,
   'var': 20.760194778442383,
   'min': -10.868167877197266,
   'max': 16.228788375854492}},
 {'i

In [9]:
# epochs = 20
# losses = []
# for epoch in range(epochs):
#     model.train()
#     total = 0
#     correct = 0
#     running_loss = 0.0
#     for idx, (images, labels) in enumerate(trainloader):
#         optimizer.zero_grad()
#         output = model(images)
#         loss = criterion(output, labels)
#         loss.backward()
#         optimizer.step()
#         running_loss += loss
#         total += images.shape[0]
#         correct += (torch.argmax(output,dim=1)==labels).sum()
#         if idx%100==0:
#             print(f'idx= {idx}')
#             for name, param in model.named_parameters():
#                 if param.grad is not None:
#                     print(f'{name} - Grad Norm: {param.grad.norm():.4f}')
#     print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(trainloader):.4f}')
#     print(f'Accuracy = {(correct/total)*100}')
#     losses.append(running_loss/len(trainloader)) 


In [10]:
correct = 0
total = 0
for images,labels in trainloader:
    model.eval()
    y_pred = model(images)
    loss = criterion(y_pred, labels)
    predicted = torch.argmax(y_pred,dim=1)
    total += labels.size(0)
    correct += (predicted == labels).sum()


# Calculate accuracy
accuracy = 100 * correct / total
print(f'Accuracy: {accuracy}%')
   

Accuracy: 95.84666442871094%


In [12]:
correct = 0
total = 0
for images,labels in testloader:
    model.eval()
    y_pred = model(images)
    loss = criterion(y_pred, labels)
    predicted = torch.argmax(y_pred,dim=1)
    total += labels.size(0)
    correct += (predicted == labels).sum()


# Calculate accuracy
accuracy = 100 * correct / total
print(f'Accuracy: {accuracy}%')
   

Accuracy: 95.08000183105469%
