In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

# class MNIST_CNN(nn.Module):
#     def __init__(self):
#         super(MNIST_CNN, self).__init__()
#         self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
#         self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
#         self.pool = nn.MaxPool2d(2, 2)
#         self.fc1 = nn.Linear(64 * 7 * 7, 128)
#         self.fc2 = nn.Linear(128, 10)
#         self.dropout = nn.Dropout(0.5)

#     def forward(self, x):
#         x = self.pool(F.relu(self.conv1(x)))
#         x = self.pool(F.relu(self.conv2(x)))
#         x = x.view(-1, 64 * 7 * 7)
#         x = F.relu(self.fc1(x))
#         x = self.dropout(x)
#         x = self.fc2(x)
#         return x

class CNNNet(nn.Module):
    def __init__(self):
        super(CNNNet, self).__init__()
        
        # Convolutional layers
                            #Init_channels, channels, kernel_size, padding) 
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        
        # Pooling layers
        self.pool = nn.MaxPool2d(2,2)
        
        # FC layers
        # Linear layer (64x4x4 -> 500)
        self.fc1 = nn.Linear(64 * 4 * 4, 500)
        
        # Linear Layer (500 -> 10)
        self.fc2 = nn.Linear(500, 10)
        
        # Dropout layer
        self.dropout = nn.Dropout(0.25)
        
    def forward(self, x):
        x = self.pool(F.elu(self.conv1(x)))
        x = self.pool(F.elu(self.conv2(x)))
        x = self.pool(F.elu(self.conv3(x)))
        
        # Flatten the image
        x = x.view(-1, 64*4*4)
        x = self.dropout(x)
        x = F.elu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

model = CNNNet()
print(model)

CNNNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1024, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
  (dropout): Dropout(p=0.25, inplace=False)
)


In [3]:
# Modify this to change the way we prune weights

def prune_weights(model, threshold):
    with torch.no_grad():
        for name, param in model.named_parameters():
            if 'weight' in name:
                print(f'Param max: {torch.max(param)}')
                print(f'Param min: {torch.min(param)}')
                mask = torch.abs(param) > threshold
                # mask = torch.zeros_like(param)
                param.mul_(mask)


In [3]:
# Select devices and create model

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

# Load data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.CIFAR10('~/.pytorch/CIFAR10_data/', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testset = datasets.CIFAR10('~/.pytorch/CIFAR10_data/', download=True, train=False, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False)

# Model, Loss, and Optimizer
# model = MNIST_CNN().to(device)
model = models.resnet50(weights=models.ResNet50_Weights(models.ResNet50_Weights.DEFAULT)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)


Using device: mps
Files already downloaded and verified
Files already downloaded and verified


In [4]:
def test_model_acc(post_prune=False):
    # Testing the model
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(100 * correct / total)
    print(f'Accuracy of the model on the 10000 test images ' + ('post' if post_prune else 'pre') + f'-pruning: {100 * correct / total}%')


In [5]:
# Training the model
# epochs = 5
# for epoch in range(epochs):
#     running_loss = 0
#     for images, labels in trainloader:
#         images, labels = images.to(device), labels.to(device)
#         optimizer.zero_grad()
#         output = model(images)
#         loss = criterion(output, labels)
#         loss.backward()
#         optimizer.step()
#         running_loss += loss.item()
#     print(f"Epoch {epoch+1} - Training loss: {running_loss/len(trainloader)}")


test_model_acc()
# Apply pruning
prune_weights(model, threshold=0.8)
test_model_acc(post_prune=True)


0.21
Accuracy of the model on the 10000 test images pre-pruning: 0.21%
Param max: 1.9789636135101318
Param min: -1.9180800914764404
Param max: 9.424724578857422
Param min: 1.4622493982315063
Param max: 1.0011051893234253
Param min: -1.512972116470337
Param max: 5.6243062019348145
Param min: 0.9666146636009216
Param max: 1.3543440103530884
Param min: -1.4747580289840698
Param max: 5.002761363983154
Param min: 0.689778745174408
Param max: 0.8202264308929443
Param min: -1.3517900705337524
Param max: 8.049958229064941
Param min: -1.5706143379211426
Param max: 1.017331838607788
Param min: -1.8782446384429932
Param max: 8.588117599487305
Param min: 0.21759121119976044
Param max: 0.8031531572341919
Param min: -1.0186502933502197
Param max: 7.427504539489746
Param min: 1.3327645063400269
Param max: 0.9823306202888489
Param min: -0.7846152186393738
Param max: 4.049496650695801
Param min: 1.242072343826294
Param max: 0.7892203330993652
Param min: -0.6373439431190491
Param max: 7.529585838317871
