# Daily Dose of Data Science

This notebook accompanies the code for our model compression blog.

Read the full blog here: [Machine Learning Model Compression: A Critical Step Towards Efficient Deep Learning](https://www.dailydoseofds.com/model-compression-a-critical-step-towards-efficient-machine-learning)

Author: Avi Chawla

# Imports

In [1]:
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import pandas as pd

from time import time
from tqdm import tqdm
from torch.utils.data import DataLoader

# Load the MNIST dataset

In [2]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False)

# Knowledge Distillation

## Teacher Model

In [3]:
class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.pool = nn.MaxPool2d(5, 5)
        self.fc1 = nn.Linear(32 * 4 * 4, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x        

## Evaluation function

In [4]:
def evaluate(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

## Initialize and train the teacher model

In [5]:
teacher_model = TeacherNet()
teacher_optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)
teacher_criterion = nn.CrossEntropyLoss()

In [6]:
for epoch in range(5):
    teacher_model.train()
    running_loss = 0.0
    
    for data in trainloader:
        inputs, labels = data
        teacher_optimizer.zero_grad()
        outputs = teacher_model(inputs)
        loss = teacher_criterion(outputs, labels)
        loss.backward()
        teacher_optimizer.step()
        
        running_loss += loss.item()
        
    teacher_accuracy = evaluate(teacher_model)
        
    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}, Accuracy: {teacher_accuracy * 100:.2f}%")

Epoch 1, Loss: 0.23366064861265898, Accuracy: 97.60%
Epoch 2, Loss: 0.07699692965661889, Accuracy: 98.00%
Epoch 3, Loss: 0.058064278137973394, Accuracy: 98.44%
Epoch 4, Loss: 0.04937064894107677, Accuracy: 98.24%
Epoch 5, Loss: 0.04162352114517703, Accuracy: 98.53%


## Student Model

In [7]:
class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

## Initialize and train the teacher model

In [8]:
student_model = StudentNet()
student_optimizer = optim.Adam(student_model.parameters(), lr=0.001)

## Loss function (KL Divergence)

In [9]:
def knowledge_distillation_loss(student_logits, teacher_logits):
    p_teacher = F.softmax(teacher_logits , dim=1)
    p_student = F.log_softmax(student_logits, dim=1)
    loss = F.kl_div(p_student, p_teacher, reduction='batchmean')
    return loss

In [10]:
# Train the student model with knowledge distillation
for epoch in range(5):  # You can adjust the number of epochs
    student_model.train()
    running_loss = 0.0
    
    for data in trainloader:
        inputs, labels = data
        student_optimizer.zero_grad()
        student_logits = student_model(inputs)
        teacher_logits = teacher_model(inputs).detach()  # Detach the teacher's output to avoid backpropagation
        loss = knowledge_distillation_loss(student_logits, teacher_logits)
        loss.backward()
        student_optimizer.step()
        
        running_loss += loss.item()
    
    student_accuracy = evaluate(student_model)
        
    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(testloader)}, Accuracy: {student_accuracy * 100:.2f}%")


Epoch 1, Loss: 1.97617478094473, Accuracy: 93.53%
Epoch 2, Loss: 0.9071605966373044, Accuracy: 94.67%
Epoch 3, Loss: 0.6211776698874251, Accuracy: 96.30%
Epoch 4, Loss: 0.48355193005483244, Accuracy: 96.29%
Epoch 5, Loss: 0.4033386060778218, Accuracy: 96.34%


In [11]:
%timeit evaluate(teacher_model)

1.61 s ± 21.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
%timeit evaluate(student_model) # student model runs faster

1.09 s ± 63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Zero-Pruning

## Model

In [14]:
# Define a simple neural network for MNIST classification
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x

## Evaluation function

In [15]:
def evaluate(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

## Initialize and train the neural network

In [16]:
net = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

for epoch in range(5):
    running_loss = 0.0
    net.train()
    for data in trainloader:
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs.view(-1, 28*28))
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
    running_loss += loss.item()
        
    accuracy = evaluate(net)
        
    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}, Accuracy: {accuracy * 100:.2f}%")

Epoch 1, Loss: 0.0002394356866126884, Accuracy: 95.17%
Epoch 2, Loss: 0.0002629532933489346, Accuracy: 96.14%
Epoch 3, Loss: 5.20410822398627e-05, Accuracy: 96.73%
Epoch 4, Loss: 5.284582437482724e-05, Accuracy: 97.09%
Epoch 5, Loss: 2.2851257845123947e-06, Accuracy: 97.07%


In [17]:
# Save model
torch.save(net, "net.pt")

## Define the pruning threshold (λ) and apply Pruning

In [18]:
thresholds = np.linspace(0, 0.1, 11) # range of thresholds
results = []
total_params = np.sum([param.numel() for name, param in net.named_parameters() if 'weight' in name])

for threshold in tqdm(thresholds):
    
    # apply zero pruning
    for name, param in net.named_parameters():
        if 'weight' in name: # Apply pruning to weight parameters
            param.data[torch.abs(param.data) < threshold] = 0
            
    # Count zero param
    zero_params = np.sum([torch.sum(param == 0).item() for name, param in net.named_parameters() if 'weight' in name])
    
    accuracy = evaluate(net)
    
    results.append([threshold, accuracy, total_params, zero_params])


100%|███████████████████████████████████████████| 11/11 [00:14<00:00,  1.33s/it]


In [20]:
results = pd.DataFrame(results, columns = ["Threshold", "Accuracy", "Original Params", "Zero Params"])
results["Zero percentage"] = 100*results["Zero Params"]/results["Original Params"]
results

Unnamed: 0,Threshold,Accuracy,Original Params,Zero Params,Zero percentage
0,0.00,0.9707,566528,0,0.000000
1,0.01,0.9705,566528,120825,21.327278
2,0.02,0.9703,566528,238169,42.040111
3,0.03,0.9658,566528,345597,61.002634
4,0.04,0.9375,566528,422488,74.574955
...,...,...,...,...,...
6,0.06,0.8220,566528,488538,86.233690
7,0.07,0.7939,566528,505050,89.148286
8,0.08,0.7405,566528,517510,91.347647
9,0.09,0.6363,566528,527475,93.106607


## Select best threshold and apply pruning

In [21]:
threshold = 0.03
net = torch.load('net.pt')

# apply zero pruning
for name, param in net.named_parameters():
    if 'weight' in name: # Apply pruning to weight parameters
        param.data[torch.abs(param.data) < threshold] = 0

## Represent as sparse matrix

In [22]:
import scipy.sparse as sp

sparse_weights = []

# Convert the pruned weights to a sparse matrix
for name, param in net.named_parameters():
    if 'weight' in name:
        
        np_weight = param.data.cpu().numpy()
        sparse_weights.append(sp.csr_matrix(np_weight))

## Size before pruning

In [23]:
total_size = 0

for name, param in net.named_parameters():
    if 'weight' in name:
        tensor = param.data
        total_size += tensor.element_size() * tensor.numel()
        
# Convert bytes to a more human-readable format (e.g., megabytes)
tensor_size_mb = total_size/(1024**2)

tensor_size_mb

2.1611328125

## Size after pruning

In [24]:
total_size = 0

for w in sparse_weights:
    total_size += w.data.nbytes
    
csr_size_mb = total_size/(1024**2)
csr_size_mb

0.8427848815917969

# Activation pruning

## Model

In [25]:
# Define a simple teacher neural network with 4 fully connected layers
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x1 = torch.relu(self.fc1(x))
        x2 = torch.relu(self.fc2(x1))
        x3 = torch.relu(self.fc3(x2))
        x4 = self.fc4(x3)
        return x1, x2, x3, x4  # Return intermediate feature activations for activation pruning

In [26]:
def evaluate(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            outputs = model(inputs)[-1] # use last element returned by forward function
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

## Initialize and train the neural network

In [27]:
net = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

for epoch in range(5):
    net.train()
    running_loss = 0.0
    
    for data in trainloader:
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs[-1], labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    accuracy = evaluate(net)
        
    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}, Accuracy: {accuracy * 100:.2f}%")

Epoch 1, Loss: 0.31967253410525476, Accuracy: 95.16%
Epoch 2, Loss: 0.1467606759618229, Accuracy: 96.23%
Epoch 3, Loss: 0.10635705036061532, Accuracy: 96.97%
Epoch 4, Loss: 0.09333869409032547, Accuracy: 97.34%
Epoch 5, Loss: 0.07620950450357208, Accuracy: 96.63%


## Evaluate the average activations of neurons on the training data

In [28]:
net.eval()
all_activations = [torch.zeros(512), torch.zeros(256), torch.zeros(128)]
data_size = len(trainloader.dataset.targets)

with torch.no_grad():
    for data in trainloader:
        inputs, _ = data
        
        activations_fc1 = torch.relu(net.fc1(inputs.view(-1, 28*28)))
        activations_fc2 = torch.relu(net.fc2(activations_fc1))
        activations_fc3 = torch.relu(net.fc3(activations_fc2))
        
        all_activations[0] += torch.sum(activations_fc1, dim=0)
        all_activations[1] += torch.sum(activations_fc2, dim=0)
        all_activations[2] += torch.sum(activations_fc3, dim=0)
            
for idx, activations in enumerate(all_activations):
    all_activations[idx] = activations/data_size

## Apply activation pruning across thresholds

In [29]:
thresholds = np.linspace(0, 1, 11)
results = []

original_total_params = sum(p.numel() for p in net.parameters())
for threshold in tqdm(thresholds):
    new_net = SimpleNet()
    
    new_net.fc1.weight = net.fc1.weight
    new_net.fc2.weight = net.fc2.weight
    new_net.fc3.weight = net.fc3.weight
    new_net.fc4.weight = net.fc4.weight
    
    new_net.fc1.bias = net.fc1.bias
    new_net.fc2.bias = net.fc2.bias
    new_net.fc3.bias = net.fc3.bias
    new_net.fc4.bias = net.fc4.bias
    
    new_net.fc1.weight = nn.Parameter(new_net.fc1.weight[all_activations[0]>=threshold])

    new_net.fc2.weight = nn.Parameter(new_net.fc2.weight[:, all_activations[0]>=threshold])
    new_net.fc2.weight = nn.Parameter(new_net.fc2.weight[all_activations[1]>=threshold])
    
    new_net.fc3.weight = nn.Parameter(new_net.fc3.weight[:, all_activations[1]>=threshold])
    new_net.fc3.weight = nn.Parameter(new_net.fc3.weight[all_activations[2]>=threshold])
    
    new_net.fc4.weight = nn.Parameter(new_net.fc4.weight[:, all_activations[2]>=threshold])
    
    
    new_net.fc1.bias = nn.Parameter(new_net.fc1.bias[all_activations[0]>=threshold])
    new_net.fc2.bias = nn.Parameter(new_net.fc2.bias[all_activations[1]>=threshold])
    new_net.fc3.bias = nn.Parameter(new_net.fc3.bias[all_activations[2]>=threshold])
    
    
    accuracies = 0
    total_time = 0
    for _ in range(7):
        
        start = time()
        accuracies += evaluate(new_net)    
        total_time += time() - start
        
    new_total_params = sum(p.numel() for p in new_net.parameters())

    results.append([threshold, 100*accuracies/7, original_total_params, new_total_params, total_time/7])

100%|███████████████████████████████████████████| 11/11 [01:44<00:00,  9.54s/it]


In [30]:
results = pd.DataFrame(results, columns = ["Threshold", "Accuracy", "Original Params", "New Params", "Inference Time"])
results["Size Reduction"] = 1-results["New Params"]/results["Original Params"]
results

Unnamed: 0,Threshold,Accuracy,Original Params,New Params,Inference Time,Size Reduction
0,0.0,96.63,567434,567434,1.405958,0.000000
1,0.1,96.44,567434,215901,1.196406,0.619513
2,0.2,96.37,567434,199489,1.361184,0.648437
3,0.3,96.11,567434,182201,1.425039,0.678904
4,0.4,94.24,567434,165041,1.300359,0.709145
...,...,...,...,...,...,...
6,0.6,87.49,567434,136019,1.325743,0.760291
7,0.7,86.40,567434,124319,1.423424,0.780910
8,0.8,82.55,567434,106310,1.495066,0.812648
9,0.9,77.36,567434,94773,1.370770,0.832980


# Low-rank Factorization

## Model

In [31]:
# Define a simple teacher neural network with 4 fully connected layers
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x1 = torch.relu(self.fc1(x))
        x2 = torch.relu(self.fc2(x1))
        x3 = torch.relu(self.fc3(x2))
        x4 = self.fc4(x3)
        return x4 

In [32]:
def evaluate(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            outputs = model(inputs) # use last element returned by forward function
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

## Initialize and train the neural network

In [33]:
net = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

for epoch in range(5):
    net.train()
    running_loss = 0.0
    
    for data in trainloader:
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    accuracy = evaluate(net)
        
    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}, Accuracy: {accuracy * 100:.2f}%")

Epoch 1, Loss: 0.32092429239596765, Accuracy: 94.65%
Epoch 2, Loss: 0.13917262043887332, Accuracy: 96.70%
Epoch 3, Loss: 0.10783804802925014, Accuracy: 96.63%
Epoch 4, Loss: 0.08807487449378394, Accuracy: 96.81%
Epoch 5, Loss: 0.07418891476260732, Accuracy: 97.24%


## Function to determine minimum matrix operations

Taken from https://www.geeksforgeeks.org/python-program-for-matrix-chain-multiplication-dp-8/

In [34]:
def MatrixChainOrder(p, i, j):
 
    if i == j:
        return 0
 
    _min = sys.maxsize
     
    for k in range(i, j):
     
        count = (MatrixChainOrder(p, i, k)
             + MatrixChainOrder(p, k + 1, j)
                   + p[i-1] * p[k] * p[j])
 
        if count < _min:
            _min = count;
     
    return _min;

## Apply LRF across thresholds

In [35]:
rank_values = [128, 100, 90, 80, 60, 50, 40, 30, 20, 10, 5, 2, 1]
results = []
original_total_params = 128*256
batch_size = 32

# factorize layer 3 weights
U, S, V = torch.svd(net.fc3.weight)

for rank in tqdm(rank_values):

    # Truncate U, S, and V to retain only the top 'rank' components
    U_low_rank = U[:, :rank]
    S_low_rank = torch.diag(S[:rank])
    V_low_rank = V[:, :rank]

    # Reconstruct the factorized weight matrix
    factorized_weight_matrix = torch.mm(U_low_rank, torch.mm(S_low_rank, V_low_rank.t()))

    # Replace the weight matrix of the chosen layer with the factorized weight matrix
    net.fc3.weight = nn.Parameter(factorized_weight_matrix)
    
    weight_list = [batch_size, 256, rank, rank, 128]
    
    if rank == 128:
        # use the usual weight matrix
        total_operations = batch_size*256*128
    else:
        # find the minimum number of operations
        total_operations = MatrixChainOrder(weight_list, 1, 4)

    accuracies = 0
    total_time = 0
    for _ in range(7):
        
        start = time()
        accuracies += evaluate(net)    
        total_time += time() - start
    
    # total parameters of three matrices
    new_total_params = 128*rank + rank**2 + rank*256

    results.append([rank, 100*accuracies/7, original_total_params, new_total_params, total_operations, total_time/7])

100%|███████████████████████████████████████████| 13/13 [02:16<00:00, 10.47s/it]


In [36]:
results = pd.DataFrame(results, columns = ["Threshold", "Accuracy", "Original Params", "New Params", "Operations", "Inference Time"])
results

Unnamed: 0,Threshold,Accuracy,Original Params,New Params,Operations,Inference Time
0,128,97.24,32768,65536,1048576,1.272552
1,100,97.24,32768,48400,1548800,1.350363
2,90,97.26,32768,42660,1365120,1.414365
3,80,97.26,32768,37120,1187840,1.439593
4,60,97.26,32768,26640,852480,1.411347
...,...,...,...,...,...,...
8,20,97.16,32768,8080,258560,1.571930
9,10,97.24,32768,3940,126080,1.524304
10,5,80.23,32768,1945,62240,1.817376
11,2,34.73,32768,772,24704,1.589601


# Quantization

In [37]:
# Quantize the model to int8 data type
quantized_model = torch.quantization.quantize_dynamic(
    net,  # Your model
    {torch.nn.Linear},  # Specify the layers to be quantized (e.g., linear layers)
    dtype=torch.qint8  # Data type for quantization (int8)
)

In [39]:
evaluate(quantized_model)

97.24%
