In [1]:
import torch 
from torch import nn, Tensor
from typing import Union, Tuple
import torchvision
from torchvision.transforms import v2 
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from torch.optim import lr_scheduler
from tqdm import tqdm
import random
import os
from model import Model 
# import math
import pickle

seed_number = 42
torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
random.seed(seed_number)
torch.manual_seed(seed_number)
torch.cuda.manual_seed(seed_number)
np.random.seed(seed_number)
os.environ['PYTHONHASHSEED'] = str(seed_number)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 64 


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def downloadData(batch_size, download=True):


    transforms = v2.Compose(
        [
            v2.ToImage(), 
            v2.ToDtype(torch.float32, scale=True),
            v2.Grayscale(1),
            v2.Normalize((0.5, ), (0.5, ))
        ]
    )

    trainset = torchvision.datasets.MNIST('./', train=True, transform=transforms, download=download)
    train_subset, val_subset = torch.utils.data.random_split(
        trainset, [0.9, 0.1], generator=torch.Generator().manual_seed(42)
    )

    testset = torchvision.datasets.MNIST('./', train=False, transform=transforms, download=download)
        
    train_data = torch.utils.data.DataLoader(train_subset,  batch_size=batch_size, shuffle=True, num_workers=2)
    test_data = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    val_data = torch.utils.data.DataLoader(val_subset,  batch_size=batch_size, shuffle=True, num_workers=2)
    return train_data, val_data, test_data


In [3]:
def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))
    plt.show()

def calculateAccuracy(predicted, targets):
    predicted = nn.functional.softmax(predicted, dim=0)    # print(predicted[0], targets[0])
    pred_no = torch.argmax(predicted, dim=1)
    # print(predicted)
    right = torch.sum(torch.eq(pred_no, targets).int())
    return right / len(pred_no)


In [4]:
def validateModel(model, testIter, loss, device):
    loss_per_batch = []
    acc_per_batch = []
    with torch.no_grad():
        model.eval()
        for _, (X, y) in enumerate(testIter):
            X, y = X.to(device), y.to(device)
            out = model(X)
            l = loss(out, y)
            a = calculateAccuracy(out, y)
            loss_per_batch.append(l.item())
            acc_per_batch.append(a.item())
        mean_acc = sum(acc_per_batch) / len(acc_per_batch)
        meanloss = sum(loss_per_batch)/len(loss_per_batch)
    return loss_per_batch, meanloss, acc_per_batch, mean_acc


In [5]:
def train(trainIter, testIter, model,
          device=device,
          epochs=100, 
          optim=None, 
          loss=None,
          scheduler=None
          ):
    logs_dic = {
        "valildationLoss": [],
        "trainingLoss" : [],
        "validationAccuracy": [],
        "trainingAccuracy": []
    }
    for epoch in range(epochs):
        train_loss_per_batch = []
        train_acc_per_batch = []
        with tqdm(trainIter, unit="batches") as tepoch:
            for _, (X, y) in enumerate(tepoch):
                model.train()
                optim.zero_grad()
                X, y = X.to(device), y.to(device)
                out = model(X)
                l = loss(out, y)
                acc = calculateAccuracy(out, y)
                train_acc_per_batch.append(acc.item())
                train_loss_per_batch.append(l.item())
                tepoch.set_description(f"Epoch {epoch + 1}")
                tepoch.set_postfix(loss=l.item(), accuracy=acc.item())
                l.backward()
                optim.step()
        val_loss, mean_val_loss, val_acc, mean_val_acc = validateModel(model, testIter, loss=loss, device=device)
        print(f"The validation loss is: {mean_val_loss}")
        print(f"The validation accuracy is: {mean_val_acc}")
        logs_dic['valildationLoss'].append(val_loss)
        logs_dic['trainingLoss'].append(train_loss_per_batch)
        logs_dic['trainingAccuracy'].append(train_acc_per_batch)
        logs_dic['validationAccuracy'].append(val_acc)
        if scheduler: 
            if scheduler.__module__ == lr_scheduler.__name__:
                scheduler.step()
            else:
                for param_group in optim.param_groups:
                    lr = scheduler(epoch)
                    param_group['lr'] = lr

    return logs_dic


In [6]:
def test(testIter, model, device=device):
    with torch.no_grad():
        model.eval()
        acc_batch = []
        for _, (X, y) in enumerate(tqdm(testIter)):
            X, y = X.to(device), y.to(device)
            out = model(X)
            a = calculateAccuracy(out, y)
            acc_batch.append(a.item())
    total_acc = sum(acc_batch) / len(acc_batch)
    return total_acc


In [7]:
# trainloader, valloader, testloader = downloadData(BATCH_SIZE, download=False)
# model = Model()
# model.load_state_dict(torch.load("base.pth", map_location=device))
# # optim = optim.Adam(model.parameters(), lr=0.0001 )
# optim = optim.SGD(model.parameters(), lr=0.009, momentum=0.93, dampening=0.05, weight_decay=0.009)
# # optim = optim.SGD(model.parameters(), lr=0.0023)
# acc = test(testloader, model)
# loss = nn.CrossEntropyLoss()
# print("The test accuracy is: ", acc)
# logs = train(trainloader, valloader, model, 
#                 device=device, epochs=20, optim=optim, loss=loss)
# acc = test(testloader, model)
# print("The test accuracy is: ", acc)


In [8]:
def maskWeights(weight_matrix, prune_rate):
    mask = torch.ones_like(weight_matrix)
    mask = mask * (weight_matrix!=0)
    prev_l, curr_l = weight_matrix.shape
    abs_wght_matrix = torch.abs(weight_matrix)
    sorted_tensor, sorted_indices = torch.sort(abs_wght_matrix.reshape(1, -1))
    non_zeros = sorted_tensor.nonzero()
    # print(sorted_tensor, non_zeros)
    first_non_zero_index = non_zeros[0, 1]
    amt_pruning = torch.floor((len(torch.squeeze(sorted_tensor)) - first_non_zero_index) * float(prune_rate/100))
    # print(first_non_zero_index + amt_pruning.int())
    p_indices = sorted_indices[0][first_non_zero_index: first_non_zero_index+(amt_pruning.int())]
    # print(p_indices.shape)
    # convert to the original indices
    izz, jzz = p_indices // curr_l, p_indices % curr_l
    mask[izz, jzz] = 0
    # print(weight_matrix)
    return mask
    
def saveMasks(masks, filename):
    print("pickling masks.")
    with open(filename, 'wb') as fp:
        pickle.dump(masks, fp)
    print("successfully pickled") 

def retrieveMasks(filename):
    print("unpickling masks.")
    with open(filename, 'rb') as fp:
        masks = pickle.load( fp)
    print("successfully unpickled") 
    return masks

In [9]:
def freezeGrads(mask):
    def callback(module, grad_input, grad_output):
        module.weight.grad.data = module.weight.grad.data * mask
    return callback


In [10]:
def iterativePruning(trainIter=None, 
                    model=None, 
                    masks=None,
                    prune_rate=None,
                    device=device,
                    optim=None, 
                    loss=None,
                    ):
    logs_dic = {
        # "valildationLoss": [],
        "trainingLoss" : [],
        # "validationAccuracy": [],
        "trainingAccuracy": []
    }
    # for epoch in range(epochs):
    #     train_loss_per_batch = []
    #     train_acc_per_batch = []
    # with tqdm(trainIter, unit="batches") as tepoch:
    for _, (X, y) in enumerate(trainIter):
        model.train()
        optim.zero_grad()
        X, y = X.to(device), y.to(device)
        X.requires_grad = True
        out = model(X)
        l = loss(out, y)
        acc = calculateAccuracy(out, y)
        print("The training loss is: ", l)
        print("The training accuracy is: ", acc)
        
        # train_acc_per_batch.append(acc.item())
        # train_loss_per_batch.append(l.item())
        # tepoch.set_description(f"Epoch {epoch + 1}")
        # tepoch.set_postfix(loss=l.item(), accuracy=acc.item())
        hook1 = model.layer1.register_full_backward_hook(freezeGrads(masks[0]))
        hook2 = model.layer1.register_full_backward_hook(freezeGrads(masks[1]))
        hook3 = model.layer2.register_full_backward_hook(freezeGrads(masks[2]))
        hook4 = model.layer3.register_full_backward_hook(freezeGrads(masks[3]))
        hook5 = model.outputLayer.register_full_backward_hook(freezeGrads(masks[4]))
        
        l.backward()
        optim.step()
        mask1 = maskWeights(model.layer1.weight, prune_rate)
        mask2 = maskWeights(model.layer2.weight, prune_rate)
        mask3 = maskWeights(model.layer3.weight, prune_rate)
        mask4 = maskWeights(model.layer4.weight, prune_rate)  
        mask5 = maskWeights(model.outputLayer.weight, prune_rate/4)   
        hook1.remove()
        hook2.remove()
        hook3.remove()
        hook4.remove()
        hook5.remove()
        saveMasks([mask1, mask2, mask3, mask4, mask5], "./masks/" + str(prune_rate) +"percentMasks")
        # prune_rate = prune_rate * 
        # print(mask1)
        # save the masks
        
        
        break
    return 0
        # val_loss, mean_val_loss, val_acc, mean_val_acc = validateModel(model, testIter, loss=loss, device=device)
        # print(f"The validation loss is: {mean_val_loss}")
        # print(f"The validation accuracy is: {mean_val_acc}")
        # logs_dic['valildationLoss'].append(val_loss)
        # logs_dic['trainingLoss'].append(train_loss_per_batch)
        # logs_dic['trainingAccuracy'].append(train_acc_per_batch)
        # logs_dic['validationAccuracy'].append(val_acc)
        # if scheduler: 
        #     if scheduler.__module__ == lr_scheduler.__name__:
        #         scheduler.step()
        #     else:
        #         for param_group in optim.param_groups:
        #             lr = scheduler(epoch)
        #             param_group['lr'] = lr

    # return logs_dic


In [11]:
import torch.optim as optimi

# model = Model()
# torch.save(model.state_dict(), 'base.pth')
trainloader, valloader, testloader = downloadData(BATCH_SIZE, download=False)




In [12]:

def pruneIterations():
    model = Model()
    model.load_state_dict(torch.load('base.pth', map_location='cpu'))
    optim = optimi.SGD(model.parameters(), lr=0.009)
    loss = nn.CrossEntropyLoss()
    # mask1 = torch.ones_like(model.layer1.weight)
    # mask2 = torch.ones_like(model.layer2.weight)
    # mask3 = torch.ones_like(model.layer3.weight)
    # mask4 = torch.ones_like(model.layer4.weight)
    # mask5 = torch.ones_like(model.outputLayer.weight)
    
    # iterativePruning(trainIter=trainloader,  model=model, masks=[mask1, mask2, mask3, mask4, mask5], prune_rate=3, device=device, optim=optim, loss=loss)
    for i in range(90, 100, 1):
        print(i)
        masks = retrieveMasks('masks/' + str(i-1) + 'percentMasks')
    # apply masks to weights of each layer
        model = Model()
        model.load_state_dict(torch.load("base.pth", map_location='cpu'))
        model.layer1.weight.data = model.layer1.weight.data * masks[0]
        # print(model.layer1.weight.data)
        model.layer2.weight.data = model.layer2.weight.data * masks[1]
        model.layer3.weight.data = model.layer3.weight.data * masks[2]
        model.layer4.weight.data = model.layer4.weight.data * masks[3]
        model.outputLayer.weight.data = model.outputLayer.weight.data * masks[4]
        
        optim = optimi.SGD(model.parameters(), lr=0.09)
        # print(device)
        iterativePruning(trainIter=trainloader,  model=model, masks=masks, prune_rate=i, device=device, optim=optim, loss=loss)
     
# pruneIterations()

In [13]:
def trainPrunedModels(model, masks, trainIter, testIter, epochs, loss, optim ):
    logs_dic = {
        "valildationLoss": [],
        "trainingLoss" : [],
        "validationAccuracy": [],
        "trainingAccuracy": []
    }
    model.layer1.weight.data = model.layer1.weight.data * masks[0]
    model.layer2.weight.data = model.layer2.weight.data * masks[1]
    model.layer3.weight.data = model.layer3.weight.data * masks[2]
    model.layer4.weight.data = model.layer4.weight.data * masks[3]
    model.outputLayer.weight.data = model.outputLayer.weight.data * masks[4]
    
    hook1 = model.layer1.register_full_backward_hook(freezeGrads(masks[0]))
    hook2 = model.layer2.register_full_backward_hook(freezeGrads(masks[1]))
    hook3 = model.layer3.register_full_backward_hook(freezeGrads(masks[2]))
    hook4 = model.layer4.register_full_backward_hook(freezeGrads(masks[3]))
    hook5 = model.outputLayer.register_full_backward_hook(freezeGrads(masks[4]))
    
    for epoch in range(epochs):
        train_loss_per_batch = []
        train_acc_per_batch = []
        with tqdm(trainIter, unit="batches") as tepoch:
            for _, (X, y) in enumerate(tepoch):
                model.train()
                optim.zero_grad()
                X, y = X.to(device), y.to(device)
                X.requires_grad=True
                out = model(X)
                l = loss(out, y)
                acc = calculateAccuracy(out, y)
                train_acc_per_batch.append(acc.item())
                train_loss_per_batch.append(l.item())
                tepoch.set_description(f"Epoch {epoch + 1}")
                tepoch.set_postfix(loss=l.item(), accuracy=acc.item())
                l.backward()
                
                optim.step()
        val_loss, mean_val_loss, val_acc, mean_val_acc = validateModel(model, testIter, loss=loss, device=device)
        print(f"The validation loss is: {mean_val_loss}")
        print(f"The validation accuracy is: {mean_val_acc}")
        logs_dic['valildationLoss'].append(val_loss)
        logs_dic['trainingLoss'].append(train_loss_per_batch)
        logs_dic['trainingAccuracy'].append(train_acc_per_batch)
        logs_dic['validationAccuracy'].append(val_acc)
    hook1.remove()
    hook2.remove()
    hook3.remove()
    hook4.remove()
    hook5.remove()
    return logs_dic

In [92]:
masks = retrieveMasks('masks/96percentMasks')
# apply masks to weights of each layer
model = Model()
model.load_state_dict(torch.load("base.pth", map_location='cpu'))
# model.layer1.weight.data = model.layer1.weight.data * masks[0]
# # print(model.layer1.weight.data)
# model.layer2.weight.data = model.layer2.weight.data * masks[1]
# model.layer3.weight.data = model.layer3.weight.data * masks[2]
# model.layer4.weight.data = model.layer4.weight.data * masks[3]
# model.outputLayer.weight.data = model.outputLayer.weight.data * masks[4]


# optim = optimi.SGD(model.parameters(), lr=0.009)
# iterativePruning(trainloader, valloader, model, 60, device, optim, loss)

# print(model.outputLayer.weight.data)
# torch.sum(model.outputLayer.weight.data == torch.zeros_like(model.outputLayer.weight))
 

unpickling masks.
successfully unpickled


<All keys matched successfully>

In [93]:
optim = optimi.SGD(model.parameters(), lr=0.01, momentum=0.94)
loss = nn.CrossEntropyLoss()
logs = trainPrunedModels(model, masks, trainloader, valloader, epochs=10, loss=loss, optim=optim)

Epoch 1: 100%|██████████| 844/844 [00:19<00:00, 43.87batches/s, accuracy=0.646, loss=0.864]


The validation loss is: 0.8534586531050662
The validation accuracy is: 0.5295877659574468


Epoch 2: 100%|██████████| 844/844 [00:19<00:00, 44.02batches/s, accuracy=0.5, loss=0.477]  


The validation loss is: 0.5131179236985267
The validation accuracy is: 0.5938608158142009


Epoch 3: 100%|██████████| 844/844 [00:18<00:00, 44.58batches/s, accuracy=0.708, loss=0.456]


The validation loss is: 0.4038897573630861
The validation accuracy is: 0.6555851063829787


Epoch 4: 100%|██████████| 844/844 [00:19<00:00, 44.36batches/s, accuracy=0.75, loss=0.367] 


The validation loss is: 0.3671668024456247
The validation accuracy is: 0.6798537234042553


Epoch 5: 100%|██████████| 844/844 [00:18<00:00, 44.46batches/s, accuracy=0.625, loss=0.471]


The validation loss is: 0.42395737006309187
The validation accuracy is: 0.706283244680851


Epoch 6: 100%|██████████| 844/844 [00:18<00:00, 44.46batches/s, accuracy=0.854, loss=0.349] 


The validation loss is: 0.3807481036699833
The validation accuracy is: 0.7146498224836715


Epoch 7: 100%|██████████| 844/844 [00:19<00:00, 44.38batches/s, accuracy=0.875, loss=0.478] 


The validation loss is: 0.31993414112862123
The validation accuracy is: 0.7511081562397328


Epoch 8: 100%|██████████| 844/844 [00:18<00:00, 44.51batches/s, accuracy=0.729, loss=0.219] 


The validation loss is: 0.2903837932551161
The validation accuracy is: 0.7524933510638298


Epoch 9: 100%|██████████| 844/844 [00:19<00:00, 44.32batches/s, accuracy=0.708, loss=0.21]  


The validation loss is: 0.3079137896445203
The validation accuracy is: 0.7420766841857991


Epoch 10: 100%|██████████| 844/844 [00:19<00:00, 43.98batches/s, accuracy=0.75, loss=0.237]  


The validation loss is: 0.29596637720440294
The validation accuracy is: 0.7386414009205838


In [94]:
logs['lr'] = 0.01
logs['epochs'] = 10
logs['momentum'] = 0.94
# logs['dampening'] = 0.05

In [95]:
def saveModel(model, filename):
    torch.save(model.state_dict(), filename + ".pth")
    print("saved Successfully")
    
def saveLogs(logs, filename):
    print("pickling masks.")
    with open(filename, 'wb') as fp:
        pickle.dump(logs, fp)
    print("successfully pickled") 
    
saveModel(model, "./prunedModels/96percent")

saved Successfully


In [96]:
saveLogs(logs, 'logs/96percent')

pickling masks.
successfully pickled
