In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

from torch.utils.data import DataLoader

In [2]:
# 1 Data Transformer
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.5,),(0.5,))])

# 2 Create Train Dataset
trainset = torchvision.datasets.MNIST(root='./data', train =True,
                                    download = True, transform = transform)
trainloader = DataLoader(trainset, batch_size =64, shuffle =True)

#3 Create Test Dataset
testset = torchvision.datasets.MNIST(root = "./data", train = False,
                                     download = True, transform = transform)
testloader = DataLoader(testset, batch_size=64, shuffle=True)

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [22]:
class NeuralNetwork(nn.Module):

    def __init__(self):
        super(NeuralNetwork,self).__init__()

        self.fc1 = nn.Linear(28*28,512)
        self.fc2 = nn.Linear(512,256)
        self.fc3 = nn.Linear(256,128)
        self.fc4 = nn.Linear(128,10)


    def forward(self,x):
        x = x.view(x.size(0),-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return self.fc4(x)

In [28]:
all_activation = [torch.zeros(512).to(device),  # 1st hidden layer
                  torch.zeros(256).to(device),  # 2nd hidden layer
                  torch.zeros(128).to(device)   # 3nd hidden layer
                 ]


In [29]:
net = NeuralNetwork().to(device)
optimizer = optim.Adam(net.parameters(),lr = 0.001)
criterion = nn.CrossEntropyLoss().to(device)
epoch_size = 5

In [30]:
net.train()

for epoch in range(epoch_size):
    epoch_loss = 0
    correct_predictions = 0
    total_predictions = 0
    for data in trainloader:
        
        
        X, y = data[0].to(device), data[1].to(device)
    
        preds = net(X)
        loss = criterion(preds,y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_loss+= loss

        #Calculate accuracy

        _, predicted = torch.max(preds,1)
        correct_predictions += (predicted ==y).sum().item()
        total_predictions+= y.size(0)

    accuracy = correct_predictions / total_predictions

    average_loss = epoch_loss / len(trainloader)

    print(f"Epoch {epoch+1} Loss: {average_loss:.4f} Accuracy: {accuracy:.4f}")
    


Epoch 1 Loss: 0.3178 Accuracy: 0.8993
Epoch 2 Loss: 0.1442 Accuracy: 0.9555
Epoch 3 Loss: 0.1085 Accuracy: 0.9661
Epoch 4 Loss: 0.0913 Accuracy: 0.9715
Epoch 5 Loss: 0.0775 Accuracy: 0.9757


In [43]:
for data in trainloader: # for every batch
    inputs, _ = data

    # compute layer by layer activation
    activations_fc1 = torch.relu(net.fc1(inputs.view(-1,28*28).to(device)))
    activations_fc2 = torch.relu(net.fc2(activations_fc1))
    activations_fc3 = torch.relu(net.fc3(activations_fc2))

    # store layer-by-layer activatons
    all_activation[0] += torch.sum(activations_fc1,dim =0)
    all_activation[1] += torch.sum(activations_fc2,dim =0)
    all_activation[2] += torch.sum(activations_fc3,dim =0)

#compute average activation
for idx, activations in enumerate(all_activation):
    all_activation[idx] = activations/(len(trainloader)*64)

In [44]:
new_net = NeuralNetwork()

new_net.fc1.weight = net.fc1.weight
new_net.fc2.weight = net.fc2.weight
new_net.fc3.weight = net.fc3.weight
new_net.fc4.weight = net.fc4.weight

new_net.fc1.bias = net.fc1.bias
new_net.fc2.bias = net.fc2.bias
new_net.fc3.bias = net.fc3.bias
new_net.fc4.bias = net.fc4.bias

In [45]:
threshold = 0.4

In [46]:
#apply threshold on first hidden layer
keep_neurons_l1 = all_activation[0]> threshold

#remove rows (neurons) that have activations below threshold
new_net.fc1.weight = nn.Parameter(new_net.fc1.weight[keep_neurons_l1])

#remove columns (neurons) that have activations below threshold
new_net.fc2.weight = nn.Parameter(new_net.fc2.weight[:,keep_neurons_l1])

#remove bias of neurons that have activations below threshold
new_net.fc1.bias = nn.Parameter(new_net.fc1.bias[keep_neurons_l1])

In [48]:
## Second hidden layer
keep_neurons_l2 = all_activation[1]>threshold

new_net.fc2.weight = nn.Parameter(new_net.fc2.weight[keep_neurons_l2])
new_net.fc3.weight = nn.Parameter(new_net.fc3.weight[:, keep_neurons_l2])

new_net.fc2.bias = nn.Parameter(new_net.fc2.bias[keep_neurons_l2])

## Third hidden layer
keep_neurons_l3 = all_activation[2]>threshold

new_net.fc3.weight = nn.Parameter(new_net.fc3.weight[keep_neurons_l3])
new_net.fc4.weight = nn.Parameter(new_net.fc4.weight[:, keep_neurons_l3])

new_net.fc3.bias = nn.Parameter(new_net.fc3.bias[keep_neurons_l3])

In [50]:
print(net.fc1.weight.data.shape)
print(net.fc2.weight.data.shape)
print(net.fc3.weight.data.shape)
print(net.fc4.weight.data.shape)

print(net.fc1.bias.data.shape)
print(net.fc2.bias.data.shape)
print(net.fc3.bias.data.shape)
print(net.fc4.bias.data.shape)

torch.Size([512, 784])
torch.Size([256, 512])
torch.Size([128, 256])
torch.Size([10, 128])
torch.Size([512])
torch.Size([256])
torch.Size([128])
torch.Size([10])


In [51]:
print(new_net.fc1.weight.data.shape)
print(new_net.fc2.weight.data.shape)
print(new_net.fc3.weight.data.shape)
print(new_net.fc4.weight.data.shape)

print(new_net.fc1.bias.data.shape)
print(new_net.fc2.bias.data.shape)
print(new_net.fc3.bias.data.shape)
print(new_net.fc4.bias.data.shape)

torch.Size([132, 784])
torch.Size([164, 132])
torch.Size([97, 164])
torch.Size([10, 97])
torch.Size([132])
torch.Size([164])
torch.Size([97])
torch.Size([10])
