# MNIST CNN Experiments with different conv layers modified

# Imports

In [None]:
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch import optim
from torch.autograd import Variable
import torch.nn.utils.prune as prune

import numpy as np
import math
import cv2
import time, datetime
from prettytable import PrettyTable
from numpy.linalg import svd

In [None]:
# Define the device to train on
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# Load data

In [None]:
from torchvision import datasets
from torchvision.transforms import ToTensor
train_data = datasets.MNIST(
    root = 'data',
    train = True,                         
    transform = ToTensor(), 
    download = True,            
)
test_data = datasets.MNIST(
    root = 'data', 
    train = False, 
    transform = ToTensor()
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 74849154.29it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 113635735.29it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 27851058.92it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 7509077.17it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






In [None]:
print(train_data)

Dataset MNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
    StandardTransform
Transform: ToTensor()


In [None]:
print(test_data)

Dataset MNIST
    Number of datapoints: 10000
    Root location: data
    Split: Test
    StandardTransform
Transform: ToTensor()


In [None]:
from torch.utils.data import DataLoader
loaders = {
    'train' : torch.utils.data.DataLoader(train_data, 
                                          batch_size=100, 
                                          shuffle=True, 
                                          num_workers=1),
    
    'test'  : torch.utils.data.DataLoader(test_data, 
                                          batch_size=100, 
                                          shuffle=True, 
                                          num_workers=1),
}
loaders

{'train': <torch.utils.data.dataloader.DataLoader at 0x7f23ed68d2b0>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x7f23ed68d670>}

In [None]:
len(loaders["test"]), len(loaders["train"])

(100, 600)

# Utils (training, testing and visualization fucntions)

In [None]:
from torchvision.transforms.autoaugment import TrivialAugmentWide
def train(cnn, optimizer, loss_func, loaders, num_epochs):
    
    cnn.train()
        
    # Train the model
    total_step = len(loaders['train'])
        
    for epoch in range(num_epochs):
        running_loss = 0.0
        visualize = True

        for i, (images, labels) in tqdm(enumerate(loaders['train']), total=len(loaders['train'])):
            b_x = Variable(images)   # batch x
            b_y = Variable(labels)   # batch y

            output = cnn(b_x)[0]
            b_y = b_y.to(device)
            output = output.to(device)     
            loss = loss_func(output, b_y)
            
            # clear gradients
            optimizer.zero_grad()           
            
            # backpropagation
            loss.backward()  

            optimizer.step()     

            running_loss += loss.item()     
                
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, running_loss/len(loaders['train'])))

In [None]:
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score

def test(cnn, loaders):
    # Test the model
    cnn.eval()
    with torch.no_grad():
        all_labels = []
        all_preds = []
        for images, labels in loaders['test']:
            test_output, last_layer = cnn(images)

            preds = torch.max(test_output, 1)[1].data.squeeze().cpu().numpy()
            labels = labels.cpu().numpy()
            
            all_labels.extend(labels)
            all_preds.extend(preds)
        
        accuracy = accuracy_score(all_labels, all_preds)
        precision = precision_score(all_labels, all_preds, average='weighted')
        recall = recall_score(all_labels, all_preds, average='weighted')
        f1 = f1_score(all_labels, all_preds, average='weighted')

    # print('Accuracy of the model on the 10000 test images: %.2f' % accuracy)
    print('Precision on the 10000 test images: %.4f' % precision)
    print('Recall on the 10000 test images: %.4f' % recall)
    print('F1 score on the 10000 test images: %.4f' % f1)

In [None]:
def visualize_conv(conv_layer):
    fig, axs = plt.subplots(1, 2)

    k = 10
    sample_image = test_data[k]
    axs[0].imshow(sample_image[0][0], cmap='gray')

    conv_output = conv_layer(sample_image[0])
    axs[1].imshow(conv_output[0].cpu().detach().numpy(), cmap='gray')

    plt.show()

# Original CNN

## Model

In [None]:
class MNISTCNN(nn.Module):
    def __init__(self):
        super(MNISTCNN, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,          
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=16,              
                out_channels=32,            
                kernel_size=5,              
                stride=1,                   
                padding=2,                  
            ),  
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)
    def forward(self, x, visualize=False):
        x = self.conv1(x)
        x = self.conv2(x)

        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output, x

## Training

In [None]:
cnn = MNISTCNN()
loss_func = nn.CrossEntropyLoss()   
optimizer = optim.Adam(cnn.parameters(), lr = 0.01)   


start = datetime.datetime.now()
train(cnn, optimizer, loss_func, loaders, num_epochs=10)
end = datetime.datetime.now()
diff = (end - start)
print("Training time:", int(diff.total_seconds()))

100%|██████████| 600/600 [00:45<00:00, 13.25it/s]

Epoch [1/10], Loss: 0.1704



100%|██████████| 600/600 [00:33<00:00, 17.75it/s]

Epoch [2/10], Loss: 0.0627



100%|██████████| 600/600 [00:35<00:00, 16.93it/s]

Epoch [3/10], Loss: 0.0540



100%|██████████| 600/600 [00:34<00:00, 17.60it/s]

Epoch [4/10], Loss: 0.0486



100%|██████████| 600/600 [00:33<00:00, 17.82it/s]

Epoch [5/10], Loss: 0.0485



100%|██████████| 600/600 [00:33<00:00, 17.70it/s]

Epoch [6/10], Loss: 0.0451



100%|██████████| 600/600 [00:34<00:00, 17.45it/s]

Epoch [7/10], Loss: 0.0405



100%|██████████| 600/600 [00:33<00:00, 17.67it/s]

Epoch [8/10], Loss: 0.0450



100%|██████████| 600/600 [00:33<00:00, 17.81it/s]

Epoch [9/10], Loss: 0.0435



100%|██████████| 600/600 [00:34<00:00, 17.24it/s]

Epoch [10/10], Loss: 0.0454
Training time: 353





## Testing

In [None]:
test(cnn, loaders)

Precision on the 10000 test images: 0.9799
Recall on the 10000 test images: 0.9797
F1 score on the 10000 test images: 0.9797


# CNN with fractional filters

## Model

In [None]:
class FractionalConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(FractionalConv2d, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        # Initializing the learnable parameters

        # In A < 1, because the gamma function is not defined for negative integers or zero.
        # In the paper, they mention using a regularization term 
        # to ensure that A is greater than or equal to 1, which would avoid this issue.
        self.A = torch.randn(self.out_channels, self.in_channels)
        self.A = nn.Parameter(torch.abs(self.A) + 1)

        self.sigma = nn.Parameter(torch.randn(self.out_channels, self.in_channels).abs())
        self.x0 = nn.Parameter(torch.randn(self.out_channels, self.in_channels))
        self.y0 = nn.Parameter(torch.randn(self.out_channels, self.in_channels))
        self.a = nn.Parameter(torch.rand(self.out_channels, self.in_channels) * 2)
        self.b = nn.Parameter(torch.rand(self.out_channels, self.in_channels) * 2)

        self.weights = None
        self.compute_weights()

    def compute_weights(self):
        # Computing weights
        weights = torch.zeros(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
        dx = self._fractional_derivative(self.a, self.A, self.sigma, self.x0)
        dy = self._fractional_derivative(self.a, self.A, self.sigma, self.y0)

        kernel = torch.einsum('abc,abd->acbd', dx, dy)
        weights = kernel.reshape(dx.shape[0], dx.shape[1], dx.shape[2], dy.shape[2])

        self.weights = weights.to(device)

    def forward(self, x):
        self.weights = self.weights.clone().detach()
        x = x.to(device)
        out = F.conv2d(x, self.weights, stride=self.stride, padding=self.padding)
        return out

    def _fractional_derivative(self, alpha, A, sigma, x0):
        N = 10
        h = torch.tensor(1 / self.kernel_size).repeat(self.out_channels, self.in_channels, self.kernel_size)

        def gamma_func(a):
            return torch.exp(torch.lgamma(a))

        def G(x):
            return torch.exp(-(torch.square(x-x0))/torch.square(sigma))
        
        def f(x, n):
            return (gamma_func(alpha + 1) * G(x)) / ((-1)**n * gamma_func(n+1) * gamma_func(1-n+alpha))

        dx = torch.zeros (self.out_channels, self.in_channels, self.kernel_size)
        for x in range(1, self.kernel_size+1):
            x = torch.tensor(x).repeat(self.out_channels, self.in_channels)
            sum_term = 0
            for n in range(N+1):
                n = torch.tensor(n).repeat(self.out_channels, self.in_channels)
                sum_term += f(x, n)

            dx[..., x-1] = sum_term

        dx = dx.to(device)
        h = h.to(device)
        A = A.unsqueeze(2).repeat(1, 1, self.kernel_size)
        A = A.to(device)
        dx = (A / h) * dx

        return dx

In [None]:
class MNISTCNN_frac(nn.Module):
    def __init__(self):
        super(MNISTCNN_frac, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,          
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(
            FractionalConv2d(
                in_channels=16,              
                out_channels=32,            
                kernel_size=5,           
                stride=1,                   
                padding=2,                  
            ),    
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10).to(device)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)

        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)
        x.to(device)
        output = self.out(x)
        return output, x

In [None]:
class MNISTCNN_frac_2conv(nn.Module):
    def __init__(self):
        super(MNISTCNN_frac_2conv, self).__init__()
        self.conv1 = nn.Sequential(         
            FractionalConv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,          
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(
            FractionalConv2d(
                in_channels=16,              
                out_channels=32,            
                kernel_size=5,           
                stride=1,                   
                padding=2,                  
            ),    
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10).to(device)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)

        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)
        x.to(device)
        output = self.out(x)
        return output, x

In [None]:
class MNISTCNN_frac_first_conv(nn.Module):
    def __init__(self):
        super(MNISTCNN_frac_first_conv, self).__init__()
        self.conv1 = nn.Sequential(         
            FractionalConv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,          
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        ).to(device)
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=16,              
                out_channels=32,            
                kernel_size=5,           
                stride=1,                   
                padding=2,                  
            ),    
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        ).to(device)
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10).to(device)
    def forward(self, x):
        x = self.conv1(x)
        # x.to('cpu')
        x = self.conv2(x)

        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)
        x.to(device)
        output = self.out(x)
        return output, x

## Training

In [None]:
cnn_frac = MNISTCNN_frac()
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn_frac.parameters(), lr = 0.01)   

start = datetime.datetime.now()
train(cnn_frac, optimizer, loss_func, loaders, num_epochs=10)
end = datetime.datetime.now()
diff = (end - start)
print("Training time:", int(diff.total_seconds()))

100%|██████████| 600/600 [00:19<00:00, 31.08it/s]

Epoch [1/10], Loss: 0.4763



100%|██████████| 600/600 [00:20<00:00, 29.99it/s]

Epoch [2/10], Loss: 0.2275



100%|██████████| 600/600 [00:19<00:00, 30.80it/s]

Epoch [3/10], Loss: 0.1920



100%|██████████| 600/600 [00:20<00:00, 28.98it/s]

Epoch [4/10], Loss: 0.1686



100%|██████████| 600/600 [00:20<00:00, 29.63it/s]

Epoch [5/10], Loss: 0.1560



100%|██████████| 600/600 [00:19<00:00, 31.38it/s]

Epoch [6/10], Loss: 0.1452



100%|██████████| 600/600 [00:20<00:00, 29.80it/s]

Epoch [7/10], Loss: 0.1378



100%|██████████| 600/600 [00:18<00:00, 31.84it/s]

Epoch [8/10], Loss: 0.1319



100%|██████████| 600/600 [00:20<00:00, 29.61it/s]

Epoch [9/10], Loss: 0.1249



100%|██████████| 600/600 [00:19<00:00, 30.65it/s]

Epoch [10/10], Loss: 0.1218
Training time: 198





In [None]:
cnn_frac_2conv = MNISTCNN_frac_2conv()
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn_frac_2conv.parameters(), lr = 0.01)   

start = datetime.datetime.now()
train(cnn_frac_2conv, optimizer, loss_func, loaders, num_epochs=10)
end = datetime.datetime.now()
diff = (end - start)
print("Training time:", int(diff.total_seconds()))

100%|██████████| 600/600 [00:10<00:00, 59.45it/s]

Epoch [1/10], Loss: 22633.0462



100%|██████████| 600/600 [00:10<00:00, 56.11it/s]

Epoch [2/10], Loss: 10718.6109



100%|██████████| 600/600 [00:10<00:00, 56.54it/s]

Epoch [3/10], Loss: 10629.7964



100%|██████████| 600/600 [00:10<00:00, 57.52it/s]

Epoch [4/10], Loss: 12139.9114



100%|██████████| 600/600 [00:10<00:00, 57.03it/s]

Epoch [5/10], Loss: 11781.7047



100%|██████████| 600/600 [00:10<00:00, 55.95it/s]

Epoch [6/10], Loss: 12066.8267



100%|██████████| 600/600 [00:10<00:00, 57.98it/s]

Epoch [7/10], Loss: 12283.6638



100%|██████████| 600/600 [00:10<00:00, 59.56it/s]

Epoch [8/10], Loss: 12263.6066



100%|██████████| 600/600 [00:10<00:00, 57.06it/s]

Epoch [9/10], Loss: 13846.3920



100%|██████████| 600/600 [00:10<00:00, 57.29it/s]

Epoch [10/10], Loss: 12686.8145
Training time: 105





In [None]:
cnn_frac_first_conv = MNISTCNN_frac_first_conv()
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn_frac_first_conv.parameters(), lr = 0.01)   

start = datetime.datetime.now()
train(cnn_frac_first_conv, optimizer, loss_func, loaders, num_epochs=10)
end = datetime.datetime.now()
diff = (end - start)
print("Training time:", int(diff.total_seconds()))

100%|██████████| 600/600 [00:10<00:00, 56.71it/s]

Epoch [1/10], Loss: 2.3021



100%|██████████| 600/600 [00:10<00:00, 56.53it/s]

Epoch [2/10], Loss: 2.3020



100%|██████████| 600/600 [00:10<00:00, 55.48it/s]

Epoch [3/10], Loss: 2.3019



100%|██████████| 600/600 [00:10<00:00, 56.46it/s]

Epoch [4/10], Loss: 2.3019



100%|██████████| 600/600 [00:11<00:00, 50.34it/s]

Epoch [5/10], Loss: 2.3019



100%|██████████| 600/600 [00:10<00:00, 57.06it/s]

Epoch [6/10], Loss: 2.3020



100%|██████████| 600/600 [00:10<00:00, 54.97it/s]

Epoch [7/10], Loss: 2.3021



100%|██████████| 600/600 [00:10<00:00, 55.19it/s]

Epoch [8/10], Loss: 2.3020



100%|██████████| 600/600 [00:10<00:00, 56.22it/s]

Epoch [9/10], Loss: 2.3020



100%|██████████| 600/600 [00:10<00:00, 55.27it/s]

Epoch [10/10], Loss: 2.3020
Training time: 108





## Testing

In [None]:
test(cnn_frac, loaders)

Precision on the 10000 test images: 0.9683
Recall on the 10000 test images: 0.9680
F1 score on the 10000 test images: 0.9680


In [None]:
test(cnn_frac_first_conv, loaders)

Precision on the 10000 test images: 0.9789
Recall on the 10000 test images: 0.9788
F1 score on the 10000 test images: 0.9788


In [None]:
test(cnn_frac_2conv, loaders)

Precision on the 10000 test images: 0.7905
Recall on the 10000 test images: 0.7898
F1 score on the 10000 test images: 0.7883


# CNN with pruning

## Model

In [None]:
class PrunedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, sparsity=0.5):
        super(PrunedConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
        self.sparsity = sparsity

        self.prune()

    def forward(self, x):
        return self.conv(x)

    def prune(self):
        prune.random_unstructured(self.conv, name="weight", amount=0.3)

    def get_sparsity(self):
        return self.sparsity

    def set_sparsity(self, sparsity):
        self.sparsity = sparsity
        self.prune()

In [None]:
class MNISTCNN_pruned_first(nn.Module):
    def __init__(self):
        super(MNISTCNN_pruned_first, self).__init__()
        self.conv1 = nn.Sequential(         
            PrunedConv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,          
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=16,              
                out_channels=32,            
                kernel_size=5,             
                stride=1,                   
                padding=2,                  
            ),  
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output, x

In [None]:
class MNISTCNN_pruned_second(nn.Module):
    def __init__(self):
        super(MNISTCNN_pruned_second, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,          
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(
            PrunedConv2d(
                in_channels=16,              
                out_channels=32,            
                kernel_size=5,             
                stride=1,                   
                padding=2,                  
            ),  
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output, x

In [None]:
class MNISTCNN_pruned_both(nn.Module):
    def __init__(self):
        super(MNISTCNN_pruned_both, self).__init__()
        self.conv1 = nn.Sequential(         
            PrunedConv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,          
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(
            PrunedConv2d(
                in_channels=16,              
                out_channels=32,            
                kernel_size=5,             
                stride=1,                   
                padding=2,                  
            ),  
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output, x

## Training

In [None]:
cnn_pruned_first = MNISTCNN_pruned_first()
loss_func = nn.CrossEntropyLoss()   
optimizer = optim.Adam(cnn_pruned_first.parameters(), lr = 0.01)   


start = datetime.datetime.now()
train(cnn_pruned_first, optimizer, loss_func, loaders, num_epochs=10)
end = datetime.datetime.now()
diff = (end - start)
print("Training time:", int(diff.total_seconds()))

100%|██████████| 600/600 [00:43<00:00, 13.91it/s]

Epoch [1/10], Loss: 0.1174



100%|██████████| 600/600 [00:35<00:00, 17.01it/s]

Epoch [2/10], Loss: 0.0471



100%|██████████| 600/600 [00:36<00:00, 16.55it/s]

Epoch [3/10], Loss: 0.0414



100%|██████████| 600/600 [00:35<00:00, 17.07it/s]

Epoch [4/10], Loss: 0.0365



100%|██████████| 600/600 [00:34<00:00, 17.29it/s]

Epoch [5/10], Loss: 0.0332



100%|██████████| 600/600 [00:36<00:00, 16.54it/s]

Epoch [6/10], Loss: 0.0306



100%|██████████| 600/600 [00:34<00:00, 17.17it/s]

Epoch [7/10], Loss: 0.0298



100%|██████████| 600/600 [00:35<00:00, 17.02it/s]

Epoch [8/10], Loss: 0.0234



100%|██████████| 600/600 [00:34<00:00, 17.27it/s]

Epoch [9/10], Loss: 0.0283



100%|██████████| 600/600 [00:35<00:00, 16.95it/s]

Epoch [10/10], Loss: 0.0326
Training time: 361





In [None]:
cnn_pruned_second = MNISTCNN_pruned_second()
loss_func = nn.CrossEntropyLoss()   
optimizer = optim.Adam(cnn_pruned_second.parameters(), lr = 0.01)   


start = datetime.datetime.now()
train(cnn_pruned_second, optimizer, loss_func, loaders, num_epochs=10)
end = datetime.datetime.now()
diff = (end - start)
print("Training time:", int(diff.total_seconds()))

100%|██████████| 600/600 [00:35<00:00, 16.87it/s]

Epoch [1/10], Loss: 0.1368



100%|██████████| 600/600 [00:35<00:00, 17.08it/s]

Epoch [2/10], Loss: 0.0530



100%|██████████| 600/600 [00:34<00:00, 17.33it/s]

Epoch [3/10], Loss: 0.0431



100%|██████████| 600/600 [00:35<00:00, 17.13it/s]

Epoch [4/10], Loss: 0.0409



100%|██████████| 600/600 [00:37<00:00, 15.82it/s]

Epoch [5/10], Loss: 0.0407



100%|██████████| 600/600 [00:36<00:00, 16.60it/s]

Epoch [6/10], Loss: 0.0352



100%|██████████| 600/600 [00:36<00:00, 16.59it/s]

Epoch [7/10], Loss: 0.0345



100%|██████████| 600/600 [00:36<00:00, 16.46it/s]

Epoch [8/10], Loss: 0.0394



100%|██████████| 600/600 [00:36<00:00, 16.43it/s]

Epoch [9/10], Loss: 0.0343



100%|██████████| 600/600 [00:36<00:00, 16.61it/s]

Epoch [10/10], Loss: 0.0356
Training time: 360





In [None]:
cnn_pruned_both = MNISTCNN_pruned_both()
loss_func = nn.CrossEntropyLoss()   
optimizer = optim.Adam(cnn_pruned_both.parameters(), lr = 0.01)   


start = datetime.datetime.now()
train(cnn_pruned_both, optimizer, loss_func, loaders, num_epochs=10)
end = datetime.datetime.now()
diff = (end - start)
print("Training time:", int(diff.total_seconds()))

100%|██████████| 600/600 [00:36<00:00, 16.29it/s]

Epoch [1/10], Loss: 0.1979



100%|██████████| 600/600 [00:37<00:00, 16.18it/s]

Epoch [2/10], Loss: 0.0726



100%|██████████| 600/600 [00:38<00:00, 15.77it/s]

Epoch [3/10], Loss: 0.0626



100%|██████████| 600/600 [00:36<00:00, 16.38it/s]

Epoch [4/10], Loss: 0.0563



100%|██████████| 600/600 [00:36<00:00, 16.39it/s]

Epoch [5/10], Loss: 0.0516



100%|██████████| 600/600 [00:36<00:00, 16.35it/s]

Epoch [6/10], Loss: 0.0507



100%|██████████| 600/600 [00:36<00:00, 16.34it/s]

Epoch [7/10], Loss: 0.0506



100%|██████████| 600/600 [00:36<00:00, 16.31it/s]

Epoch [8/10], Loss: 0.0495



100%|██████████| 600/600 [00:34<00:00, 17.28it/s]

Epoch [9/10], Loss: 0.0453



100%|██████████| 600/600 [00:35<00:00, 16.70it/s]

Epoch [10/10], Loss: 0.0437
Training time: 366





## Testing

In [None]:
test(cnn_pruned_first, loaders)

Precision on the 10000 test images: 0.9869
Recall on the 10000 test images: 0.9869
F1 score on the 10000 test images: 0.9869


In [None]:
test(cnn_pruned_second, loaders)

Precision on the 10000 test images: 0.9846
Recall on the 10000 test images: 0.9845
F1 score on the 10000 test images: 0.9845


In [None]:
test(cnn_pruned_both, loaders)

Precision on the 10000 test images: 0.9829
Recall on the 10000 test images: 0.9828
F1 score on the 10000 test images: 0.9828


# CNN with Low-rank approximation

## Model

In [None]:
class LowRankConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, r=3, method='constant', decomposition='cur'):
        super(LowRankConv2d, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.r = r

        self.conv = nn.Conv2d(
                in_channels=self.in_channels,              
                out_channels=self.out_channels,            
                kernel_size=self.kernel_size,           
                stride=self.stride,                   
                padding=self.padding,                  
            )
        # print(self.conv.weight.shape)
        
        if decomposition == 'cur':
            if method == 'constant':
                c = 0.99
            elif method == 'log':
                c = np.log(self.conv.weight.shape[-1])
            self.conv.weight.data = torch.from_numpy(self.cur_low_rank(self.conv.weight, c=c, r=r))
        elif decomposition == 'svd':
            self.conv.weight.data = torch.from_numpy(self.traditional_low_rank(self.conv.weight, r=r))

        # print(self.conv.weight.shape)

    def forward(self, x):
        return self.conv(x)

    def traditional_low_rank(self, A, r):
        # A is a 4D weight matrix
        # k is the target rank
        A = A.detach().numpy()
        
        # Flatten the 4D weight matrix into a 2D matrix
        n1, n2, n3, n4 = A.shape
        A = A.reshape((n1*n2*n3, n4))
        print(f"A before approx: {np.linalg.matrix_rank(A)}, {A.shape}")

        # Compute the SVD of A
        U, s, Vt = svd(A, full_matrices=False)

        # Truncate the SVD to the target rank
        U = U[:, :r]
        s = s[:r]
        Vt = Vt[:r, :]

        # print(U.shape, s.shape, Vt.shape)

        # Compute the low-rank approximation
        A_approx = np.dot(U, np.dot(np.diag(s), Vt))
        # A_approx = U @ s @ Vt

        print(f"A after approx: {np.linalg.matrix_rank(A_approx)}, {A_approx.shape}")

        # Reshape the low-rank approximation to a 4D weight matrix
        A_approx = A_approx.reshape((n1, n2, n3, n4))

        return A_approx

    def cur_low_rank(self, A, c, r):
        # Flatten the 4D weight matrix into a 2D matrix
        A = A.detach().numpy()
        n1, n2, n3, n4 = A.shape
        A_2d = A.reshape((n1*n2*n3, n4))
        m, n = A_2d.shape
        print(f"A before approx: {np.linalg.matrix_rank(A_2d)}")

        curr_r = np.linalg.matrix_rank(A_2d)
        
        # Computing C
        def choose_col_by_prob(A):
            U, s, Vt = svd(A, full_matrices=False)

            leverage_scores = np.linalg.norm(Vt[:curr_r], axis=0) ** 2 / curr_r
            column_probabilities = np.minimum(c, leverage_scores) / np.sum(np.minimum(c, leverage_scores))
            selected_columns = np.random.choice(A.shape[1], curr_r, replace=False, p=column_probabilities)
            
            return A[:, selected_columns]

        C = choose_col_by_prob(A_2d)
        R = choose_col_by_prob(np.transpose(A_2d))
        C_pinv = np.linalg.pinv(C)
        R_pinv = np.linalg.pinv(R)

        U = C_pinv @ A_2d @ R_pinv

        # print(C.shape, U.shape, R.shape)
        # Truncate the CUR to the target rank
        C = C[:, :r]
        U = U[:r, :r]
        R = R[:r, :]

        A_approx = C @ U @ R

        print(f"A after approx: {np.linalg.matrix_rank(A_approx)}")

        # Reshape the low-rank approximation to a 4D weight matrix
        A_approx = A_approx.reshape((n1, n2, n3, n4))
        return A_approx

In [None]:
class MNISTCNN_lowrank_first(nn.Module):
    def __init__(self, r=3, method='log', decomposition='cur'):
        super(MNISTCNN_lowrank_first, self).__init__()
        self.conv1 = nn.Sequential(         
            LowRankConv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,          
                stride=1,                   
                padding=2, 
                r=r,
                method=method,
                decomposition = decomposition              
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=16,              
                out_channels=32,            
                kernel_size=5,             
                stride=1,                   
                padding=2             
            ),  
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output, x

In [None]:
class MNISTCNN_lowrank_second(nn.Module):
    def __init__(self, r=3, method='log', decomposition='cur'):
        super(MNISTCNN_lowrank_second, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,          
                stride=1,                   
                padding=2            
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(
            LowRankConv2d(
                in_channels=16,              
                out_channels=32,            
                kernel_size=5,             
                stride=1,                   
                padding=2, 
                r=r,
                method=method,
                decomposition = decomposition                
            ),  
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output, x

In [None]:
class MNISTCNN_lowrank_both(nn.Module):
    def __init__(self, r=3, method='log', decomposition='cur'):
        super(MNISTCNN_lowrank_both, self).__init__()
        self.conv1 = nn.Sequential(         
            LowRankConv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,          
                stride=1,                   
                padding=2, 
                r=r,
                method=method,
                decomposition = decomposition              
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(
            LowRankConv2d(
                in_channels=16,              
                out_channels=32,            
                kernel_size=5,             
                stride=1,                   
                padding=2, 
                r=r,
                method=method,
                decomposition = decomposition                
            ),  
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output, x

## Training

In [None]:
cnn_lowrank_first = MNISTCNN_lowrank_first(r=3, method='log', decomposition='cur')
loss_func = nn.CrossEntropyLoss()   
optimizer = optim.Adam(cnn_lowrank_first.parameters(), lr = 0.01)   


start = datetime.datetime.now()
train(cnn_lowrank_first, optimizer, loss_func, loaders, num_epochs=10)
end = datetime.datetime.now()
diff = (end - start)
print("Training time:", int(diff.total_seconds()))

A before approx: 5
A after approx: 3


100%|██████████| 600/600 [00:35<00:00, 16.85it/s]

Epoch [1/10], Loss: 0.1429



100%|██████████| 600/600 [00:35<00:00, 16.75it/s]

Epoch [2/10], Loss: 0.0590



100%|██████████| 600/600 [00:34<00:00, 17.43it/s]

Epoch [3/10], Loss: 0.0530



100%|██████████| 600/600 [00:35<00:00, 16.79it/s]

Epoch [4/10], Loss: 0.0466



100%|██████████| 600/600 [00:35<00:00, 16.69it/s]

Epoch [5/10], Loss: 0.0436



100%|██████████| 600/600 [00:35<00:00, 16.80it/s]

Epoch [6/10], Loss: 0.0412



100%|██████████| 600/600 [00:35<00:00, 16.85it/s]

Epoch [7/10], Loss: 0.0401



100%|██████████| 600/600 [00:34<00:00, 17.18it/s]

Epoch [8/10], Loss: 0.0423



100%|██████████| 600/600 [00:37<00:00, 16.13it/s]

Epoch [9/10], Loss: 0.0395



100%|██████████| 600/600 [00:36<00:00, 16.55it/s]

Epoch [10/10], Loss: 0.0387
Training time: 357





In [None]:
cnn_lowrank_second = MNISTCNN_lowrank_second(r=3, method='log', decomposition='cur')
loss_func = nn.CrossEntropyLoss()   
optimizer = optim.Adam(cnn_lowrank_second.parameters(), lr = 0.01)   


start = datetime.datetime.now()
train(cnn_lowrank_second, optimizer, loss_func, loaders, num_epochs=10)
end = datetime.datetime.now()
diff = (end - start)
print("Training time:", int(diff.total_seconds()))

A before approx: 5
A after approx: 3


100%|██████████| 600/600 [00:35<00:00, 16.67it/s]

Epoch [1/10], Loss: 0.2249



100%|██████████| 600/600 [00:35<00:00, 16.71it/s]

Epoch [2/10], Loss: 0.0703



100%|██████████| 600/600 [00:34<00:00, 17.21it/s]

Epoch [3/10], Loss: 0.0605



100%|██████████| 600/600 [00:35<00:00, 16.86it/s]

Epoch [4/10], Loss: 0.0556



100%|██████████| 600/600 [00:35<00:00, 16.71it/s]

Epoch [5/10], Loss: 0.0503



100%|██████████| 600/600 [00:35<00:00, 16.70it/s]

Epoch [6/10], Loss: 0.0516



100%|██████████| 600/600 [00:36<00:00, 16.25it/s]

Epoch [7/10], Loss: 0.0467



100%|██████████| 600/600 [00:35<00:00, 16.97it/s]

Epoch [8/10], Loss: 0.0466



100%|██████████| 600/600 [00:35<00:00, 17.06it/s]

Epoch [9/10], Loss: 0.0438



100%|██████████| 600/600 [00:35<00:00, 16.82it/s]

Epoch [10/10], Loss: 0.0447
Training time: 357





In [None]:
cnn_lowrank_both = MNISTCNN_lowrank_both(r=3, method='log', decomposition='cur')
loss_func = nn.CrossEntropyLoss()   
optimizer = optim.Adam(cnn_lowrank_both.parameters(), lr = 0.01)   


start = datetime.datetime.now()
train(cnn_lowrank_both, optimizer, loss_func, loaders, num_epochs=10)
end = datetime.datetime.now()
diff = (end - start)
print("Training time:", int(diff.total_seconds()))

A before approx: 5
A after approx: 3
A before approx: 5
A after approx: 3


100%|██████████| 600/600 [00:35<00:00, 17.03it/s]

Epoch [1/10], Loss: 0.1859



100%|██████████| 600/600 [00:34<00:00, 17.19it/s]

Epoch [2/10], Loss: 0.0727



100%|██████████| 600/600 [00:33<00:00, 17.84it/s]

Epoch [3/10], Loss: 0.0586



100%|██████████| 600/600 [00:34<00:00, 17.34it/s]

Epoch [4/10], Loss: 0.0519



100%|██████████| 600/600 [00:35<00:00, 16.78it/s]

Epoch [5/10], Loss: 0.0500



100%|██████████| 600/600 [00:35<00:00, 17.06it/s]

Epoch [6/10], Loss: 0.0463



100%|██████████| 600/600 [00:35<00:00, 16.97it/s]

Epoch [7/10], Loss: 0.0429



100%|██████████| 600/600 [00:34<00:00, 17.24it/s]

Epoch [8/10], Loss: 0.0427



100%|██████████| 600/600 [00:34<00:00, 17.14it/s]

Epoch [9/10], Loss: 0.0399



100%|██████████| 600/600 [00:34<00:00, 17.16it/s]

Epoch [10/10], Loss: 0.0410
Training time: 349





## Testing

In [None]:
test(cnn_lowrank_first, loaders)

Precision on the 10000 test images: 0.9819
Recall on the 10000 test images: 0.9817
F1 score on the 10000 test images: 0.9817


In [None]:
test(cnn_lowrank_second, loaders)

Precision on the 10000 test images: 0.9803
Recall on the 10000 test images: 0.9800
F1 score on the 10000 test images: 0.9800


In [None]:
test(cnn_lowrank_both, loaders)

Precision on the 10000 test images: 0.9826
Recall on the 10000 test images: 0.9825
F1 score on the 10000 test images: 0.9825


# Count parameters for all

In [None]:
def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: 
            continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [None]:
cnn_params = count_parameters(cnn)

+----------------+------------+
|    Modules     | Parameters |
+----------------+------------+
| conv1.0.weight |    400     |
|  conv1.0.bias  |     16     |
| conv2.0.weight |   12800    |
|  conv2.0.bias  |     32     |
|   out.weight   |   15680    |
|    out.bias    |     10     |
+----------------+------------+
Total Trainable Params: 28938


In [None]:
cnn_frac_params = count_parameters(cnn_frac)

+----------------+------------+
|    Modules     | Parameters |
+----------------+------------+
| conv1.0.weight |    400     |
|  conv1.0.bias  |     16     |
|   conv2.0.A    |    512     |
| conv2.0.sigma  |    512     |
|   conv2.0.x0   |    512     |
|   conv2.0.y0   |    512     |
|   conv2.0.a    |    512     |
|   conv2.0.b    |    512     |
|   out.weight   |   15680    |
|    out.bias    |     10     |
+----------------+------------+
Total Trainable Params: 19178


In [None]:
cnn_frac_params = count_parameters(cnn_frac_2conv)

+---------------+------------+
|    Modules    | Parameters |
+---------------+------------+
|   conv1.0.A   |     16     |
| conv1.0.sigma |     16     |
|   conv1.0.x0  |     16     |
|   conv1.0.y0  |     16     |
|   conv1.0.a   |     16     |
|   conv1.0.b   |     16     |
|   conv2.0.A   |    512     |
| conv2.0.sigma |    512     |
|   conv2.0.x0  |    512     |
|   conv2.0.y0  |    512     |
|   conv2.0.a   |    512     |
|   conv2.0.b   |    512     |
|   out.weight  |   15680    |
|    out.bias   |     10     |
+---------------+------------+
Total Trainable Params: 18858


In [None]:
cnn_frac_params = count_parameters(cnn_frac_first_conv)

+----------------+------------+
|    Modules     | Parameters |
+----------------+------------+
|   conv1.0.A    |     16     |
| conv1.0.sigma  |     16     |
|   conv1.0.x0   |     16     |
|   conv1.0.y0   |     16     |
|   conv1.0.a    |     16     |
|   conv1.0.b    |     16     |
| conv2.0.weight |   12800    |
|  conv2.0.bias  |     32     |
|   out.weight   |   15680    |
|    out.bias    |     10     |
+----------------+------------+
Total Trainable Params: 28618


In [None]:
from prettytable import PrettyTable

def count_parameters_pruning(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: 
            continue
        
        param = parameter.numel()
        if 'conv2.0.conv.weight' in name:
            pruned_mask = list(model.conv2.named_buffers())
            param = int(torch.sum(pruned_mask[0][1]).item())
            name = name[:-5]
        
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params


In [None]:
cnn_pruned_params = count_parameters_pruning(cnn_pruned)

+--------------------------+------------+
|         Modules          | Parameters |
+--------------------------+------------+
|    conv1.0.conv.bias     |     16     |
| conv1.0.conv.weight_orig |    400     |
|    conv2.0.conv.bias     |     32     |
|   conv2.0.conv.weight    |    8960    |
|        out.weight        |   15680    |
|         out.bias         |     10     |
+--------------------------+------------+
Total Trainable Params: 25098


In [None]:
cnn_lowrank_params = count_parameters(cnn_lowrank)

+---------------------+------------+
|       Modules       | Parameters |
+---------------------+------------+
| conv1.0.conv.weight |    400     |
|  conv1.0.conv.bias  |     16     |
| conv2.0.conv.weight |   12800    |
|  conv2.0.conv.bias  |     32     |
|      out.weight     |   15680    |
|       out.bias      |     10     |
+---------------------+------------+
Total Trainable Params: 28938
