# Post-training pruning: iterative magnitude pruning (IMP)

## When iteration is one, IMP downgrades to one-shot magnitude pruning.

In [1]:
import sys
sys.path.append('..')
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse
from copy import deepcopy

from models import *
from utils.utils import progress_bar, train, test
from utils.pruner import pruning_model_random, check_sparsity, pruning_model, prune_model_custom, extract_mask, remove_prune

In [2]:
lr = 0.1
epochs = 2 # how many epochs for each training period?
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# 1 if OMP is applied. However, to get a fully retrained OMP model, pruning_times should be set to 2. 
# The model before the last pruning step is the fully trained OMP model given the desired pruning ratio.
pruning_times = 2
pruning_ratio = 0.2
rewinding_epoch = 1

save_dir = 'checkpoint'

In [3]:
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

# Model
print('==> Building model..')
net = ResNet18()
net = net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

print('######################################## Start Standard Training Iterative Pruning ########################################')

for state in range(pruning_times):

    print('******************************************')
    print('pruning state', state)
    print('******************************************')

    check_sparsity(net)

    for epoch in range(start_epoch, start_epoch+epochs):
        train(epoch, device, net, trainloader, optimizer, criterion)
        if state == 0:
            if (epoch+1) == rewinding_epoch:
                torch.save(net.state_dict(), os.path.join(save_dir, 'epoch_{}_rewind_weight.pt'.format(epoch+1)))
                rewind_init = deepcopy(net.state_dict())
        test(epoch, device, net, testloader, criterion, best_acc)
        scheduler.step()
    
    # model pruning and rewinding
    pruning_model(net, pruning_ratio)
    current_mask = extract_mask(net.state_dict())
    remove_prune(net)

    net.load_state_dict(rewind_init, strict=True)
    prune_model_custom(net, current_mask)
    
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    for _ in range(rewinding_epoch):
        scheduler.step()
    
check_sparsity(net)
print("Finished!")

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
==> Building model..
######################################## Start Standard Training Iterative Pruning ########################################
******************************************
pruning state 0
******************************************
* remain weight =  100.0 %

Epoch: 0
Saving..

Epoch: 1
Saving..
start unstructured pruning
remove pruning
start unstructured pruning with custom mask
******************************************
pruning state 1
******************************************
* remain weight =  80.00000358447606 %

Epoch: 0




Saving..

Epoch: 1
Saving..
start unstructured pruning
remove pruning
start unstructured pruning with custom mask
* remain weight =  64.00000465981887 %
Finished!
