## Accelerate Inference: Neural Network Pruning

In [None]:
import os
import numpy as np
import cv2
import pickle
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchsummary import summary

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# untar
!ls
!tar -xvzf dataset.tar.gz
# load train
train_images = pickle.load(open('train_images.pkl', 'rb'))
train_labels = pickle.load(open('train_labels.pkl', 'rb'))
# load val
val_images = pickle.load(open('val_images.pkl', 'rb'))
val_labels = pickle.load(open('val_labels.pkl', 'rb'))

dataset.tar.gz	sample_data  train_images.pkl
train_images.pkl
train_labels.pkl
val_images.pkl
val_labels.pkl


In [None]:
train_images = torch.tensor(train_images, dtype=torch.float32)
val_images = torch.tensor(val_images, dtype=torch.float32)

train_images = train_images.permute(0, 3, 1, 2)
val_images = val_images.permute(0, 3, 1, 2)

In [None]:
train_dataset = TensorDataset(train_images,
                              torch.tensor(train_labels.squeeze(), dtype=torch.long))
val_dataset = TensorDataset(val_images,
                            torch.tensor(val_labels.squeeze(), dtype=torch.long))

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()

        self.model = nn.Sequential(
            # First block: Conv -> ReLU -> Conv -> ReLU -> MaxPool -> Dropout
            nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=True),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=0, bias=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),

            # Second block: Conv -> ReLU -> Conv -> ReLU -> MaxPool -> Dropout
            nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=0, bias=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),

            # Flatten layer
            nn.Flatten(),

            # Fully connected block: Dense -> ReLU -> Dropout -> Dense -> Softmax
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 5),
        )

    def forward(self, x):
        return self.model(x)

In [None]:
model = ConvNet()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-6)

In [None]:
model = model.to(device)
summary(model, input_size=(3, 25, 25))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 25, 25]             896
              ReLU-2           [-1, 32, 25, 25]               0
            Conv2d-3           [-1, 32, 23, 23]           9,248
              ReLU-4           [-1, 32, 23, 23]               0
         MaxPool2d-5           [-1, 32, 11, 11]               0
           Dropout-6           [-1, 32, 11, 11]               0
            Conv2d-7           [-1, 64, 11, 11]          18,496
              ReLU-8           [-1, 64, 11, 11]               0
            Conv2d-9             [-1, 64, 9, 9]          36,928
             ReLU-10             [-1, 64, 9, 9]               0
        MaxPool2d-11             [-1, 64, 4, 4]               0
          Dropout-12             [-1, 64, 4, 4]               0
          Flatten-13                 [-1, 1024]               0
           Linear-14                  [

In [None]:
def train_one_epoch(model, train_loader, optimizer, criterion, device):
    model.train()  # Set model to training mode
    running_loss = 0.0
    correct = 0
    total = 0

    # Progress bar for the training loop
    train_loader_tqdm = tqdm(train_loader, desc="Training", leave=False)

    for inputs, labels in train_loader_tqdm:
        optimizer.zero_grad()  # Zero the parameter gradients
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Track loss and accuracy
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

        # Update tqdm description with current loss and accuracy
        train_loader_tqdm.set_postfix(loss=running_loss / total, accuracy=100 * correct / total)

    train_accuracy = 100 * correct / total
    train_loss = running_loss / len(train_loader)
    return train_loss, train_accuracy

In [None]:
def validate(model, val_loader, criterion, device):
    model.eval()  # Set model to evaluation mode
    val_loss = 0.0
    correct = 0
    total = 0

    # Progress bar for the validation loop
    val_loader_tqdm = tqdm(val_loader, desc="Validation", leave=False)

    with torch.no_grad():  # Disable gradient calculations for validation
        for inputs, labels in val_loader_tqdm:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Track loss and accuracy
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            # Update tqdm description with current validation loss and accuracy
            val_loader_tqdm.set_postfix(loss=val_loss / total, accuracy=100 * correct / total)

    val_accuracy = 100 * correct / total
    val_loss = val_loss / len(val_loader)
    return val_loss, val_accuracy

In [None]:
# Main training loop
num_epochs = 50
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    # Training
    train_loss, train_accuracy = train_one_epoch(model, train_loader, optimizer, criterion, device)

    # Validation
    val_loss, val_accuracy = validate(model, val_loader, criterion, device)

    # Print epoch results
    print(f'Epoch [{epoch+1}/{num_epochs}], '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%')

Epoch 1/50




Epoch [1/50], Train Loss: 1.5401, Train Acc: 28.82%, Val Loss: 1.4229, Val Acc: 38.42%
Epoch 2/50




Epoch [2/50], Train Loss: 1.3875, Train Acc: 39.51%, Val Loss: 1.3339, Val Acc: 41.98%
Epoch 3/50




Epoch [3/50], Train Loss: 1.3250, Train Acc: 43.45%, Val Loss: 1.2868, Val Acc: 43.76%
Epoch 4/50




KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), 'my_model_weights_1.pt', _use_new_zipfile_serialization=False)

### Loading in saved model

In [None]:
state_dict = torch.load('my_model_weights_1.pt', weights_only=True)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [None]:
print(model.model[0].weight.data.size())

torch.Size([32, 3, 3, 3])


### L1 Pruner

In [None]:
def l1_prune(model, prune_percent):

  # Hook is made to make sure we don't retrain these weights after zeroing out
  def make_hook(mask):
    def hook(grad):
      return mask * grad
    return hook


  pruned_outputs = None
  for module in model.model:
    if isinstance(module, nn.Conv2d):
      weight_mask = torch.ones_like(module.weight.data).to(device)
      m = int(prune_percent * module.out_channels)

      # Zero out pruned inputs
      if pruned_outputs is not None:
        module.weight.data[:, pruned_outputs, :, :] = 0.0
        weight_mask[pruned_outputs, :, :, :] = 0.0

      filters = module.weight.data.view(module.weight.data.size(0), -1)
      l1_vals = torch.sum(torch.abs(filters), dim=1)


      pruned_outputs = torch.topk(l1_vals, m, largest=False).indices.tolist()
      module.weight.data[pruned_outputs, :, :, :] = 0.0
      weight_mask[pruned_outputs, :, :, :] = 0.0
      module.weight.register_hook(make_hook(weight_mask))

      if module.bias is not None:
        bias_mask = torch.ones_like(module.bias.data).to(device)
        module.bias.data[pruned_outputs] = 0.0
        bias_mask[pruned_outputs] = 0.0
        module.bias.register_hook(make_hook(bias_mask))

      out_chan = module.out_channels

    if isinstance(module, nn.Linear):
      if pruned_outputs is not None:
        assert module.in_features % out_chan == 0
        chan_size = module.in_features // out_chan

        weight_mask = torch.ones_like(module.weight.data).to(device)


        flattened_prune_indices = []
        for chan_idx in pruned_outputs:
          flattened_prune_indices.extend(range(chan_idx*chan_size, (chan_idx+1)*chan_size))

        module.weight.data[:, flattened_prune_indices] = 0.0
        weight_mask[:, flattened_prune_indices] = 0.0
        module.weight.register_hook(make_hook(weight_mask))

        pruned_outputs = None

  return model


In [None]:
model = l1_prune(model, 0.3)

In [None]:
model.to('cuda')

ConvNet(
  (model): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Dropout(p=0.25, inplace=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (9): ReLU()
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Dropout(p=0.25, inplace=False)
    (12): Flatten(start_dim=1, end_dim=-1)
    (13): Linear(in_features=1024, out_features=512, bias=True)
    (14): ReLU()
    (15): Dropout(p=0.5, inplace=False)
    (16): Linear(in_features=512, out_features=5, bias=True)
  )
)

In [None]:
val_loss, val_accuracy = validate(model, val_loader, criterion, device)
print(val_loss, val_accuracy)

                                                                                        

1.1742661542530302 53.386138613861384




In [None]:
# Retraining loop
num_epochs = 40
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    # Training
    train_loss, train_accuracy = train_one_epoch(model, train_loader, optimizer, criterion, device)

    # Validation
    val_loss, val_accuracy = validate(model, val_loader, criterion, device)

    # Print epoch results
    print(f'Epoch [{epoch+1}/{num_epochs}], '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%')

Epoch 1/40




Epoch [1/40], Train Loss: 0.9260, Train Acc: 64.27%, Val Loss: 0.9374, Val Acc: 62.53%
Epoch 2/40




Epoch [2/40], Train Loss: 0.9306, Train Acc: 63.48%, Val Loss: 0.9311, Val Acc: 62.65%
Epoch 3/40




Epoch [3/40], Train Loss: 0.9322, Train Acc: 63.84%, Val Loss: 0.9217, Val Acc: 63.09%
Epoch 4/40




Epoch [4/40], Train Loss: 0.9218, Train Acc: 64.13%, Val Loss: 0.9221, Val Acc: 63.13%
Epoch 5/40




Epoch [5/40], Train Loss: 0.9042, Train Acc: 64.64%, Val Loss: 0.9019, Val Acc: 63.49%
Epoch 6/40




Epoch [6/40], Train Loss: 0.8925, Train Acc: 65.38%, Val Loss: 0.8949, Val Acc: 64.44%
Epoch 7/40




Epoch [7/40], Train Loss: 0.8833, Train Acc: 65.78%, Val Loss: 0.8799, Val Acc: 65.58%
Epoch 8/40




Epoch [8/40], Train Loss: 0.8692, Train Acc: 66.50%, Val Loss: 0.8717, Val Acc: 64.75%
Epoch 9/40




Epoch [9/40], Train Loss: 0.8608, Train Acc: 66.99%, Val Loss: 0.8574, Val Acc: 66.06%
Epoch 10/40




Epoch [10/40], Train Loss: 0.8581, Train Acc: 66.78%, Val Loss: 0.8641, Val Acc: 65.82%
Epoch 11/40




Epoch [11/40], Train Loss: 0.8487, Train Acc: 66.96%, Val Loss: 0.8588, Val Acc: 65.27%
Epoch 12/40




Epoch [12/40], Train Loss: 0.8409, Train Acc: 67.51%, Val Loss: 0.8585, Val Acc: 65.62%
Epoch 13/40




Epoch [13/40], Train Loss: 0.8320, Train Acc: 67.66%, Val Loss: 0.8506, Val Acc: 66.26%
Epoch 14/40




Epoch [14/40], Train Loss: 0.8289, Train Acc: 68.21%, Val Loss: 0.8521, Val Acc: 66.18%
Epoch 15/40




Epoch [15/40], Train Loss: 0.8148, Train Acc: 68.20%, Val Loss: 0.8599, Val Acc: 65.78%
Epoch 16/40




Epoch [16/40], Train Loss: 0.8037, Train Acc: 69.21%, Val Loss: 0.8321, Val Acc: 66.53%
Epoch 17/40




Epoch [17/40], Train Loss: 0.8082, Train Acc: 68.93%, Val Loss: 0.8350, Val Acc: 67.09%
Epoch 18/40




Epoch [18/40], Train Loss: 0.7968, Train Acc: 69.25%, Val Loss: 0.8341, Val Acc: 66.85%
Epoch 19/40




Epoch [19/40], Train Loss: 0.7997, Train Acc: 69.33%, Val Loss: 0.8334, Val Acc: 67.64%
Epoch 20/40




Epoch [20/40], Train Loss: 0.7885, Train Acc: 69.59%, Val Loss: 0.8250, Val Acc: 67.80%
Epoch 21/40




Epoch [21/40], Train Loss: 0.7860, Train Acc: 69.75%, Val Loss: 0.8369, Val Acc: 66.93%
Epoch 22/40




Epoch [22/40], Train Loss: 0.7818, Train Acc: 70.25%, Val Loss: 0.8224, Val Acc: 66.73%
Epoch 23/40




Epoch [23/40], Train Loss: 0.7734, Train Acc: 70.22%, Val Loss: 0.8193, Val Acc: 67.80%
Epoch 24/40




Epoch [24/40], Train Loss: 0.7737, Train Acc: 70.49%, Val Loss: 0.8217, Val Acc: 67.56%
Epoch 25/40




Epoch [25/40], Train Loss: 0.7690, Train Acc: 70.74%, Val Loss: 0.8126, Val Acc: 68.12%
Epoch 26/40




Epoch [26/40], Train Loss: 0.7643, Train Acc: 70.88%, Val Loss: 0.8279, Val Acc: 67.37%
Epoch 27/40




Epoch [27/40], Train Loss: 0.7570, Train Acc: 70.89%, Val Loss: 0.8273, Val Acc: 67.60%
Epoch 28/40




Epoch [28/40], Train Loss: 0.7518, Train Acc: 70.92%, Val Loss: 0.8114, Val Acc: 68.44%
Epoch 29/40




Epoch [29/40], Train Loss: 0.7448, Train Acc: 71.60%, Val Loss: 0.8074, Val Acc: 68.44%
Epoch 30/40




Epoch [30/40], Train Loss: 0.7457, Train Acc: 71.33%, Val Loss: 0.8148, Val Acc: 67.64%
Epoch 31/40




Epoch [31/40], Train Loss: 0.7390, Train Acc: 71.93%, Val Loss: 0.8191, Val Acc: 67.49%
Epoch 32/40




Epoch [32/40], Train Loss: 0.7355, Train Acc: 72.01%, Val Loss: 0.8001, Val Acc: 68.63%
Epoch 33/40




Epoch [33/40], Train Loss: 0.7336, Train Acc: 72.07%, Val Loss: 0.8036, Val Acc: 68.63%
Epoch 34/40




Epoch [34/40], Train Loss: 0.7293, Train Acc: 72.33%, Val Loss: 0.8042, Val Acc: 68.32%
Epoch 35/40




Epoch [35/40], Train Loss: 0.7222, Train Acc: 72.17%, Val Loss: 0.8084, Val Acc: 68.20%
Epoch 36/40




Epoch [36/40], Train Loss: 0.7256, Train Acc: 72.28%, Val Loss: 0.8002, Val Acc: 68.75%
Epoch 37/40




Epoch [37/40], Train Loss: 0.7187, Train Acc: 72.35%, Val Loss: 0.7984, Val Acc: 68.75%
Epoch 38/40




Epoch [38/40], Train Loss: 0.7165, Train Acc: 72.67%, Val Loss: 0.8063, Val Acc: 68.44%
Epoch 39/40




Epoch [39/40], Train Loss: 0.7054, Train Acc: 73.10%, Val Loss: 0.8021, Val Acc: 68.71%
Epoch 40/40


                                                                                        

Epoch [40/40], Train Loss: 0.7114, Train Acc: 72.87%, Val Loss: 0.8014, Val Acc: 68.36%




In [None]:
def model_sparsity(model):
  total = 0
  zero = 0
  for param in model.parameters():
    total += param.numel()
    zero += (param.data == 0.0).sum().item()

  return zero/total

In [None]:
model_sparsity(model)

0.31733433625721624

In [None]:
torch.save(model.state_dict(), 'my_model_weights_lr_prune_30.pt', _use_new_zipfile_serialization=False)