In [1]:
#Left - Remove the pruning layers as if they never existed and save the new model

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import  torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import shutil
import copy
import pandas as pd
from sklearn.cluster import KMeans, MiniBatchKMeans
from statistics import mean
from collections  import OrderedDict
from collections  import namedtuple
import sys
import torch.nn.utils.prune as prune

device = torch.device('cuda')
SAVE_PATH = 'D://models//Pruned_net.pth'

In [3]:
transform = {
    'train':transforms.Compose([
        transforms.RandomResizedCrop(224), 
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229,0.224,0.225])]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}

In [4]:
data_dir = 'D:\\datasets\\ILSVRC2012_img_val - Retrain\\'
dataset = {x:datasets.ImageFolder(os.path.join(data_dir, x), transform[x]) for x in ['train', 'val']}

In [5]:
dataloader = {x:torch.utils.data.DataLoader(dataset[x], batch_size = 256, shuffle = False, num_workers = 6, pin_memory = True)
              for x in ['train', 'val']}

In [6]:
dataset_size = {x:len(dataset[x]) for x in ['train', 'val']}
class_names = dataset['train'].classes

In [7]:
class AlexNet(nn.Module):

    def __init__(self, init_model, num_classes=1000):
        super(AlexNet, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )
        
        self.load_state_dict(copy.deepcopy(init_model.state_dict()))
       
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [8]:
def check_accuracy(model, phase, record_grad, criterion = None, optimizer = None):
    
    global device
    
    model.to(device)
    model.eval()
#     if record_grad:
#         model.train()
#     else:
#         model.eval()

        
    done = 0
    acc = 0.0
    since = time.time()
    corrects = torch.tensor(0)
    total_loss = 0.0
    corrects = corrects.to(device)
    loss = 100.0
    
    for inputs, labels in dataloader[phase]:

        inputs = inputs.to(device)
        labels = labels.to(device)

        if record_grad:
            with torch.set_grad_enabled(True):
                optimizer.zero_grad()
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                corrects += torch.sum(preds == labels)
                loss = criterion(outputs, labels)
                print(model.features[10].weight.grad)
                loss.backward()
                optimizer.step()

            total_loss += loss.item() * inputs.size(0)
            
            done += len(inputs)
            print('\r{}, {}, {:.2f}%, {:.2f}, {:.2f}'.format(corrects.item(), done, corrects.item() * 100.0 / done, loss.item(), total_loss), end = '')

        else:
            with torch.set_grad_enabled(False):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                corrects += torch.sum(preds == labels)

            done += len(inputs)
            print('\r{}, {}, {:.2f}%'.format(corrects.item(), done, corrects.item() * 100.0 / done), end = '')

    acc = corrects.double() / done
    print('\n{} Acc: {:.4f} %'.format(phase, acc * 100))

    time_elapsed = time.time() - since
    print('Total time taken = {} seconds'.format(time_elapsed))

    if record_grad:
        return acc, total_loss
    else:
        return acc


In [9]:
def train_limited(model, criterion, optimizer, num_epochs = 100, do_baseline = True):
    
    global device
    
    print('          ', end = '\r')
    acc = {'train':0.0, 'val':0.0}
    best_acc = 0.0
    
    if do_baseline:
        acc['train'] = check_accuracy(model, phase = 'train', record_grad = False)
        print('.......... Baseline Evaluation Done ..............')
        best_acc = acc['train']
    
    since = time.time()
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        epoch_acc, epoch_loss = check_accuracy(model, phase='train', record_grad=True, criterion=criterion, optimizer=optimizer)
        if epoch_acc > best_acc:
            print('Saving')
            best_acc = epoch_acc
            torch.save(model.state_dict(), SAVE_PATH)
        print()
        
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    
    model.load_state_dict(torch.load(SAVE_PATH))
    return model

In [10]:
def train(model, criterion, optimizer, num_epochs = 100, do_baseline = True):
    
    global device
    
    print('          ', end = '\r')
    acc = {'train':0.0, 'val':0.0}
    best_acc = 0.0
    
    if do_baseline:
        acc['val'] = check_accuracy(model, phase = 'val', record_grad = False)
        acc['train'] = check_accuracy(model, phase = 'train', record_grad = False)
        print('.......... Baseline Evaluation Done ..............')
        best_acc = acc['val']
    
    since = time.time()
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'val':
                epoch_acc = check_accuracy(model, phase=phase, record_grad=False, criterion=criterion, optimizer=optimizer)
                if epoch_acc > best_acc:
                    print('Saving')
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), SAVE_PATH)
            else:
                epoch_acc, epoch_loss = check_accuracy(model, phase=phase, record_grad=True, criterion=criterion, optimizer=optimizer)
   
        print()
        
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    
    model.load_state_dict(torch.load(SAVE_PATH))
    return model

In [11]:
alexnet = models.alexnet(pretrained=True)
model = AlexNet(init_model=alexnet)

In [12]:
#check_accuracy(model, 'train', record_grad = False)

19086, 40000, 47.72%
train Acc: 47.7150 %
Total time taken = 94.63574719429016 seconds


tensor(0.4772, device='cuda:0', dtype=torch.float64)

In [13]:
# prune_kwargs = [
#     [model.features[0], 'weight', 0.1],
#     [model.features[3], 'weight', 0.1],
#     [model.features[6], 'weight', 0.1],
#     [model.features[8], 'weight', 0.1],
#     [model.features[10], 'weight', 0.1],
    
#     [model.classifier[1], 'weight', 0.1],
#     [model.classifier[4], 'weight', 0.1],
#     [model.classifier[6], 'weight', 0.1]
# ]

# for kwarg in prune_kwargs:
#     prune.l1_unstructured(kwarg[0], name = kwarg[1], amount = kwarg[2])

parameters_to_prune = (
    (model.features[0], 'weight'),
    (model.features[3], 'weight'),
    (model.features[6], 'weight'),
    (model.features[8], 'weight'),
    (model.features[10], 'weight'),
    (model.classifier[1], 'weight'),
    (model.classifier[4], 'weight'),
    (model.classifier[6], 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.6
)

torch.save(model.state_dict(), SAVE_PATH)

In [14]:
model.load_state_dict(torch.load(SAVE_PATH))
model.to(device)
torch.cuda.empty_cache()

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr = 1e-10, momentum = 0.9)
#exp_lr_scheduler = lr_scheduler.StepLR(optimzer_ft, step_size = 7, gamma = 0.1)

model = train_limited(model, criterion, optimizer, do_baseline = False)

18881, 40000, 47.20%
train Acc: 47.2025 %
Total time taken = 95.20125913619995 seconds
.......... Baseline Evaluation Done ..............
Epoch 0/99
----------
None
161, 256, 62.89%, 1.72, 439.89None
332, 512, 64.84%, 1.52, 828.06None
512, 768, 66.67%, 1.37, 1178.78None
685, 1024, 66.89%, 1.41, 1539.06None
836, 1280, 65.31%, 1.98, 2046.29None
942, 1536, 61.33%, 2.50, 2687.15None
1057, 1792, 58.98%, 2.54, 3337.93None
1184, 2048, 57.81%, 2.39, 3949.91None
1312, 2304, 56.94%, 2.25, 4526.05None
1403, 2560, 54.80%, 3.06, 5308.31None
1512, 2816, 53.69%, 2.66, 5989.18None
1652, 3072, 53.78%, 1.81, 6451.60None
1789, 3328, 53.76%, 2.06, 6979.79None
1975, 3584, 55.11%, 1.29, 7310.34None
2169, 3840, 56.48%, 1.21, 7620.66None
2336, 4096, 57.03%, 1.57, 8021.49None
2481, 4352, 57.01%, 2.16, 8573.87None
2617, 4608, 56.79%, 2.20, 9136.01None
2745, 4864, 56.44%, 2.33, 9733.51None
2855, 5120, 55.76%, 2.73, 10431.38None
3021, 5376, 56.19%, 1.53, 10822.43None
3203, 5632, 56.87%, 1.32, 11160.10None
3377, 5

In [None]:
for module in list(model.features):
    if prune.is_pruned(module):
        print(module)
        prune.remove(module, 'weight')

for module in list(model.classifier):
    if prune.is_pruned(module):
        print(module)
        prune.remove(module, 'weight')

In [None]:
check_accuracy(model, 'train', record_grad = False)

In [None]:
torch.save(model.state_dict(), 'D://models//undone_pruned_net.pth')

In [None]:
# model.load_state_dict(torch.load(SAVE_PATH)) 
# model.to(device)
# torch.cuda.empty_cache()
# check_accuracy(model, 'train', record_grad = False)

# for lr in [10000, 1000, 100, 10, 1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]:
#     model.load_state_dict(torch.load(SAVE_PATH)) 
#     model.to(device)
#     torch.cuda.empty_cache()
    
#     criterion = nn.CrossEntropyLoss()
#     optimizer = optim.SGD(model.parameters(), lr = lr, momentum = 0.9)
    
#     print('--------------------')
#     print('lr = {}'.format(lr))
#     print()
    
#     check_accuracy(model, 'train', record_grad = True, criterion = criterion, optimizer = optimizer)
#     check_accuracy(model, 'train', record_grad = True, criterion = criterion, optimizer = optimizer)
#     check_accuracy(model, 'train', record_grad = True, criterion = criterion, optimizer = optimizer)